diff options
| author | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
| commit | 8c4603c73675958efc960fbd4bb599a2909d106a (patch) | |
| tree | 714dc6fc9a1672d5fd7f89676b97e10959662abc | |
| parent | 990a8d0dbaefc996244097397259e92758b15cce (diff) | |
Source codes
392 files changed, 75260 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1cbaff1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +.vs/ +ComLightLib/x64/ +Whisper/x64/ +x64/ +Tools/CompressShaders/bin/ +Tools/CompressShaders/obj/ +Whisper/D3D/shaderData-Debug.inl +Whisper/D3D/shaderData-Release.inl +WhisperNet/bin/ +WhisperNet/obj/ +Examples/TranscribeCS/bin/ +Examples/TranscribeCS/obj/ +*.aps +*.json +*.user +Examples/MicrophoneCS/obj/ +Examples/MicrophoneCS/bin/
\ No newline at end of file diff --git a/ComLightLib/ComLightLib.vcxproj b/ComLightLib/ComLightLib.vcxproj new file mode 100644 index 0000000..ee7f07b --- /dev/null +++ b/ComLightLib/ComLightLib.vcxproj @@ -0,0 +1,116 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <ItemGroup> + <ClInclude Include="comLightClient.h" /> + <ClInclude Include="client\CComPtr.hpp" /> + <ClInclude Include="comLightServer.h" /> + <ClInclude Include="comLightCommon.h" /> + <ClInclude Include="server\freeThreadedMarshaller.h" /> + <ClInclude Include="hresult.h" /> + <ClInclude Include="server\ObjectRoot.hpp" /> + <ClInclude Include="pal\guiddef.h" /> + <ClInclude Include="server\Object.hpp" /> + <ClInclude Include="server\interfaceMap.h" /> + <ClInclude Include="server\RefCounter.hpp" /> + <ClInclude Include="Exception.hpp" /> + <ClInclude Include="streams.h" /> + <ClInclude Include="utils\guid_parse.hpp" /> + <ClInclude Include="pal\hresult.h" /> + <ClInclude Include="unknwn.h" /> + <ClInclude Include="utils\typeTraits.hpp" /> + </ItemGroup> + <ItemGroup> + <ClCompile Include="server\freeThreadedMarshaller.cpp" /> + </ItemGroup> + <ItemGroup> + <Text Include="Readme.txt" /> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>15.0</VCProjectVersion> + <ProjectGuid>{52F486E7-830C-45D8-BE47-E76B5AAB2772}</ProjectGuid> + <Keyword>Win32Proj</Keyword> + <RootNamespace>ComLightLib</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <LinkIncremental>true</LinkIncremental> + <OutDir>$(Platform)\$(Configuration)\</OutDir> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <LinkIncremental>false</LinkIncremental> + <OutDir>$(Platform)\$(Configuration)\</OutDir> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <PrecompiledHeader>NotUsing</PrecompiledHeader> + <WarningLevel>Level3</WarningLevel> + <Optimization>Disabled</Optimization> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>_DEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Windows</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <PrecompiledHeader>NotUsing</PrecompiledHeader> + <WarningLevel>Level3</WarningLevel> + <Optimization>MaxSpeed</Optimization> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Windows</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/ComLightLib/ComLightLib.vcxproj.filters b/ComLightLib/ComLightLib.vcxproj.filters new file mode 100644 index 0000000..aa2c81d --- /dev/null +++ b/ComLightLib/ComLightLib.vcxproj.filters @@ -0,0 +1,28 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <ClInclude Include="pal\hresult.h" /> + <ClInclude Include="pal\guiddef.h" /> + <ClInclude Include="utils\guid_parse.hpp" /> + <ClInclude Include="unknwn.h" /> + <ClInclude Include="comLightClient.h" /> + <ClInclude Include="comLightServer.h" /> + <ClInclude Include="client\CComPtr.hpp" /> + <ClInclude Include="comLightCommon.h" /> + <ClInclude Include="server\RefCounter.hpp" /> + <ClInclude Include="server\interfaceMap.h" /> + <ClInclude Include="server\Object.hpp" /> + <ClInclude Include="utils\typeTraits.hpp" /> + <ClInclude Include="server\freeThreadedMarshaller.h" /> + <ClInclude Include="hresult.h" /> + <ClInclude Include="server\ObjectRoot.hpp" /> + <ClInclude Include="Exception.hpp" /> + <ClInclude Include="streams.h" /> + </ItemGroup> + <ItemGroup> + <ClCompile Include="server\freeThreadedMarshaller.cpp" /> + </ItemGroup> + <ItemGroup> + <Text Include="Readme.txt" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/ComLightLib/Exception.hpp b/ComLightLib/Exception.hpp new file mode 100644 index 0000000..57c1b78 --- /dev/null +++ b/ComLightLib/Exception.hpp @@ -0,0 +1,20 @@ +#pragma once + +namespace ComLight +{ + class Exception : public std::runtime_error + { + // I don't like C++ exceptions too much, but for some cases they are useful. + // You can throw ComLight::Exception from constructor, or from FinalConstruct() method, the library will catch & return the code from the class factory function. + // Unfortunately, for interface methods this doesn't work, the C++ parts of the library can't catch them without very complex trickery like code generation. + // You can still use this class in methods, but you'll need to catch them manually near the API boundary or the app will crash. + // C++ doesn't have an ABI, the framework can't catch C++ exception across the modules. + const HRESULT m_code; + + public: + + Exception( HRESULT hr ) : runtime_error( "ComLight HRESULT exception" ), m_code( hr ) { } + + HRESULT code() const { return m_code; } + }; +}
\ No newline at end of file diff --git a/ComLightLib/Readme.txt b/ComLightLib/Readme.txt new file mode 100644 index 0000000..1eabeec --- /dev/null +++ b/ComLightLib/Readme.txt @@ -0,0 +1,3 @@ +Copy-pasted from there: +https://github.com/Const-me/ComLightInterop/tree/master/ComLightLib +With only a few minor changes.
\ No newline at end of file diff --git a/ComLightLib/client/CComPtr.hpp b/ComLightLib/client/CComPtr.hpp new file mode 100644 index 0000000..7786591 --- /dev/null +++ b/ComLightLib/client/CComPtr.hpp @@ -0,0 +1,110 @@ +#pragma once + +namespace ComLight +{ + // COM smart pointer, very comparable to CComPtr from ATL + template <class I> + class CComPtr + { + I* p; + + void callAddRef() const + { + if( nullptr == p ) + return; + p->AddRef(); + } + + public: + + // Construct with nullptr + CComPtr() : p( nullptr ) { } + + // Release the pointer + void release() + { + if( nullptr == p ) + return; + p->Release(); + p = nullptr; + } + + ~CComPtr() + { + release(); + } + + // Attach without AddRef() + void attach( I* raw ) + { + release(); + p = raw; + } + + // Detach without Release(), set this pointer to nullptr + I* detach() + { + I* const result = p; + p = nullptr; + return result; + } + + // Detach without Release() and place to the specified address, set this pointer to nullptr + template<class Other> + void detach( Other** pp ) + { + // If the argument points to a non-empty object, release the old instance: would leak memory otherwise. + if( nullptr != *pp ) + ( *pp )->Release(); + ( *pp ) = detach(); + } + + // Set and AddRef() + void assign( I* raw ) + { + release(); + attach( raw ); + callAddRef(); + } + + void swap( CComPtr<I>& that ) + { + std::swap( p, that.p ); + } + + // Set and AddRef() + CComPtr( I* raw ) : p( raw ) + { + callAddRef(); + } + + // Set and AddRef() + CComPtr( const CComPtr<I>& that ) : CComPtr( that.p ) { } + // Move constructor + CComPtr( CComPtr<I>&& that ) : p( that.p ) { that.p = nullptr; } + + // Set and AddRef() + void operator=( I* raw ) + { + assign( raw ); + } + + // Set and AddRef() + void operator=( const CComPtr<I>& that ) + { + assign( that.p ); + } + + // Move assignment operator, destroys the other one + void operator=( CComPtr<I>&& that ) + { + attach( that.detach() ); + } + + operator I*( ) const { return p; } + I* operator -> () const { return p; } + I** operator &() { return &p; } + + operator bool() const { return nullptr != p; } + }; +}
\ No newline at end of file diff --git a/ComLightLib/comLightClient.h b/ComLightLib/comLightClient.h new file mode 100644 index 0000000..3174c92 --- /dev/null +++ b/ComLightLib/comLightClient.h @@ -0,0 +1,23 @@ +#pragma once +#include "comLightCommon.h" +#include "client/CComPtr.hpp" +#include "utils/typeTraits.hpp" + +namespace ComLight +{ + namespace details + { + template<typename T> + inline constexpr void** castDoublePointerToVoid( T** pp ) + { + static_assert( pointersAssignable<IUnknown, T>(), "IID_PPV_ARGS macro should be used with IUnknown interfaces" ); + return reinterpret_cast<void**>( pp ); + } + } +} + +#ifdef IID_PPV_ARGS +#undef IID_PPV_ARGS +#endif + +#define IID_PPV_ARGS( pp ) decltype( **pp )::iid, ::ComLight::details::castDoublePointerToVoid( pp )
\ No newline at end of file diff --git a/ComLightLib/comLightCommon.h b/ComLightLib/comLightCommon.h new file mode 100644 index 0000000..c571910 --- /dev/null +++ b/ComLightLib/comLightCommon.h @@ -0,0 +1,11 @@ +#pragma once +#include "hresult.h" + +#ifdef _MSC_VER +#include <guiddef.h> +#else +#include "pal/guiddef.h" +using LPCTSTR = const char*; +#endif + +#include "unknwn.h"
\ No newline at end of file diff --git a/ComLightLib/comLightServer.h b/ComLightLib/comLightServer.h new file mode 100644 index 0000000..8b2e844 --- /dev/null +++ b/ComLightLib/comLightServer.h @@ -0,0 +1,15 @@ +#pragma once +#include "comLightCommon.h" +#include "client/CComPtr.hpp" + +#include "server/ObjectRoot.hpp" +#include "server/interfaceMap.h" +#include "server/Object.hpp" +#include "server/freeThreadedMarshaller.h" + +#ifdef _MSC_VER +// On Windows, it's controlled by library.def module definition file. There's __declspec(dllexport), but it adds underscore, I don't like that. +#define DLLEXPORT extern "C" +#else +#define DLLEXPORT extern "C" __attribute__((visibility("default"))) +#endif
\ No newline at end of file diff --git a/ComLightLib/hresult.h b/ComLightLib/hresult.h new file mode 100644 index 0000000..dbaee67 --- /dev/null +++ b/ComLightLib/hresult.h @@ -0,0 +1,26 @@ +#pragma once +#include <stdint.h> +#ifdef _MSC_VER +#include <winerror.h> +#include <OleCtl.h> +#else +#include "pal/hresult.h" +#endif + +#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } + +#ifndef _MSC_VER +inline constexpr HRESULT HRESULT_FROM_WIN32( int c ) +{ + return c < 0 ? c : ( ( 0xFFFF & c ) | 0x80070000 ); +} + +constexpr HRESULT OLE_E_BLANK = _HRESULT_TYPEDEF_( 0x80040007 ); +constexpr HRESULT E_BOUNDS = _HRESULT_TYPEDEF_( 0x8000000BL ); + +constexpr int ERROR_HANDLE_EOF = 38; +constexpr int ERROR_ALREADY_INITIALIZED = 1247; +#endif + +constexpr HRESULT E_EOF = HRESULT_FROM_WIN32( ERROR_HANDLE_EOF ); +constexpr HRESULT E_ALREADY_INITIALIZED = HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
\ No newline at end of file diff --git a/ComLightLib/pal/guiddef.h b/ComLightLib/pal/guiddef.h new file mode 100644 index 0000000..ed8259f --- /dev/null +++ b/ComLightLib/pal/guiddef.h @@ -0,0 +1,21 @@ +#pragma once +#include <stdint.h> +#include <array> +#ifndef GUID_DEFINED +#define GUID_DEFINED +#endif + +struct GUID +{ + uint32_t Data1; + uint16_t Data2; + uint16_t Data3; + std::array<uint8_t, 8> Data4; + + constexpr inline bool operator==( const GUID& that ) const + { + return Data1 == that.Data1 && Data2 == that.Data2 && Data3 == that.Data3 && Data4 == that.Data4; + } +}; + +using REFIID = const GUID&;
\ No newline at end of file diff --git a/ComLightLib/pal/hresult.h b/ComLightLib/pal/hresult.h new file mode 100644 index 0000000..c261458 --- /dev/null +++ b/ComLightLib/pal/hresult.h @@ -0,0 +1,101 @@ +#pragma once +#include <stdint.h> +using HRESULT = int32_t; +#define _HRESULT_TYPEDEF_(_sc) ((HRESULT)_sc) +#define SEVERITY_ERROR 1 +#define FACILITY_CONTROL 10 + +inline constexpr HRESULT MAKE_SCODE( uint32_t sev, uint32_t fac, uint32_t code ) +{ + return (HRESULT)( ( (uint32_t)( sev ) << 31 ) | ( (unsigned long)( fac ) << 16 ) | ( (unsigned long)( code ) ) ); +}; + +// ==== Copy-pasted from coreclr-master\src\pal\inc\rt\palrt.h ==== +#define S_OK _HRESULT_TYPEDEF_(0x00000000L) +#define S_FALSE _HRESULT_TYPEDEF_(0x00000001L) + +#define E_NOTIMPL _HRESULT_TYPEDEF_(0x80004001L) +#define E_NOINTERFACE _HRESULT_TYPEDEF_(0x80004002L) +#define E_UNEXPECTED _HRESULT_TYPEDEF_(0x8000FFFFL) +#define E_OUTOFMEMORY _HRESULT_TYPEDEF_(0x8007000EL) +#define E_INVALIDARG _HRESULT_TYPEDEF_(0x80070057L) +#define E_POINTER _HRESULT_TYPEDEF_(0x80004003L) +#define E_HANDLE _HRESULT_TYPEDEF_(0x80070006L) +#define E_ABORT _HRESULT_TYPEDEF_(0x80004004L) +#define E_FAIL _HRESULT_TYPEDEF_(0x80004005L) +#define E_ACCESSDENIED _HRESULT_TYPEDEF_(0x80070005L) +#define E_PENDING _HRESULT_TYPEDEF_(0x8000000AL) + +#define DISP_E_PARAMNOTFOUND _HRESULT_TYPEDEF_(0x80020004L) +#define DISP_E_TYPEMISMATCH _HRESULT_TYPEDEF_(0x80020005L) +#define DISP_E_BADVARTYPE _HRESULT_TYPEDEF_(0x80020008L) +#define DISP_E_OVERFLOW _HRESULT_TYPEDEF_(0x8002000AL) +#define DISP_E_DIVBYZERO _HRESULT_TYPEDEF_(0x80020012L) + +#define CLASS_E_CLASSNOTAVAILABLE _HRESULT_TYPEDEF_(0x80040111L) +#define CLASS_E_NOAGGREGATION _HRESULT_TYPEDEF_(0x80040110L) + +#define CO_E_CLASSSTRING _HRESULT_TYPEDEF_(0x800401F3L) + +#define MK_E_SYNTAX _HRESULT_TYPEDEF_(0x800401E4L) + +#define STG_E_INVALIDFUNCTION _HRESULT_TYPEDEF_(0x80030001L) +#define STG_E_FILENOTFOUND _HRESULT_TYPEDEF_(0x80030002L) +#define STG_E_PATHNOTFOUND _HRESULT_TYPEDEF_(0x80030003L) +#define STG_E_WRITEFAULT _HRESULT_TYPEDEF_(0x8003001DL) +#define STG_E_FILEALREADYEXISTS _HRESULT_TYPEDEF_(0x80030050L) +#define STG_E_ABNORMALAPIEXIT _HRESULT_TYPEDEF_(0x800300FAL) + +#define NTE_BAD_UID _HRESULT_TYPEDEF_(0x80090001L) +#define NTE_BAD_HASH _HRESULT_TYPEDEF_(0x80090002L) +#define NTE_BAD_KEY _HRESULT_TYPEDEF_(0x80090003L) +#define NTE_BAD_LEN _HRESULT_TYPEDEF_(0x80090004L) +#define NTE_BAD_DATA _HRESULT_TYPEDEF_(0x80090005L) +#define NTE_BAD_SIGNATURE _HRESULT_TYPEDEF_(0x80090006L) +#define NTE_BAD_VER _HRESULT_TYPEDEF_(0x80090007L) +#define NTE_BAD_ALGID _HRESULT_TYPEDEF_(0x80090008L) +#define NTE_BAD_FLAGS _HRESULT_TYPEDEF_(0x80090009L) +#define NTE_BAD_TYPE _HRESULT_TYPEDEF_(0x8009000AL) +#define NTE_BAD_KEY_STATE _HRESULT_TYPEDEF_(0x8009000BL) +#define NTE_BAD_HASH_STATE _HRESULT_TYPEDEF_(0x8009000CL) +#define NTE_NO_KEY _HRESULT_TYPEDEF_(0x8009000DL) +#define NTE_NO_MEMORY _HRESULT_TYPEDEF_(0x8009000EL) +#define NTE_SIGNATURE_FILE_BAD _HRESULT_TYPEDEF_(0x8009001CL) +#define NTE_FAIL _HRESULT_TYPEDEF_(0x80090020L) + +#define CRYPT_E_HASH_VALUE _HRESULT_TYPEDEF_(0x80091007L) + +#define TYPE_E_SIZETOOBIG _HRESULT_TYPEDEF_(0x800288C5L) +#define TYPE_E_DUPLICATEID _HRESULT_TYPEDEF_(0x800288C6L) + +#define STD_CTL_SCODE(n) MAKE_SCODE(SEVERITY_ERROR, FACILITY_CONTROL, n) +#define CTL_E_OVERFLOW STD_CTL_SCODE(6) +#define CTL_E_OUTOFMEMORY STD_CTL_SCODE(7) +#define CTL_E_DIVISIONBYZERO STD_CTL_SCODE(11) +#define CTL_E_OUTOFSTACKSPACE STD_CTL_SCODE(28) +#define CTL_E_FILENOTFOUND STD_CTL_SCODE(53) +#define CTL_E_DEVICEIOERROR STD_CTL_SCODE(57) +#define CTL_E_PERMISSIONDENIED STD_CTL_SCODE(70) +#define CTL_E_PATHFILEACCESSERROR STD_CTL_SCODE(75) +#define CTL_E_PATHNOTFOUND STD_CTL_SCODE(76) + +#define INET_E_CANNOT_CONNECT _HRESULT_TYPEDEF_(0x800C0004L) +#define INET_E_RESOURCE_NOT_FOUND _HRESULT_TYPEDEF_(0x800C0005L) +#define INET_E_OBJECT_NOT_FOUND _HRESULT_TYPEDEF_(0x800C0006L) +#define INET_E_DATA_NOT_AVAILABLE _HRESULT_TYPEDEF_(0x800C0007L) +#define INET_E_DOWNLOAD_FAILURE _HRESULT_TYPEDEF_(0x800C0008L) +#define INET_E_CONNECTION_TIMEOUT _HRESULT_TYPEDEF_(0x800C000BL) +#define INET_E_UNKNOWN_PROTOCOL _HRESULT_TYPEDEF_(0x800C000DL) + +#define DBG_PRINTEXCEPTION_C _HRESULT_TYPEDEF_(0x40010006L) +// ==== Done pasting ==== + +inline constexpr bool SUCCEEDED( HRESULT hr ) +{ + return hr >= 0; +} + +inline constexpr bool FAILED( HRESULT hr ) +{ + return hr < 0; +}
\ No newline at end of file diff --git a/ComLightLib/server/Object.hpp b/ComLightLib/server/Object.hpp new file mode 100644 index 0000000..d2e3257 --- /dev/null +++ b/ComLightLib/server/Object.hpp @@ -0,0 +1,139 @@ +#pragma once +#include <type_traits> +#include "../comLightClient.h" +#include "../utils/typeTraits.hpp" +#include "../Exception.hpp" + +namespace ComLight +{ + namespace details + { + GENERATE_HAS_MEMBER( implQueryInterface ); + GENERATE_HAS_MEMBER( implAddRef ); + GENERATE_HAS_MEMBER( implRelease ); + } + + // Outer class of objects, implements IUnknown methods, also the class factory. The type argument must be your class implementing your interfaces, inherited from ObjectRoot<I> + template<class T> + class Object : public T + { + public: + Object() = default; + + template<typename ... Args> + Object( Args&& ... args ) : T{ std::forward<Args>( args )... } {}; + + inline virtual ~Object() override { } + + // Implement IUnknown methods + HRESULT COMLIGHTCALL QueryInterface( REFIID riid, void** ppvObject ) override + { + static_assert( details::has_member_implQueryInterface<T>::value, "Your object class must inherit from ComLight::ObjectRoot" ); + + if( nullptr == ppvObject ) + return E_POINTER; + + if( T::implQueryInterface( riid, ppvObject ) ) + return S_OK; + if( T::queryExtraInterfaces( riid, ppvObject ) ) + return S_OK; + + if( riid == IUnknown::iid() ) + { + ComLight::IUnknown* unk = T::getUnknown(); + unk->AddRef(); + *ppvObject = unk; + return S_OK; + } + + return E_NOINTERFACE; + } + + uint32_t COMLIGHTCALL AddRef() override + { + static_assert( details::has_member_implAddRef<T>::value, "Your object class must inherit from ComLight::ObjectRoot" ); + return T::implAddRef(); + } + + uint32_t COMLIGHTCALL Release() override + { + static_assert( details::has_member_implRelease<T>::value, "Your object class must inherit from ComLight::ObjectRoot" ); + const uint32_t ret = T::implRelease(); + if( 0 == ret ) + { + T::FinalRelease(); + delete this; + } + return ret; + } + + // Create a new object on the heap, store in smart pointer + static inline HRESULT create( CComPtr<Object<T>>& result ) + { + CComPtr<Object<T>> ptr; + try + { + ptr = new Object<T>(); // The RefCounter constructor creates it with ref.counter 0. But then CComPtr constructor calls AddRef so we have RC=1 after this line. + + HRESULT hr = ptr->internalFinalConstruct(); + if( FAILED( hr ) ) + return hr; + + hr = ptr->FinalConstruct(); + if( FAILED( hr ) ) + return hr; + + ptr.swap( result ); + return S_OK; + } + catch( const Exception& ex ) + { + return ex.code(); + } + } + + // Create a new object on the heap, store in smart pointer + template<typename ... Args> + static inline HRESULT create( CComPtr<Object<T>>& result, Args&& ... args ) + { + CComPtr<Object<T>> ptr; + try + { + ptr = new Object<T>( std::forward<Args>( args )... ); + + HRESULT hr = ptr->internalFinalConstruct(); + if( FAILED( hr ) ) + return hr; + + hr = ptr->FinalConstruct(); + if( FAILED( hr ) ) + return hr; + + ptr.swap( result ); + return S_OK; + } + catch( const Exception& ex ) + { + return ex.code(); + } + catch( HRESULT hr ) + { + return hr; + } + } + + // Create a new object on the heap, return one of it's interfaces. The caller is assumed to take ownership of the new object. + template<class I> + static inline HRESULT create( I** pp ) + { + if( pp == nullptr ) + return E_POINTER; + + static_assert( details::pointersAssignable<I, T>(), "Object::create can't cast object to the requested interface" ); + CComPtr<Object<T>> ptr; + CHECK( create( ptr ) ); + ptr.detach( pp ); + return S_OK; + } + }; +}
\ No newline at end of file diff --git a/ComLightLib/server/ObjectRoot.hpp b/ComLightLib/server/ObjectRoot.hpp new file mode 100644 index 0000000..1cc8f4d --- /dev/null +++ b/ComLightLib/server/ObjectRoot.hpp @@ -0,0 +1,51 @@ +#pragma once +#include "RefCounter.hpp" +#include "../comLightCommon.h" +#include "../utils/typeTraits.hpp" + +namespace ComLight +{ + // Base class of objects, implements reference counting, also a few lifetime methods. + // The template argument is the interface you want clients to get when they ask for IID_IUnknown. By convention, that pointer defines object's identity. + template<class I> + class ObjectRoot : public RefCounter, public I + { + protected: + + inline HRESULT internalFinalConstruct() + { + return S_FALSE; + } + + inline HRESULT FinalConstruct() + { + return S_FALSE; + } + + inline void FinalRelease() { } + + IUnknown* getUnknown() + { + static_assert( details::pointersAssignable<IUnknown, I>(), "The interface doesn't derive from IUnknown" ); + return static_cast<I*>( this ); + } + + bool queryExtraInterfaces( REFIID riid, void **ppvObject ) const + { + return false; + } + + // Implement query interface with 2 entries, IUnknown and I. + bool implQueryInterface( REFIID riid, void** ppvObject ) + { + if( riid == I::iid() || riid == IUnknown::iid() ) + { + I* const result = this; + result->AddRef(); + *ppvObject = result; + return true; + } + return false; + } + }; +}
\ No newline at end of file diff --git a/ComLightLib/server/RefCounter.hpp b/ComLightLib/server/RefCounter.hpp new file mode 100644 index 0000000..9698cc8 --- /dev/null +++ b/ComLightLib/server/RefCounter.hpp @@ -0,0 +1,38 @@ +#pragma once +#include <atomic> +#include <assert.h> +#include <limits.h> + +namespace ComLight +{ + // Very base class of objects, implements reference counting. + class RefCounter + { + std::atomic_uint referenceCounter; + + public: + + RefCounter() : referenceCounter( 0 ) { } + + inline virtual ~RefCounter() { } + + RefCounter( const RefCounter &that ) = delete; + RefCounter( RefCounter &&that ) = delete; + + protected: + + uint32_t implAddRef() + { + return ++referenceCounter; + } + + uint32_t implRelease() + { + // Might be a good idea to use locks, at least in debug builds. They're much slower than atomics, but with locks it's possible to detect when 2 threads call release at the same time, for object with counter = 1. + // It's a memory management bug, but it would be nice if debug builds would handle that case gracefully. + const uint32_t rc = --referenceCounter; + assert( rc != UINT_MAX ); + return rc; + } + }; +}
\ No newline at end of file diff --git a/ComLightLib/server/freeThreadedMarshaller.cpp b/ComLightLib/server/freeThreadedMarshaller.cpp new file mode 100644 index 0000000..fc1ea80 --- /dev/null +++ b/ComLightLib/server/freeThreadedMarshaller.cpp @@ -0,0 +1,17 @@ +#include "freeThreadedMarshaller.h" +#ifdef _MSC_VER +#include <combaseapi.h> + +HRESULT ComLight::details::createFreeThreadedMarshaller( IUnknown* pUnkOuter, IUnknown** ppUnkMarshal ) +{ + return ::CoCreateFreeThreadedMarshaler( (LPUNKNOWN)pUnkOuter, (LPUNKNOWN *)ppUnkMarshal ); +} + +bool ComLight::details::queryMarshallerInterface( REFIID riid, void **ppvObject, IUnknown* marshaller ) +{ + if( riid != IID_IMarshal || nullptr == marshaller ) + return false; + const HRESULT hr = marshaller->QueryInterface( IID_IMarshal, ppvObject ); + return SUCCEEDED( hr ) ? true : false; +} +#endif
\ No newline at end of file diff --git a/ComLightLib/server/freeThreadedMarshaller.h b/ComLightLib/server/freeThreadedMarshaller.h new file mode 100644 index 0000000..8ef774e --- /dev/null +++ b/ComLightLib/server/freeThreadedMarshaller.h @@ -0,0 +1,29 @@ +#pragma once +#ifdef _MSC_VER +#include "../comLightCommon.h" + +namespace ComLight +{ + namespace details + { + HRESULT createFreeThreadedMarshaller( IUnknown* pUnkOuter, IUnknown** ppUnkMarshal ); + bool queryMarshallerInterface( REFIID riid, void **ppvObject, IUnknown* marshaller ); + } +} + +#define DECLARE_FREE_THREADED_MARSHALLER() \ +private: \ +ComLight::CComPtr<ComLight::IUnknown> m_freeThreadedMarshaller; \ +protected: \ +HRESULT internalFinalConstruct() \ +{ \ + return ComLight::details::createFreeThreadedMarshaller( getUnknown(), &m_freeThreadedMarshaller ); \ +} \ +bool queryExtraInterfaces( REFIID riid, void **ppvObject ) const \ +{ \ + return ComLight::details::queryMarshallerInterface( riid, ppvObject, m_freeThreadedMarshaller ); \ +} + +#else +#define DECLARE_FREE_THREADED_MARSHALLER() +#endif
\ No newline at end of file diff --git a/ComLightLib/server/interfaceMap.h b/ComLightLib/server/interfaceMap.h new file mode 100644 index 0000000..605ed33 --- /dev/null +++ b/ComLightLib/server/interfaceMap.h @@ -0,0 +1,31 @@ +#pragma once +#include "../utils/typeTraits.hpp" + +// Unlike ATL, the interface map is optional for ComLight. +// If you won't declare a map, the object will support 2 interfaces: IUnknown, and whatever template argument was passed to ObjectRoot class. +#define BEGIN_COM_MAP() \ +protected: \ +bool implQueryInterface( REFIID iid, void** ppvObject ) { + +#define END_COM_MAP() return false; } + +namespace ComLight +{ + namespace details + { + template<typename I, typename C> + inline bool tryReturnInterface( REFIID iid, C* pThis, void** ppvResult ) + { + static_assert( pointersAssignable<IUnknown, I>(), "Trying to implement an interface that doesn't derive from IUnknown" ); + static_assert( pointersAssignable<I, C>(), "Declared support for an interface, but the class doesn't implement it" ); + if( I::iid() != iid ) + return false; + I* const result = pThis; + result->AddRef(); + *ppvResult = result; + return true; + } + } +} + +#define COM_INTERFACE_ENTRY( I ) if( ComLight::details::tryReturnInterface<I>( iid, this, ppvObject ) ) return true;
\ No newline at end of file diff --git a/ComLightLib/streams.h b/ComLightLib/streams.h new file mode 100644 index 0000000..87680e1 --- /dev/null +++ b/ComLightLib/streams.h @@ -0,0 +1,61 @@ +#pragma once +#include <vector> +#include "comLightCommon.h" + +// COM interfaces to marshal streams across the interop. +namespace ComLight +{ + enum struct eSeekOrigin : uint8_t + { + Begin = 0, + Current = 1, + End = 2 + }; + + namespace details + { + template<class E> + inline size_t sizeofVector( const std::vector<E>& vec ) + { + return sizeof( E ) * vec.size(); + } + } + + // COM interface for readonly stream. You'll get these interfaces what you use [ReadStream] attribute in C#. + struct DECLSPEC_NOVTABLE iReadStream : public IUnknown + { + DEFINE_INTERFACE_ID( "006af6db-734e-4595-8c94-19304b2389ac" ); + + virtual HRESULT COMLIGHTCALL read( void* lpBuffer, int nNumberOfBytesToRead, int &lpNumberOfBytesRead ) = 0; + virtual HRESULT COMLIGHTCALL seek( int64_t offset, eSeekOrigin origin ) = 0; + virtual HRESULT COMLIGHTCALL getPosition( int64_t& position ) = 0; + virtual HRESULT COMLIGHTCALL getLength( int64_t& length ) = 0; + + template<class E> + inline HRESULT read( std::vector<E>& vec ) + { + const int cb = (int)details::sizeofVector( vec ); + int cbRead = 0; + CHECK( read( vec.data(), cb, cbRead ) ); + if( cbRead >= cb ) + return S_OK; + return E_EOF; + } + }; + + // COM interface for readonly stream. You'll get these interfaces what you use [WriteStream] attribute in C#. + struct DECLSPEC_NOVTABLE iWriteStream : public IUnknown + { + DEFINE_INTERFACE_ID( "d7c3eb39-9170-43b9-ba98-2ea1f2fed8a8" ); + + virtual HRESULT COMLIGHTCALL write( const void* lpBuffer, int nNumberOfBytesToWrite ) = 0; + virtual HRESULT COMLIGHTCALL flush() = 0; + + template<class E> + inline HRESULT write( const std::vector<E>& vec ) + { + const int cb = (int)details::sizeofVector( vec ); + return write( vec.data(), cb ); + } + }; +}
\ No newline at end of file diff --git a/ComLightLib/unknwn.h b/ComLightLib/unknwn.h new file mode 100644 index 0000000..5f38359 --- /dev/null +++ b/ComLightLib/unknwn.h @@ -0,0 +1,36 @@ +#pragma once +#include <type_traits> + +// Calling conventions +#ifdef _MSC_VER +#define COMLIGHTCALL __stdcall +#define DECLSPEC_NOVTABLE __declspec(novtable) +#elif defined(__GNUC__) || defined(__clang__) +#if defined(__i386__) +#define COMLIGHTCALL __attribute__((stdcall)) +#else +#define COMLIGHTCALL +#endif +#define DECLSPEC_NOVTABLE +#else +#error Unsupported C++ compiler +#endif + +#include "utils/guid_parse.hpp" + +#define DEFINE_INTERFACE_ID( guidString ) static constexpr GUID iid() { return ::ComLight::make_guid( guidString ); } + +namespace ComLight +{ + // This thing is binary compatible with IUnknown from Windows SDK. See DesktopClient demo project, it uses normal COM interop in .NET framework 4.7 to call my implementation. + struct DECLSPEC_NOVTABLE IUnknown + { + DEFINE_INTERFACE_ID( "00000000-0000-0000-c000-000000000046" ); + + virtual HRESULT COMLIGHTCALL QueryInterface( REFIID riid, void **ppvObject ) = 0; + + virtual uint32_t COMLIGHTCALL AddRef() = 0; + + virtual uint32_t COMLIGHTCALL Release() = 0; + }; +}
\ No newline at end of file diff --git a/ComLightLib/utils/guid_parse.hpp b/ComLightLib/utils/guid_parse.hpp new file mode 100644 index 0000000..435fcb3 --- /dev/null +++ b/ComLightLib/utils/guid_parse.hpp @@ -0,0 +1,103 @@ +// https://github.com/tobias-loew/constexpr-GUID-cpp-11 + +//------------------------------------------------------------------------------------------------------- +// constexpr GUID parsing +// Written by Alexander Bessonov +// Written by Tobias Loew +// +// Licensed under the MIT license. +//------------------------------------------------------------------------------------------------------- + +#pragma once +#include <stdexcept> +#include <string> +#include <cassert> +#include <cstdint> + +#if !defined(GUID_DEFINED) +#define GUID_DEFINED +struct GUID { + uint32_t Data1; + uint16_t Data2; + uint16_t Data3; + uint8_t Data4[ 8 ]; +}; +#endif + +namespace ComLight +{ + namespace details + { + constexpr const size_t short_guid_form_length = 36; // XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX + constexpr const size_t long_guid_form_length = 38; // {XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX} + + constexpr uint8_t parse_hex_digit( const char c ) + { + using namespace std::string_literals; + return + ( '0' <= c && c <= '9' ) + ? c - '0' + : ( 'a' <= c && c <= 'f' ) + ? 10 + c - 'a' + : ( 'A' <= c && c <= 'F' ) + ? 10 + c - 'A' + : + throw std::domain_error{ "invalid character in GUID"s }; + } + + constexpr uint8_t parse_hex_uint8_t( const char *ptr ) + { + return ( parse_hex_digit( ptr[ 0 ] ) << 4 ) + parse_hex_digit( ptr[ 1 ] ); + } + + constexpr uint16_t parse_hex_uint16_t( const char *ptr ) + { + return ( parse_hex_uint8_t( ptr ) << 8 ) + parse_hex_uint8_t( ptr + 2 ); + } + + constexpr uint32_t parse_hex_uint32_t( const char *ptr ) + { + return ( parse_hex_uint16_t( ptr ) << 16 ) + parse_hex_uint16_t( ptr + 4 ); + } + + constexpr GUID parse_guid( const char *begin ) + { + return GUID{ + parse_hex_uint32_t( begin ), + parse_hex_uint16_t( begin + 8 + 1 ), + parse_hex_uint16_t( begin + 8 + 1 + 4 + 1 ), + { + parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 ), + parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 ), + parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 ), + parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 ), + parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 + 2 ), + parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 + 2 + 2 ), + parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 + 2 + 2 + 2 ), + parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 + 2 + 2 + 2 + 2 ) + } + }; + } + + constexpr GUID make_guid_helper( const char *str, size_t N ) + { + using namespace std::string_literals; + using namespace details; + + return ( !( N == long_guid_form_length || N == short_guid_form_length ) ) + ? throw std::domain_error{ "String GUID of the form {XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX} or XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX is expected"s } + : ( N == long_guid_form_length && ( str[ 0 ] != '{' || str[ long_guid_form_length - 1 ] != '}' ) ) + ? throw std::domain_error{ "Missing opening or closing brace"s } + + : parse_guid( str + ( N == long_guid_form_length ? 1 : 0 ) ); + } + + + template<size_t N> + constexpr GUID make_guid( const char( &str )[ N ] ) + { + return make_guid_helper( str, N - 1 ); + } + } + using details::make_guid; +}
\ No newline at end of file diff --git a/ComLightLib/utils/typeTraits.hpp b/ComLightLib/utils/typeTraits.hpp new file mode 100644 index 0000000..c5ddb84 --- /dev/null +++ b/ComLightLib/utils/typeTraits.hpp @@ -0,0 +1,43 @@ +#pragma once +#include <type_traits> + +namespace ComLight +{ + namespace details + { + template<class TResult, class TValue> + constexpr bool pointersAssignable() + { + // See this for why `&` is required: https://stackoverflow.com/a/52429468/126995 + return std::is_assignable<TResult*&, TValue*>::value; + } + } +} + +// https://en.wikibooks.org/wiki/More_C++_Idioms/Member_Detector +#define GENERATE_HAS_MEMBER(member) \ + \ +template < class T > \ +class HasMember_##member \ +{ \ +private: \ + using Yes = char[2]; \ + using No = char[1]; \ + \ + struct Fallback { int member; }; \ + struct Derived : T, Fallback { }; \ + \ + template < class U > \ + static No& test ( decltype(U::member)* ); \ + template < typename U > \ + static Yes& test ( U* ); \ + \ +public: \ + static constexpr bool RESULT = sizeof(test<Derived>(nullptr)) == sizeof(Yes); \ +}; \ + \ +template < class T > \ +struct has_member_##member \ +: public std::integral_constant<bool, HasMember_##member<T>::RESULT> \ +{ \ +}; diff --git a/ComputeShaders/ComputeShaders.cpp b/ComputeShaders/ComputeShaders.cpp new file mode 100644 index 0000000..9c03f27 --- /dev/null +++ b/ComputeShaders/ComputeShaders.cpp @@ -0,0 +1,3 @@ +void fnComputeShaders() +{ +}
\ No newline at end of file diff --git a/ComputeShaders/ComputeShaders.vcxproj b/ComputeShaders/ComputeShaders.vcxproj new file mode 100644 index 0000000..350d266 --- /dev/null +++ b/ComputeShaders/ComputeShaders.vcxproj @@ -0,0 +1,221 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{1c39d386-96d0-47a1-bbfa-68bbdb24439c}</ProjectGuid> + <RootNamespace>ComputeShaders</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <MultiProcFXC>true</MultiProcFXC> + <OutDir>$(Platform)\$(Configuration)\</OutDir> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <OutDir>$(Platform)\$(Configuration)\</OutDir> + <MultiProcFXC>true</MultiProcFXC> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + </ClCompile> + <Link> + <SubSystem> + </SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + </ClCompile> + <Link> + <SubSystem> + </SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + </ClCompile> + <Link> + <SubSystem> + </SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + <FxCompile> + <ShaderModel>5.0</ShaderModel> + <ShaderType>Compute</ShaderType> + </FxCompile> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + </ClCompile> + <Link> + <SubSystem> + </SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + <FxCompile> + <ShaderModel>5.0</ShaderModel> + <ShaderType>Compute</ShaderType> + <DisableOptimizations>true</DisableOptimizations> + <EnableDebuggingInformation>true</EnableDebuggingInformation> + </FxCompile> + </ItemDefinitionGroup> + <ItemGroup> + <ClCompile Include="ComputeShaders.cpp" /> + </ItemGroup> + <ItemGroup> + <FxCompile Include="add.hlsl" /> + <FxCompile Include="addInPlace.hlsl" /> + <FxCompile Include="addRepeat.hlsl" /> + <FxCompile Include="addRepeat64.hlsl" /> + <FxCompile Include="addRepeatGelu.hlsl" /> + <FxCompile Include="addRepeatGelu64.hlsl" /> + <FxCompile Include="addRepeatScale.hlsl" /> + <FxCompile Include="addRows.hlsl" /> + <FxCompile Include="convolutionMain.hlsl" /> + <FxCompile Include="convolutionMain2.hlsl" /> + <FxCompile Include="convolutionMain2Fixed.hlsl" /> + <FxCompile Include="convolutionPrep1.hlsl" /> + <FxCompile Include="convolutionPrep2.hlsl" /> + <FxCompile Include="copyConvert.hlsl" /> + <FxCompile Include="copyTranspose.hlsl" /> + <FxCompile Include="diagMaskInf.hlsl" /> + <FxCompile Include="flashAttention.hlsl" /> + <FxCompile Include="flashAttentionCompat1.hlsl" /> + <FxCompile Include="flashAttentionCompat2.hlsl" /> + <FxCompile Include="flashAttentionCompat3.hlsl" /> + <FxCompile Include="fmaRepeat1.hlsl" /> + <FxCompile Include="fmaRepeat164.hlsl" /> + <FxCompile Include="fmaRepeat2.hlsl" /> + <FxCompile Include="matReshapePanels.hlsl" /> + <FxCompile Include="mulMatByRow.hlsl" /> + <FxCompile Include="mulMatByRow64.hlsl" /> + <FxCompile Include="mulMatByRowTiled.hlsl" /> + <FxCompile Include="mulMatByRowTiled64.hlsl" /> + <FxCompile Include="mulMatByRowTiledEx.hlsl" /> + <FxCompile Include="mulMatByScalar.hlsl" /> + <FxCompile Include="mulMatDotMain.hlsl" /> + <FxCompile Include="mulMatDotReshape.hlsl" /> + <FxCompile Include="mulMatMadMain.hlsl" /> + <FxCompile Include="mulMatTiled.hlsl" /> + <FxCompile Include="mulMatTiled64.hlsl" /> + <FxCompile Include="mulMatTiledEx.hlsl" /> + <FxCompile Include="norm.hlsl" /> + <FxCompile Include="normCompat.hlsl" /> + <FxCompile Include="normFixed.hlsl" /> + <FxCompile Include="normFixed64.hlsl" /> + <FxCompile Include="scaleInPlace.hlsl" /> + <FxCompile Include="softMax.hlsl" /> + <FxCompile Include="softMax64.hlsl" /> + <FxCompile Include="softMaxCompat.hlsl" /> + <FxCompile Include="softMaxFixed.hlsl" /> + <FxCompile Include="zeroMemory.hlsl" /> + </ItemGroup> + <ItemGroup> + <None Include="componentwiseBinaryOp.hlsli" /> + <None Include="flashAttentionCommon.hlsli" /> + <None Include="fp64Utils.hlsli" /> + <None Include="groupReduce.hlsli" /> + <None Include="groupReduce64.hlsli" /> + <None Include="miscUtils.hlsli" /> + <None Include="repeatUtils.hlsli" /> + </ItemGroup> + <ItemGroup> + <Text Include="Readme.txt" /> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/ComputeShaders/ComputeShaders.vcxproj.filters b/ComputeShaders/ComputeShaders.vcxproj.filters new file mode 100644 index 0000000..12f1559 --- /dev/null +++ b/ComputeShaders/ComputeShaders.vcxproj.filters @@ -0,0 +1,66 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <ClCompile Include="ComputeShaders.cpp" /> + </ItemGroup> + <ItemGroup> + <FxCompile Include="mulMatDotMain.hlsl" /> + <FxCompile Include="mulMatDotReshape.hlsl" /> + <FxCompile Include="convolutionMain.hlsl" /> + <FxCompile Include="convolutionPrep1.hlsl" /> + <FxCompile Include="convolutionPrep2.hlsl" /> + <FxCompile Include="add.hlsl" /> + <FxCompile Include="flashAttention.hlsl" /> + <FxCompile Include="convolutionMain2.hlsl" /> + <FxCompile Include="norm.hlsl" /> + <FxCompile Include="copyConvert.hlsl" /> + <FxCompile Include="copyTranspose.hlsl" /> + <FxCompile Include="normCompat.hlsl" /> + <FxCompile Include="flashAttentionCompat1.hlsl" /> + <FxCompile Include="flashAttentionCompat3.hlsl" /> + <FxCompile Include="flashAttentionCompat2.hlsl" /> + <FxCompile Include="scaleInPlace.hlsl" /> + <FxCompile Include="diagMaskInf.hlsl" /> + <FxCompile Include="softMaxCompat.hlsl" /> + <FxCompile Include="mulMatMadMain.hlsl" /> + <FxCompile Include="addRepeat.hlsl" /> + <FxCompile Include="fmaRepeat1.hlsl" /> + <FxCompile Include="fmaRepeat2.hlsl" /> + <FxCompile Include="addInPlace.hlsl" /> + <FxCompile Include="softMax.hlsl" /> + <FxCompile Include="addRepeatScale.hlsl" /> + <FxCompile Include="mulMatByRow.hlsl" /> + <FxCompile Include="mulMatByScalar.hlsl" /> + <FxCompile Include="mulMatTiled.hlsl" /> + <FxCompile Include="mulMatByRow64.hlsl" /> + <FxCompile Include="softMax64.hlsl" /> + <FxCompile Include="softMaxFixed.hlsl" /> + <FxCompile Include="addRepeat64.hlsl" /> + <FxCompile Include="fmaRepeat164.hlsl" /> + <FxCompile Include="addRepeatGelu.hlsl" /> + <FxCompile Include="addRepeatGelu64.hlsl" /> + <FxCompile Include="normFixed.hlsl" /> + <FxCompile Include="normFixed64.hlsl" /> + <FxCompile Include="mulMatByRowTiled.hlsl" /> + <FxCompile Include="convolutionMain2Fixed.hlsl" /> + <FxCompile Include="mulMatByRowTiled64.hlsl" /> + <FxCompile Include="addRows.hlsl" /> + <FxCompile Include="mulMatTiled64.hlsl" /> + <FxCompile Include="zeroMemory.hlsl" /> + <FxCompile Include="mulMatTiledEx.hlsl" /> + <FxCompile Include="matReshapePanels.hlsl" /> + <FxCompile Include="mulMatByRowTiledEx.hlsl" /> + </ItemGroup> + <ItemGroup> + <None Include="componentwiseBinaryOp.hlsli" /> + <None Include="miscUtils.hlsli" /> + <None Include="groupReduce.hlsli" /> + <None Include="fp64Utils.hlsli" /> + <None Include="flashAttentionCommon.hlsli" /> + <None Include="repeatUtils.hlsli" /> + <None Include="groupReduce64.hlsli" /> + </ItemGroup> + <ItemGroup> + <Text Include="Readme.txt" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/ComputeShaders/Readme.txt b/ComputeShaders/Readme.txt new file mode 100644 index 0000000..18d089e --- /dev/null +++ b/ComputeShaders/Readme.txt @@ -0,0 +1,11 @@ +This project compiles all the compute shaders which implement the model. + +Many shaders come in 2 versions, something.hlsl and something64.hlsl + +The version with the `64` suffix is used on AMD GPUs, the version without suffix is used on nVidia and Intel GPUs. + +Not all of these shaders are actually used for anything. +Some of them are implementing binary compatibility for the reference CPU version, and not used unless messing with the `constexpr` flags in MlContext C++ class. +Such shaders often require FP64 support, which is an optional feature in D3D11. +CompressShaders tool detects such shaders by looking at the SFI0 chunk in the binary, and outputs a bitmap of the FP64 shaders. +This way, missing FP64 hardware support shouldn’t break the library.
\ No newline at end of file diff --git a/ComputeShaders/add.hlsl b/ComputeShaders/add.hlsl new file mode 100644 index 0000000..819443e --- /dev/null +++ b/ComputeShaders/add.hlsl @@ -0,0 +1,6 @@ +inline float compute( float a, float b ) +{ + return a + b; +} + +#include "componentwiseBinaryOp.hlsli"
\ No newline at end of file diff --git a/ComputeShaders/addInPlace.hlsl b/ComputeShaders/addInPlace.hlsl new file mode 100644 index 0000000..a83c221 --- /dev/null +++ b/ComputeShaders/addInPlace.hlsl @@ -0,0 +1,39 @@ +#ifndef THREADS +#define THREADS 512 +#endif + +Buffer<float> arg0: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 size: packoffset( c0 ); + uint4 strides: packoffset( c1 ); + uint4 argStrides: packoffset( c3 ); +} + +inline uint rowOffset( uint3 idx, uint4 strides ) +{ + return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ]; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint rdi = rowOffset( group, strides ); + uint rsi = rowOffset( group, argStrides ); + + const uint rdiEnd = rdi + size[ 0 ] * strides[ 0 ]; + rdi += thread * strides[ 0 ]; + rsi += thread * argStrides[ 0 ]; + + const uint rdiInc = THREADS * strides[ 0 ]; + const uint rsiInc = THREADS * argStrides[ 0 ]; + + for( ; rdi < rdiEnd; rdi += rdiInc, rsi += rsiInc ) + { + float f = result[ rdi ]; + f += arg0[ rsi ]; + result[ rdi ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/addRepeat.hlsl b/ComputeShaders/addRepeat.hlsl new file mode 100644 index 0000000..e5cdaa3 --- /dev/null +++ b/ComputeShaders/addRepeat.hlsl @@ -0,0 +1,70 @@ +// Compute tensor = tensor + repeat( pattern, tensor ) in 1 shot, without VRAM allocations +// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor +RWBuffer<float> tensor: register( u0 ); +Buffer<float> pattern: register( t0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSize: packoffset( c2 ); + uint4 patternStrides: packoffset( c3 ); +} + +#ifndef THREADS +#define THREADS 256 +#endif + +#include "repeatUtils.hlsli" + +inline void computeSimple( uint idx, float add ) +{ + float f = tensor[ idx ]; + f += add; + tensor[ idx ] = f; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides ); + uint rsi = rowOffset( group % patternSize.yzw, patternStrides ); + + if( patternSize[ 0 ] == 1 ) + { + // The pattern only has 1 column - broadcasting over the row + const float p = pattern[ rsi ]; + ROW_LOOP( it ) + computeSimple( it.x, p ); + } + else if( patternSize[ 0 ] <= THREADS ) + { + // pattern size doesn't exceed thread group size: load pattern value outside of the loop + const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] ); + if( thread >= threadsPerGroup ) + return; + + const float p = pattern[ rsi + ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ] ]; + ROW_LOOP_EX( it, threadsPerGroup, tensorStrides ) + computeSimple( it.x, p ); + } + else + { + // Pattern rows are larger than the thread group, need to stream from both buffers + const uint rsiInc = THREADS * patternStrides[ 0 ]; + const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ]; + const uint rsiEnd = rsi + rsiDec; + rsi += thread * patternStrides[ 0 ]; + + ROW_LOOP( it ) + { + float f = tensor[ it.x ]; + float p = pattern[ rsi ]; + rsi += rsiInc; + if( rsi >= rsiEnd ) + rsi -= rsiDec; + f += p; + tensor[ it.x ] = f; + } + } +}
\ No newline at end of file diff --git a/ComputeShaders/addRepeat64.hlsl b/ComputeShaders/addRepeat64.hlsl new file mode 100644 index 0000000..b6c8c19 --- /dev/null +++ b/ComputeShaders/addRepeat64.hlsl @@ -0,0 +1,2 @@ +#define THREADS 64 +#include "addRepeat.hlsl"
\ No newline at end of file diff --git a/ComputeShaders/addRepeatGelu.hlsl b/ComputeShaders/addRepeatGelu.hlsl new file mode 100644 index 0000000..7f63653 --- /dev/null +++ b/ComputeShaders/addRepeatGelu.hlsl @@ -0,0 +1,88 @@ +// Compute tensor = GELU( tensor + repeat( pattern, tensor ) ) in 1 shot, without VRAM allocations +// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor +RWBuffer<float> tensor: register( u0 ); +Buffer<float> pattern: register( t0 ); +Buffer<uint> lookupTable: register( t1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSize: packoffset( c2 ); + uint4 patternStrides: packoffset( c3 ); +} + +#ifndef THREADS +#define THREADS 1024 +#endif + +#include "repeatUtils.hlsli" +#include "miscUtils.hlsli" + +inline float gelu( float x ) +{ +#if 1 + const uint index = fp16Rounded( x ); + const uint res16 = lookupTable[ index ]; + return f16tof32( res16 ); +#else + // This version is much slower, at least on AMD, despite saving these VRAM loads. + const float GELU_COEF_A = 0.044715; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; + return 0.5 * x * ( 1.0 + tanh( SQRT_2_OVER_PI * x * ( 1.0 + GELU_COEF_A * x * x ) ) ); +#endif +} + +inline void computeSimple( uint idx, float add ) +{ + float f = tensor[ idx ]; + f += add; + f = gelu( f ); + tensor[ idx ] = f; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides ); + uint rsi = rowOffset( group % patternSize.yzw, patternStrides ); + + if( patternSize[ 0 ] == 1 ) + { + // The pattern only has 1 column - broadcasting over the row + const float p = pattern[ rsi ]; + ROW_LOOP( it ) + computeSimple( it.x, p ); + } + else if( patternSize[ 0 ] <= THREADS ) + { + // pattern size doesn't exceed thread group size: load pattern value outside of the loop + const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] ); + if( thread >= threadsPerGroup ) + return; + + const float p = pattern[ rsi + ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ] ]; + ROW_LOOP_EX( it, threadsPerGroup, tensorStrides ) + computeSimple( it.x, p ); + } + else + { + // Pattern rows are larger than the thread group, need to stream from both buffers + const uint rsiInc = THREADS * patternStrides[ 0 ]; + const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ]; + const uint rsiEnd = rsi + rsiDec; + rsi += thread * patternStrides[ 0 ]; + + ROW_LOOP( it ) + { + float f = tensor[ it.x ]; + float p = pattern[ rsi ]; + rsi += rsiInc; + if( rsi >= rsiEnd ) + rsi -= rsiDec; + f += p; + f = gelu( f ); + tensor[ it.x ] = f; + } + } +}
\ No newline at end of file diff --git a/ComputeShaders/addRepeatGelu64.hlsl b/ComputeShaders/addRepeatGelu64.hlsl new file mode 100644 index 0000000..3d9b2e8 --- /dev/null +++ b/ComputeShaders/addRepeatGelu64.hlsl @@ -0,0 +1,2 @@ +#define THREADS 64 +#include "addRepeatGelu.hlsl"
\ No newline at end of file diff --git a/ComputeShaders/addRepeatScale.hlsl b/ComputeShaders/addRepeatScale.hlsl new file mode 100644 index 0000000..8c24088 --- /dev/null +++ b/ComputeShaders/addRepeatScale.hlsl @@ -0,0 +1,73 @@ +// Compute tensor = ( tensor + repeat( pattern, tensor ) ) * scale in 1 shot, without VRAM allocations +// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor +RWBuffer<float> tensor: register( u0 ); +Buffer<float> pattern: register( t0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSize: packoffset( c2 ); + uint4 patternStrides: packoffset( c3 ); + float scalingMul : packoffset( c4.x ); +} + +#ifndef THREADS +#define THREADS 512 +#endif + +#include "repeatUtils.hlsli" + +inline void computeSimple( uint idx, float add ) +{ + float f = tensor[ idx ]; + f += add; + f *= scalingMul; + tensor[ idx ] = f; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides ); + uint rsi = rowOffset( group % patternSize.yzw, patternStrides ); + + if( patternSize[ 0 ] == 1 ) + { + // The pattern only has 1 column - broadcasting over the row + const float p = pattern[ rsi ]; + ROW_LOOP( it ) + computeSimple( it.x, p ); + } + else if( patternSize[ 0 ] <= THREADS ) + { + // pattern size doesn't exceed thread group size: load pattern value outside of the loop + const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] ); + if( thread >= threadsPerGroup ) + return; + + const float p = pattern[ rsi + ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ] ]; + ROW_LOOP_EX( it, threadsPerGroup, tensorStrides ) + computeSimple( it.x, p ); + } + else + { + // Pattern rows are larger than the thread group, need to stream from both buffers + const uint rsiInc = THREADS * patternStrides[ 0 ]; + const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ]; + const uint rsiEnd = rsi + rsiDec; + rsi += thread * patternStrides[ 0 ]; + + ROW_LOOP( it ) + { + float f = tensor[ it.x ]; + float p = pattern[ rsi ]; + rsi += rsiInc; + if( rsi >= rsiEnd ) + rsi -= rsiDec; + f += p; + f *= scalingMul; + tensor[ it.x ] = f; + } + } +}
\ No newline at end of file diff --git a/ComputeShaders/addRows.hlsl b/ComputeShaders/addRows.hlsl new file mode 100644 index 0000000..21e5e73 --- /dev/null +++ b/ComputeShaders/addRows.hlsl @@ -0,0 +1,46 @@ +#ifndef THREADS +#define THREADS 256 +#endif + +// dec.tokenEmbedding tensor +Buffer<float> tokenEmbedding: register( t0 ); +// dec.positionalEmbedding tensor +Buffer<float> positionalEmbedding: register( t1 ); +// R32_UINT buffer with the input tokens +Buffer<uint> embd: register( t2 ); +// Output tensor +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint rowLength: packoffset( c0.x ); + uint pastTokensCount: packoffset( c0.y ); + uint outputRowStride: packoffset( c0.z ); + uint2 embStrides: packoffset( c1.x ); + uint2 posStrides: packoffset( c1.z ); +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint row = group.x; + const uint rowTok = embd[ row ]; + const uint rowPos = row + pastTokensCount; + + uint rdi = row * outputRowStride; + const uint rdiEnd = rdi + rowLength; + rdi += thread; + + uint rsiTok = rowTok * embStrides.y; + rsiTok += thread * embStrides.x; + + uint rsiPos = rowPos * posStrides.y; + rsiPos += thread * posStrides.x; + + for( ; rdi < rdiEnd; rdi += THREADS, rsiTok += THREADS * embStrides.x, rsiPos += THREADS * posStrides.x ) + { + float a = tokenEmbedding[ rsiTok ]; + float b = positionalEmbedding[ rsiPos ]; + result[ rdi ] = a + b; + } +}
\ No newline at end of file diff --git a/ComputeShaders/componentwiseBinaryOp.hlsli b/ComputeShaders/componentwiseBinaryOp.hlsli new file mode 100644 index 0000000..0a523ca --- /dev/null +++ b/ComputeShaders/componentwiseBinaryOp.hlsli @@ -0,0 +1,43 @@ +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + uint4 src1_elements: packoffset( c2 ); + uint4 src1_strides: packoffset( c3 ); + uint4 result_elements: packoffset( c4 ); + uint4 result_strides: packoffset( c5 ); +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint j = group.x; + const uint nb1 = result_strides[ 1 ]; + const uint nb01 = src0_strides[ 1 ]; + + const uint nb10 = src1_strides[ 0 ]; + const uint nb11 = src1_strides[ 1 ]; + const uint nc = src0_elements[ 0 ]; + + uint rsi0 = j * nb01; + uint rsi1 = j * nb11; + uint rdi = j * nb1; + const uint rsi0End = rsi0 + nc; + + rsi0 += thread; + rsi1 += thread * nb10; + rdi += thread; + + const uint rsi1Inc = 32 * nb10; + for( ; rsi0 < rsi0End; rsi0 += 32, rsi1 += rsi1Inc, rdi += 32 ) + { + const float a = arg0[ rsi0 ]; + const float b = arg1[ rsi1 ]; + const float res = compute( a, b ); + result[ rdi ] = res; + } +}
\ No newline at end of file diff --git a/ComputeShaders/convolutionMain.hlsl b/ComputeShaders/convolutionMain.hlsl new file mode 100644 index 0000000..eee3ebb --- /dev/null +++ b/ComputeShaders/convolutionMain.hlsl @@ -0,0 +1,76 @@ +// ggml_compute_forward_conv_1d_1s_f16_f32, GGML_TASK_COMPUTE implementation +// Dispatch [ ne10, ne02, 1 ] thread groups +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + uint4 src1_elements: packoffset( c2 ); + uint4 result_elements: packoffset( c4 ); + uint4 result_strides: packoffset( c5 ); +} + +#include "groupReduce.hlsli" + +inline void computeDotProduct( uint s0, uint s1, uint len, uint thread, inout float acc ) +{ + float curr = 0; + const uint completeVectors = len / 32; + uint i; + for( i = 0; i < completeVectors; i++, s0 += 32, s1 += 32 ) + curr = mad( arg0[ s0 + thread ], arg1[ s1 + thread ], curr ); + + horizontalSumCompatNew( thread, curr ); + + if( 0 == thread ) + { + const uint rem = len % 32; + if( 0 != rem ) + { + double f64 = curr; + for( i = 0; i < rem; i++ ) + { + precise float a = arg0[ s0 + i ]; + precise float b = arg1[ s1 + i ]; + precise float prod = a * b; + f64 += prod; + } + curr = (float)f64; + } + acc += curr; + } +} + +#include "miscUtils.hlsli" + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint i1 = group.y; + const uint i0 = group.x; + + const uint ne00 = src0_elements[ 0 ]; + const uint nk = ne00; + const int nh = (int)( nk / 2 ); + + const uint ne01 = src0_elements[ 1 ]; + const int ew0 = roundUp32( ne01 ); + + float res = 0; + for( int k = -nh; k <= nh; k++ ) + { + const uint source0 = i1 * ew0 * ne00 + uint( nh + k ) * ew0; + const uint source1 = uint( i0 + nh + k ) * ew0; + computeDotProduct( source0, source1, ew0, thread, res ); + } + + if( 0 != thread ) + return; + + const uint nb1 = result_strides[ 1 ]; + const uint rdi = i1 * nb1 + i0; + result[ rdi ] = res; +}
\ No newline at end of file diff --git a/ComputeShaders/convolutionMain2.hlsl b/ComputeShaders/convolutionMain2.hlsl new file mode 100644 index 0000000..73c9da1 --- /dev/null +++ b/ComputeShaders/convolutionMain2.hlsl @@ -0,0 +1,60 @@ +// ggml_compute_forward_conv_1d_2s_f16_f32, GGML_TASK_COMPUTE implementation +// Dispatch [ ne10 / 2, ne02, 1 ] thread groups +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + uint4 src1_elements: packoffset( c2 ); + uint4 result_elements: packoffset( c4 ); + uint4 result_strides: packoffset( c5 ); +} + +#include "groupReduce.hlsli" + +inline void computeDotProduct( uint s0, uint s1, uint len, uint thread, inout float acc ) +{ + float curr = 0; + const uint s0End = s0 + len; + s0 += thread; + s1 += thread; + for( ; s0 < s0End; s0 += 32, s1 += 32 ) + curr = mad( arg0[ s0 ], arg1[ s1 ], curr ); + + horizontalSumCompatNew( thread, curr ); + if( 0 == thread ) + acc += curr; +} + +#include "miscUtils.hlsli" + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint ne00 = src0_elements[ 0 ]; + const uint ne01 = src0_elements[ 1 ]; + const int ew0 = roundUp32( ne01 ); + + float res = 0; + uint s0 = group.y * ew0 * ne00; + uint s1 = group.x * 2 * ew0; + // The original implementation did following: + // int nh = (int)( nk / 2 ); + // for( int k = -nh; k <= nh; k++ ) + // What we doing instead: + // for( uint len = ( nk / 2 ) * 2 + 1, i = 0; i < len; i++ ) + // len = ( nk / 2 ) * 2 + 1 is equal to ( nk | 1 ) + const uint s0End = s0 + ( ne00 | 1u ) * ew0; + for( ; s0 < s0End; s0 += ew0, s1 += ew0 ) + computeDotProduct( s0, s1, ew0, thread, res ); + + if( 0 != thread ) + return; + + const uint nb1 = result_strides[ 1 ]; + const uint rdi = group.y * nb1 + group.x; + result[ rdi ] = res; +}
\ No newline at end of file diff --git a/ComputeShaders/convolutionMain2Fixed.hlsl b/ComputeShaders/convolutionMain2Fixed.hlsl new file mode 100644 index 0000000..d21dcbb --- /dev/null +++ b/ComputeShaders/convolutionMain2Fixed.hlsl @@ -0,0 +1,119 @@ +// Optimized version of convolutionMain2.hlsl for kernel size = 3 +// Dispatch [ ( ( ne10 / 2 ) + TILE_Y - 1 ) / TILE_Y, ne02, 1 ] thread groups of this shader +#ifndef TILE_Y +static const uint TILE_Y = 8; +#endif +#ifndef THREADS +static const uint THREADS = 64; +#endif + +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + uint4 src1_elements: packoffset( c2 ); + uint4 result_elements: packoffset( c4 ); + uint4 result_strides: packoffset( c5 ); +} + +// The accumulators we're after +groupshared float resTemp[ TILE_Y ][ THREADS ]; + +// Multiply + accumulate the specified row +inline void accumulate( float a0, float a1, const uint resultRow, const uint thread ) +{ + float acc = resTemp[ resultRow ][ thread ]; + acc = mad( a0, a1, acc ); + resTemp[ resultRow ][ thread ] = acc; +} + +inline void convolutionTile( const uint s0, uint s1, const uint thread, const uint stride, const uint height ) +{ + // Load 3 rows from arg0 + const float3 a0 = float3( arg0[ s0 ], arg0[ s0 + stride ], arg0[ s0 + stride * 2 ] ); + + // Row 0 + float a1 = arg1[ s1 ]; + accumulate( a0[ 0 ], a1, 0, thread ); + s1 += stride; + + for( uint i = 1; i < height; i++ ) + { + // Row i*2-1 + // Even-indexed rows only contribute to a single output rows, after muiltiplied by kernel row #1 + a1 = arg1[ s1 ]; + accumulate( a0[ 1 ], a1, i - 1, thread ); + s1 += stride; + + // Row i*2, contributes to 2 output rows corresponding to kernel rows #0 and #2 + a1 = arg1[ s1 ]; + accumulate( a0[ 2 ], a1, i - 1, thread ); + accumulate( a0[ 0 ], a1, i, thread ); + s1 += stride; + } + + // Row height*2 - 1 + a1 = arg1[ s1 ]; + accumulate( a0[ 1 ], a1, height - 1, thread ); + s1 += stride; + + // Row height*2 + a1 = arg1[ s1 ]; + accumulate( a0[ 2 ], a1, height - 1, thread ); +} + +#include "miscUtils.hlsli" + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint i; + // Zero out the accumulators + for( i = 0; i < TILE_Y; i++ ) + resTemp[ i ][ thread ] = 0.0; + GroupMemoryBarrierWithGroupSync(); + + const uint i1 = group.y; + const uint i0 = group.x * TILE_Y * 2; + const uint height = min( TILE_Y, ( src1_elements.x / 2 ) - group.x * TILE_Y ); + + const uint ne00 = src0_elements[ 0 ]; + const uint ne01 = src0_elements[ 1 ]; + const int ew0 = roundUp32( ne01 ); + + uint s0 = i1 * ew0 * ne00; + const uint s0End = s0 + ew0; + uint s1 = i0 * ew0; + s0 += thread; + s1 += thread; + for( ; s0 < s0End; s0 += THREADS, s1 += THREADS ) + convolutionTile( s0, s1, thread, ew0, height ); + + GroupMemoryBarrierWithGroupSync(); + + // Now we need horizontal sums of these shared accumulators, i.e. reduce [height][THREADS] shared array into [height][1] column + for( i = THREADS / 2; i > 0; i /= 2 ) + { + if( thread < i ) + { + for( uint j = 0; j < height; j++ ) + { + float sum = resTemp[ j ][ thread ]; + sum += resTemp[ j ][ thread + i ]; + resTemp[ j ][ thread ] = sum; + } + } + GroupMemoryBarrierWithGroupSync(); + } + + // And finally, store that column to global memory + if( thread >= height ) + return; + const uint nb1 = result_strides[ 1 ]; + const uint rdi = i1 * nb1 + group.x * TILE_Y + thread; + result[ rdi ] = resTemp[ thread ][ 0 ]; +}
\ No newline at end of file diff --git a/ComputeShaders/convolutionPrep1.hlsl b/ComputeShaders/convolutionPrep1.hlsl new file mode 100644 index 0000000..528ff16 --- /dev/null +++ b/ComputeShaders/convolutionPrep1.hlsl @@ -0,0 +1,39 @@ +// ggml_compute_forward_conv_1d_1s_f16_f32, prepare kernel data (src0) +// Dispatch [ ne01, ne02, 1 ] thread groups +Buffer<float> arg0: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); +} + +inline uint roundUp32( uint x ) +{ + return ( x + 31 ) & ( ~31u ); +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint nb01 = src0_strides[ 1 ]; + const uint nb02 = src0_strides[ 2 ]; + + const uint ne00 = src0_elements[ 0 ]; + const uint ne01 = src0_elements[ 1 ]; + const uint ew0 = roundUp32( ne01 ); + + const uint i02 = group.y; + const uint i01 = group.x; + + uint rsi = i02 * nb02 + i01 * nb01; + const uint rsiEnd = rsi + ne00; + uint rdi = i02 * ew0 * ne00 + i01; + rsi += thread; + rdi += thread * ew0; + const uint rdiInc = 32 * ew0; + + for( ; rsi < rsiEnd; rsi += 32, rdi += rdiInc ) + result[ rdi ] = arg0[ rsi ]; +}
\ No newline at end of file diff --git a/ComputeShaders/convolutionPrep2.hlsl b/ComputeShaders/convolutionPrep2.hlsl new file mode 100644 index 0000000..a7e7172 --- /dev/null +++ b/ComputeShaders/convolutionPrep2.hlsl @@ -0,0 +1,43 @@ +// ggml_compute_forward_conv_1d_1s_f16_f32, prepare source data (src1) +// Dispatch [ ne11, 1, 1 ] thread groups +Buffer<float> arg1: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src1_elements: packoffset( c2 ); + uint4 src1_strides: packoffset( c3 ); +} + +#include "miscUtils.hlsli" + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint i11 = group.x; + + const uint ne00 = src0_elements[ 0 ]; + const uint ne01 = src0_elements[ 1 ]; + const uint ne10 = src1_elements[ 0 ]; + const uint nb11 = src1_strides[ 1 ]; + + const uint nk = ne00; + const uint nh = nk / 2; + const int ew0 = roundUp32( ne01 ); + + uint rsi = i11 * nb11; + uint rdi = nh * ew0 + i11; + const uint rdiInc = ew0 * 32; + const uint rsiEnd = rsi + ne10; + + rsi += thread; + rdi += thread * ew0; + + for( ; rsi < rsiEnd; rsi += 32, rdi += rdiInc ) + { + float f = arg1[ rsi ]; + f = adjustFp16( f ); + result[ rdi ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/copyConvert.hlsl b/ComputeShaders/copyConvert.hlsl new file mode 100644 index 0000000..399147f --- /dev/null +++ b/ComputeShaders/copyConvert.hlsl @@ -0,0 +1,50 @@ +// ggml_compute_forward_dup_f32 when we only need to convert types, but not reshape the tensor +// Dispatch [ ne01, ne02, ne03 ] thread groups of this shader +Buffer<float> arg0: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + bool downcastFp32 : packoffset( c2.x ); +} + +#include "miscUtils.hlsli" + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint nb00 = src0_strides[ 0 ]; + const uint nb01 = src0_strides[ 1 ]; + const uint nb02 = src0_strides[ 2 ]; + const uint nb03 = src0_strides[ 3 ]; + + const uint ne00 = src0_elements[ 0 ]; + const uint ne01 = src0_elements[ 1 ]; + const uint ne02 = src0_elements[ 2 ]; + const uint ne03 = src0_elements[ 3 ]; + + const uint i01 = group.x; + const uint i02 = group.y; + const uint i03 = group.z; + + const uint rs = ne00 * nb00; + //const uint id = i01 + i02 * ne02 + i03 * ne01 * ne02; + const uint id = ( i03 * ne01 + i02 ) * ne02 + i01; + + uint rsi = i01 * nb01 + i02 * nb02 + i03 * nb03; + uint rdi = id * rs; + + const uint rsiEnd = rsi + rs; + rsi += thread; + rdi += thread; + for( ; rsi < rsiEnd; rsi += 32, rdi += 32 ) + { + float f = arg0[ rsi ]; + [branch] + if( downcastFp32 ) + f = adjustFp16( f ); + result[ rdi ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/copyTranspose.hlsl b/ComputeShaders/copyTranspose.hlsl new file mode 100644 index 0000000..fc3e82f --- /dev/null +++ b/ComputeShaders/copyTranspose.hlsl @@ -0,0 +1,56 @@ +// ggml_compute_forward_dup_f32 when we actually need to reshape the tensor +// Dispatch [ ne01, ne02, ne03 ] thread groups of this shader +Buffer<float> arg0: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + bool downcastFp32 : packoffset( c2.x ); +} + +#include "miscUtils.hlsli" + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint nb00 = src0_strides[ 0 ]; + const uint nb01 = src0_strides[ 1 ]; + const uint nb02 = src0_strides[ 2 ]; + const uint nb03 = src0_strides[ 3 ]; + + const uint ne00 = src0_elements[ 0 ]; + const uint ne01 = src0_elements[ 1 ]; + const uint ne02 = src0_elements[ 2 ]; + const uint ne03 = src0_elements[ 3 ]; + + const uint i01 = group.x; + const uint i02 = group.y; + const uint i03 = group.z; + + // We need following integer: i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02 + // We want to minimize count of integer multiplications + // Also, DXBC assembly features `imad` instruction which computes a*b+c for integers, the actual hardware hopefully has an equivalent + // i03*ne00*ne01*ne02 + i02*ne00*ne01 + i01*ne00 + // ( i03*ne01*ne02 + i02*ne01 + i01 ) * ne00 + // ( ( i03*ne02 + i02) * ne01 + i01 ) * ne00 + uint rdi = ( ( i03 * ne02 + i02 ) * ne01 + i01 ) * ne00; + + const uint rdiEnd = rdi + ne00; + + uint rsi = i01 * nb01 + i02 * nb02 + i03 * nb03; + const uint rsiInc = 32 * nb00; + + rdi += thread; + rsi += thread * nb00; + + for( ; rdi < rdiEnd; rdi += 32, rsi += rsiInc ) + { + float f = arg0[ rsi ]; + [branch] + if( downcastFp32 ) + f = adjustFp16( f ); + result[ rdi ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/diagMaskInf.hlsl b/ComputeShaders/diagMaskInf.hlsl new file mode 100644 index 0000000..18e3938 --- /dev/null +++ b/ComputeShaders/diagMaskInf.hlsl @@ -0,0 +1,30 @@ +// ggml_compute_forward_diag_mask_inf_f32 +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 elements: packoffset( c0 ); + uint4 strides: packoffset( c1 ); + uint n_past : packoffset( c2.x ); +} + +static const float negativeInfinity = asfloat( 0xff800000 ); + +[numthreads( 32, 1, 1 )] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint k = group.y; + const uint j = group.x; + + // Start of the row + uint rdi = k * strides[ 2 ] + j * strides[ 1 ]; + // End of the row + const uint rdiEnd = rdi + elements[ 0 ] * strides[ 0 ]; + // First index to write in this thread + rdi += ( n_past + j + thread + 1 ) * strides[ 0 ]; + // Index increment + const uint rdiInc = 32 * strides[ 0 ]; + + for( ; rdi < rdiEnd; rdi += rdiInc ) + result[ rdi ] = negativeInfinity; +}
\ No newline at end of file diff --git a/ComputeShaders/flashAttention.hlsl b/ComputeShaders/flashAttention.hlsl new file mode 100644 index 0000000..65212d0 --- /dev/null +++ b/ComputeShaders/flashAttention.hlsl @@ -0,0 +1,170 @@ +// Ported from ggml_compute_forward_flash_attn_f16 +// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups + +#include "flashAttentionCommon.hlsli" +Buffer<uint> lookupTable: register( t3 ); +#include "groupReduce.hlsli" + +inline void computeDotProduct( Buffer<float> buff0, Buffer<float> buff1, uint s0, uint s1, const uint len, const uint thread, inout float acc ) +{ + acc = 0; + const uint s0End = s0 + len; + s0 += thread; + s1 += thread; + for( ; s0 < s0End; s0 += 32, s1 += 32 ) + acc = mad( buff0[ s0 ], buff1[ s1 ], acc ); + + horizontalSum( thread, acc ); +} + +inline void computeDotProduct( Buffer<float> buff0, RWBuffer<float> buff1, uint s0, uint s1, const uint len, const uint thread, inout float acc ) +{ + acc = 0; + const uint s0End = s0 + len; + s0 += thread; + s1 += thread; + for( ; s0 < s0End; s0 += 32, s1 += 32 ) + acc = mad( buff0[ s0 ], buff1[ s1 ], acc ); + + horizontalSum( thread, acc ); +} + +void scaleTempVector( uint i, const uint length, const uint thread, const float multiplier, bool round ) +{ + const uint end = i + length; + for( i += thread; i < end; i += 32 ) + { + float f = temp[ i ]; + f *= multiplier; + if( round ) + f = roundToFp16( f ); + temp[ i ] = f; + } +} + +#include "miscUtils.hlsli" + +// Transform temp[ i ] = exp( temp[ i ] - tempMax ), and return the sum of these values +inline float applySoftMax( uint i, const uint length, const uint thread, const float tempMax ) +{ + // Transform the values, and compute per-thread sum + const uint end = i + length; + float sum = 0; + for( i += thread; i < end; i += 32 ) + { + float f = temp[ i ]; + [branch] + if( f != negativeInfinity ) + { + f -= tempMax; + const uint index = fp16Rounded( f ); + const uint res16 = lookupTable[ index ]; + f = f16tof32( res16 ); + } + else + f = 0; + + temp[ i ] = f; + sum += f; + } + + // Reduce per-thread sum to the global one, over all threads of the group + horizontalSumBroadcast( thread, sum ); + return sum; +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint neq0 = q_elements[ 0 ]; + const uint neq1 = q_elements[ 1 ]; + const uint neq2 = q_elements[ 2 ]; + const uint neq3 = q_elements[ 3 ]; + + const uint nek0 = k_elements[ 0 ]; + const uint nek1 = k_elements[ 1 ]; + + const uint nev1 = v_elements[ 1 ]; + + const uint ne0 = res_elements[ 0 ]; + const uint ne1 = res_elements[ 1 ]; + + const uint nbk0 = k_strides[ 0 ]; + const uint nbk1 = k_strides[ 1 ]; + const uint nbk2 = k_strides[ 2 ]; + const uint nbk3 = k_strides[ 3 ]; + + const uint nbq0 = q_strides[ 0 ]; + const uint nbq1 = q_strides[ 1 ]; + const uint nbq2 = q_strides[ 2 ]; + const uint nbq3 = q_strides[ 3 ]; + + const uint nbv0 = v_strides[ 0 ]; + const uint nbv1 = v_strides[ 1 ]; + const uint nbv2 = v_strides[ 2 ]; + const uint nbv3 = v_strides[ 3 ]; + + const uint nb0 = res_strides[ 0 ]; + const uint nb1 = res_strides[ 1 ]; + const uint nb2 = res_strides[ 2 ]; + const uint nb3 = res_strides[ 3 ]; + + const uint D = neq0; + const uint N = neq1; + const uint P = nek1 - N; + const uint M = nek1; + + const uint ir = group.x; + const uint iq3 = ir / ( neq2 * neq1 ); + const uint iq2 = ( ir - iq3 * neq2 * neq1 ) / neq1; + const uint iq1 = ( ir - iq3 * neq2 * neq1 - iq2 * neq1 ); + + const uint tempIndex = ir * tempBufferStride; + + uint ic; + float tvm = negativeInfinity; + const uint s1 = iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3; + uint s0 = iq2 * nbk2 + iq3 * nbk3; + for( ic = 0; ic < nek1; ic++, s0 += nbk1 ) + { + if( masked ) + { + if( ic > P + iq1 ) + { + if( 0 == thread ) + temp[ tempIndex + ic ] = negativeInfinity; + continue; + } + } + + float dp; + computeDotProduct( k, q, s0, s1, neq0, thread, dp ); + if( 0 == thread ) + { + dp *= scale; + temp[ tempIndex + ic ] = dp; + tvm = max( tvm, dp ); + } + } + + if( 0 == thread ) + sharedAccumulators[ 0 ] = tvm; + GroupMemoryBarrierWithGroupSync(); + tvm = sharedAccumulators[ 0 ]; + + // Softmax + { + float sum = applySoftMax( tempIndex, M, thread, tvm ); + scaleTempVector( tempIndex, M, thread, 1.0 / sum, true ); + } + + s0 = iq2 * nbv2 + iq3 * nbv3; + uint rdi = iq1 * nb1 + iq2 * nb2 + iq3 * nb3; + for( ic = 0; ic < nev1; ic++, s0 += nbv1, rdi += nb0 ) + { + float dp; + computeDotProduct( v, temp, s0, tempIndex, nek1, thread, dp ); + if( 0 == thread ) + result[ rdi ] = dp; + } +}
\ No newline at end of file diff --git a/ComputeShaders/flashAttentionCommon.hlsli b/ComputeShaders/flashAttentionCommon.hlsli new file mode 100644 index 0000000..68ed30b --- /dev/null +++ b/ComputeShaders/flashAttentionCommon.hlsli @@ -0,0 +1,67 @@ +// Ported from ggml_compute_forward_flash_attn_f16 +// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups +Buffer<float> q: register( t0 ); +Buffer<float> k: register( t1 ); +Buffer<float> v: register( t2 ); + +RWBuffer<float> result: register( u0 ); +// This temporary buffer should fit tempBufferStride * neq1 * neq2 * neq3 elements, FP32 precision +RWBuffer<float> temp: register( u1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 q_elements: packoffset( c0 ); + uint4 q_strides: packoffset( c1 ); + uint4 k_elements: packoffset( c2 ); + uint4 k_strides: packoffset( c3 ); + uint4 v_elements: packoffset( c4 ); + uint4 v_strides: packoffset( c5 ); + uint4 res_elements: packoffset( c6 ); + uint4 res_strides: packoffset( c7 ); + + bool masked : packoffset( c8.x ); + // 1.0 / sqrt( (double) D ) + float scale : packoffset( c8.y ); + // This number is required to be >= nek1, and ideally rounded up to either 32 (L2 line) or 128 (L1 line) bytes + uint tempBufferStride: packoffset( c8.z ); +} + +static const float negativeInfinity = asfloat( 0xff800000 ); + +// Convert FP32 number to FP16 using rounding to nearest, then upcast back to FP32 +inline float roundToFp16( const float src ) +{ + const uint trunc16 = f32tof16( src ); + const float trunc32 = f16tof32( trunc16 ); + + const uint truncExp = ( trunc16 >> 10 ) & 0x1F; + if( truncExp != 0x1F ) + { + const uint next16 = trunc16 + 1; + const float next32 = f16tof32( next16 ); + + const float errTrunc = abs( src - trunc32 ); + const float errNext = abs( src - next32 ); + + if( errTrunc < errNext ) + { + // Truncated was closer to the source + return trunc32; + } + else if( errTrunc > errNext ) + { + // Truncated + 1 was closer to the source + return next32; + } + else + { + // Exactly half, doing banker's rounding to nearest even + return ( 0 == ( trunc16 & 1 ) ) ? trunc32 : next32; + } + } + else + { + // INF or NAN + return trunc32; + } +}
\ No newline at end of file diff --git a/ComputeShaders/flashAttentionCompat1.hlsl b/ComputeShaders/flashAttentionCompat1.hlsl new file mode 100644 index 0000000..f1f2ddb --- /dev/null +++ b/ComputeShaders/flashAttentionCompat1.hlsl @@ -0,0 +1,125 @@ +// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups +#include "flashAttentionCommon.hlsli" +#include "groupReduce.hlsli" + +inline void computeDotProduct( Buffer<float> buff0, Buffer<float> buff1, uint s0, uint s1, const uint len, const uint thread, inout float acc ) +{ + acc = 0; + /* + const uint s0End = s0 + len; + s0 += thread; + s1 += thread; + for( ; s0 < s0End; s0 += 32, s1 += 32 ) + acc = mad( buff0[ s0 ], buff1[ s1 ], acc ); + horizontalSumCompatNew( thread, acc ); + */ + + const uint completeVectors = len / 32; + uint i; + for( i = 0; i < completeVectors; i++, s0 += 32, s1 += 32 ) + acc = mad( buff0[ s0 + thread ], buff1[ s1 + thread ], acc ); + + horizontalSumCompatNew( thread, acc ); + + if( 0 == thread ) + { + const uint rem = len % 32; + for( i = 0; i < rem; i++ ) + { + precise float a = buff0[ s0 + i ]; + precise float b = buff1[ s1 + i ]; + precise float prod = a * b; + acc += prod; + } + } +} + +void scaleTempVector( uint i, const uint length, const uint thread, const float multiplier ) +{ + const uint end = i + length; + for( i += thread; i < end; i += 32 ) + { + float f = temp[ i ]; + f *= multiplier; + temp[ i ] = f; + } +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint neq0 = q_elements[ 0 ]; + const uint neq1 = q_elements[ 1 ]; + const uint neq2 = q_elements[ 2 ]; + const uint neq3 = q_elements[ 3 ]; + + const uint nek0 = k_elements[ 0 ]; + const uint nek1 = k_elements[ 1 ]; + + const uint nev1 = v_elements[ 1 ]; + + const uint ne0 = res_elements[ 0 ]; + const uint ne1 = res_elements[ 1 ]; + + const uint nbk0 = k_strides[ 0 ]; + const uint nbk1 = k_strides[ 1 ]; + const uint nbk2 = k_strides[ 2 ]; + const uint nbk3 = k_strides[ 3 ]; + + const uint nbq0 = q_strides[ 0 ]; + const uint nbq1 = q_strides[ 1 ]; + const uint nbq2 = q_strides[ 2 ]; + const uint nbq3 = q_strides[ 3 ]; + + const uint nbv0 = v_strides[ 0 ]; + const uint nbv1 = v_strides[ 1 ]; + const uint nbv2 = v_strides[ 2 ]; + const uint nbv3 = v_strides[ 3 ]; + + const uint nb0 = res_strides[ 0 ]; + const uint nb1 = res_strides[ 1 ]; + const uint nb2 = res_strides[ 2 ]; + const uint nb3 = res_strides[ 3 ]; + + const uint D = neq0; + const uint N = neq1; + const uint P = nek1 - N; + // const uint M = P + N; + const uint M = nek1; + + const uint ir = group.x; + const uint iq3 = ir / ( neq2 * neq1 ); + const uint iq2 = ( ir - iq3 * neq2 * neq1 ) / neq1; + const uint iq1 = ( ir - iq3 * neq2 * neq1 - iq2 * neq1 ); + + const uint tempIndex = ir * tempBufferStride; + + uint ic; + for( ic = 0; ic < nek1; ic++ ) + { + // k indices + const uint ik3 = iq3; + const uint ik2 = iq2; + const uint ik1 = ic; + + // S indices + const uint i1 = ik1; + + if( masked ) + { + if( ic > P + iq1 ) + { + if( 0 == thread ) + temp[ tempIndex + ic ] = negativeInfinity; + continue; + } + } + + const uint s0 = ik1 * nbk1 + ik2 * nbk2 + ik3 * nbk3; + const uint s1 = iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3; + float dp; + computeDotProduct( k, q, s0, s1, neq0, thread, dp ); + if( 0 == thread ) + temp[ tempIndex + ic ] = dp * scale; + } +}
\ No newline at end of file diff --git a/ComputeShaders/flashAttentionCompat2.hlsl b/ComputeShaders/flashAttentionCompat2.hlsl new file mode 100644 index 0000000..73f5fce --- /dev/null +++ b/ComputeShaders/flashAttentionCompat2.hlsl @@ -0,0 +1,114 @@ +// Dispatch with [ ( neq1*neq2*neq3 + 31 ) / 32, 1, 1 ] thread groups +#include "flashAttentionCommon.hlsli" +Buffer<uint> lookupTable: register( t3 ); + +void scaleTempVector( uint i, const uint length, const float multiplier ) +{ + const uint end = i + length; + for( ; i < end; i++ ) + { + float f = temp[ i ]; + f *= multiplier; + // Rounding in this shader causes numerical errors on my GeForce 1080 Ti GPU, driver 527.56 + // f = roundToFp16( f ); + temp[ i ] = f; + } +} + +inline float computeTempVectorMax( uint i, const uint length ) +{ + // Compute per-thread maximum + const uint end = i + length; + float ax = negativeInfinity; + for( ; i < end; i++ ) + ax = max( ax, temp[ i ] ); + return ax; +} + +#include "miscUtils.hlsli" +#include "fp64Utils.hlsli" + +// Transform temp[ i ] = exp( temp[ i ] - tempMax ), and return the sum of these values +inline double applySoftMax( uint i, const uint length, const float tempMax ) +{ + // Transform the values, and compute per-thread sum + const uint end = i + length; + double sum = 0; + for( ; i < end; i++ ) + { + float f = temp[ i ]; + [branch] + if( f != negativeInfinity ) + { + f -= tempMax; + const uint index = fp16Rounded( f ); + const uint res16 = lookupTable[ index ]; + f = f16tof32( res16 ); + sum += f; + } + else + f = 0; + + temp[ i ] = f; + } + return sum; +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 dtid: SV_DispatchThreadID ) +{ + const uint neq0 = q_elements[ 0 ]; + const uint neq1 = q_elements[ 1 ]; + const uint neq2 = q_elements[ 2 ]; + const uint neq3 = q_elements[ 3 ]; + + const uint nek0 = k_elements[ 0 ]; + const uint nek1 = k_elements[ 1 ]; + + const uint nev1 = v_elements[ 1 ]; + + const uint ne0 = res_elements[ 0 ]; + const uint ne1 = res_elements[ 1 ]; + + const uint nbk0 = k_strides[ 0 ]; + const uint nbk1 = k_strides[ 1 ]; + const uint nbk2 = k_strides[ 2 ]; + const uint nbk3 = k_strides[ 3 ]; + + const uint nbq0 = q_strides[ 0 ]; + const uint nbq1 = q_strides[ 1 ]; + const uint nbq2 = q_strides[ 2 ]; + const uint nbq3 = q_strides[ 3 ]; + + const uint nbv0 = v_strides[ 0 ]; + const uint nbv1 = v_strides[ 1 ]; + const uint nbv2 = v_strides[ 2 ]; + const uint nbv3 = v_strides[ 3 ]; + + const uint nb0 = res_strides[ 0 ]; + const uint nb1 = res_strides[ 1 ]; + const uint nb2 = res_strides[ 2 ]; + const uint nb3 = res_strides[ 3 ]; + + const uint D = neq0; + const uint N = neq1; + const uint P = nek1 - N; + // const uint M = P + N; + const uint M = nek1; + + const uint ir = dtid.x; + if( ir >= neq1 * neq2 * neq3 ) + return; + + const uint iq3 = ir / ( neq2 * neq1 ); + const uint iq2 = ( ir - iq3 * neq2 * neq1 ) / neq1; + const uint iq1 = ( ir - iq3 * neq2 * neq1 - iq2 * neq1 ); + + const uint tempIndex = ir * tempBufferStride; + + // Softmax + float tvm = computeTempVectorMax( tempIndex, M ); + double sum = applySoftMax( tempIndex, M, tvm ); + + scaleTempVector( tempIndex, M, (float)( 1.0 / sum ) ); +}
\ No newline at end of file diff --git a/ComputeShaders/flashAttentionCompat3.hlsl b/ComputeShaders/flashAttentionCompat3.hlsl new file mode 100644 index 0000000..e0a4061 --- /dev/null +++ b/ComputeShaders/flashAttentionCompat3.hlsl @@ -0,0 +1,118 @@ +// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups +#include "flashAttentionCommon.hlsli" +#include "groupReduce.hlsli" +#include "miscUtils.hlsli" + +inline void roundTempVector( uint i, const uint len, const uint thread ) +{ + const uint iEnd = i + len; + for( i += thread; i < iEnd; i += 32 ) + { + float f = temp[ i ]; + f = roundToFp16( f ); + temp[ i ] = f; + } +} + +inline void computeDotProduct( Buffer<float> buff0, RWBuffer<float> buff1, uint s0, uint s1, const uint len, const uint thread, inout float acc ) +{ + acc = 0; +/* const uint s0End = s0 + len; + s0 += thread; + s1 += thread; + for( ; s0 < s0End; s0 += 32, s1 += 32 ) + acc = mad( buff0[ s0 ], buff1[ s1 ], acc ); + + horizontalSumCompatNew( thread, acc ); */ + const uint completeVectors = len / 32; + uint i; + for( i = 0; i < completeVectors; i++, s0 += 32, s1 += 32 ) + acc = mad( buff0[ s0 + thread ], buff1[ s1 + thread ], acc ); + + horizontalSumCompatNew( thread, acc ); + + if( 0 == thread ) + { + const uint rem = len % 32; + if( 0 != rem ) + { + double f64 = acc; + for( i = 0; i < rem; i++ ) + { + precise float a = buff0[ s0 + i ]; + precise float b = buff1[ s1 + i ]; + precise float prod = a * b; + f64 += prod; + } + acc = (float)f64; + } + } +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint neq0 = q_elements[ 0 ]; + const uint neq1 = q_elements[ 1 ]; + const uint neq2 = q_elements[ 2 ]; + const uint neq3 = q_elements[ 3 ]; + + const uint nek0 = k_elements[ 0 ]; + const uint nek1 = k_elements[ 1 ]; + + const uint nev1 = v_elements[ 1 ]; + + const uint ne0 = res_elements[ 0 ]; + const uint ne1 = res_elements[ 1 ]; + + const uint nbk0 = k_strides[ 0 ]; + const uint nbk1 = k_strides[ 1 ]; + const uint nbk2 = k_strides[ 2 ]; + const uint nbk3 = k_strides[ 3 ]; + + const uint nbq0 = q_strides[ 0 ]; + const uint nbq1 = q_strides[ 1 ]; + const uint nbq2 = q_strides[ 2 ]; + const uint nbq3 = q_strides[ 3 ]; + + const uint nbv0 = v_strides[ 0 ]; + const uint nbv1 = v_strides[ 1 ]; + const uint nbv2 = v_strides[ 2 ]; + const uint nbv3 = v_strides[ 3 ]; + + const uint nb0 = res_strides[ 0 ]; + const uint nb1 = res_strides[ 1 ]; + const uint nb2 = res_strides[ 2 ]; + const uint nb3 = res_strides[ 3 ]; + + const uint D = neq0; + const uint N = neq1; + const uint P = nek1 - N; + // const uint M = P + N; + const uint M = nek1; + + const uint ir = group.x; + const uint iq3 = ir / ( neq2 * neq1 ); + const uint iq2 = ( ir - iq3 * neq2 * neq1 ) / neq1; + const uint iq1 = ( ir - iq3 * neq2 * neq1 - iq2 * neq1 ); + + const uint tempIndex = ir * tempBufferStride; + + roundTempVector( tempIndex, nek1, thread ); + AllMemoryBarrierWithGroupSync(); + + uint rdi = iq1 * nb1 + iq2 * nb2 + iq3 * nb3; + for( uint ic = 0; ic < nev1; ic++, rdi += nb0 ) + { + // dst indices + const uint i1 = iq1; + const uint i2 = iq2; + const uint i3 = iq3; + + const uint s0 = ic * nbv1 + i2 * nbv2 + i3 * nbv3; + float dp; + computeDotProduct( v, temp, s0, tempIndex, nek1, thread, dp ); + if( 0 == thread ) + result[ rdi ] = dp; + } +}
\ No newline at end of file diff --git a/ComputeShaders/fmaRepeat1.hlsl b/ComputeShaders/fmaRepeat1.hlsl new file mode 100644 index 0000000..3db3827 --- /dev/null +++ b/ComputeShaders/fmaRepeat1.hlsl @@ -0,0 +1,77 @@ +// Implementation of fmaRepeat() when both source arguments have same size and strides +// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor +RWBuffer<float> tensor: register( u0 ); +Buffer<float> patternMul: register( t0 ); +Buffer<float> patternAdd: register( t1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSize: packoffset( c2 ); + uint4 patternStrides: packoffset( c3 ); +} + +#ifndef THREADS +#define THREADS 512 +#endif + +#include "repeatUtils.hlsli" + +inline void computeSimple( uint idx, float mul, float add ) +{ + precise float f = tensor[ idx ]; + f *= mul; + f += add; + tensor[ idx ] = f; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides ); + uint rsi = rowOffset( group % patternSize.yzw, patternStrides ); + + if( patternSize[ 0 ] == 1 ) + { + // The pattern only has 1 column - broadcasting over the row + const float pMul = patternMul[ rsi ]; + const float pAdd = patternAdd[ rsi ]; + ROW_LOOP( it ) + computeSimple( it.x, pMul, pAdd ); + } + else if( patternSize[ 0 ] <= THREADS ) + { + // pattern size doesn't exceed thread group size: load pattern value outside of the loop + const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] ); + if( thread >= threadsPerGroup ) + return; + + rsi += ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ]; + const float pMul = patternMul[ rsi ]; + const float pAdd = patternAdd[ rsi ]; + ROW_LOOP_EX( it, threadsPerGroup, tensorStrides ) + computeSimple( it.x, pMul, pAdd ); + } + else + { + // Pattern rows are larger than the thread group, need to stream from both buffers + const uint rsiInc = THREADS * patternStrides[ 0 ]; + const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ]; + const uint rsiEnd = rsi + rsiDec; + rsi += thread * patternStrides[ 0 ]; + + ROW_LOOP( it ) + { + precise float f = tensor[ it.x ]; + float mul = patternMul[ rsi ]; + float add = patternAdd[ rsi ]; + rsi += rsiInc; + if( rsi >= rsiEnd ) + rsi -= rsiDec; + f *= mul; + f += add; + tensor[ it.x ] = f; + } + } +}
\ No newline at end of file diff --git a/ComputeShaders/fmaRepeat164.hlsl b/ComputeShaders/fmaRepeat164.hlsl new file mode 100644 index 0000000..99813f6 --- /dev/null +++ b/ComputeShaders/fmaRepeat164.hlsl @@ -0,0 +1,2 @@ +#define THREADS 64 +#include "fmaRepeat1.hlsl"
\ No newline at end of file diff --git a/ComputeShaders/fmaRepeat2.hlsl b/ComputeShaders/fmaRepeat2.hlsl new file mode 100644 index 0000000..edadf0a --- /dev/null +++ b/ComputeShaders/fmaRepeat2.hlsl @@ -0,0 +1,45 @@ +// Implementation of fmaRepeat() when source arguments have different shape or VRAM layout +// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor +RWBuffer<float> tensor: register( u0 ); +Buffer<float> patternMul: register( t0 ); +Buffer<float> patternAdd: register( t1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSizeMul: packoffset( c2 ); + uint4 patternStridesMul: packoffset( c3 ); + uint4 patternSizeAdd: packoffset( c4 ); + uint4 patternStridesAdd: packoffset( c5 ); +} + +#ifndef THREADS +#define THREADS 32 +#endif + +#include "repeatUtils.hlsli" + +inline float loadPattern( Buffer<float> buffer, uint rowStart, uint i, uint4 size, uint4 stride ) +{ + i %= size.x; + return buffer[ i * stride.x + rowStart ]; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides ); + const uint rsiMul = rowOffset( group % patternSizeMul.yzw, patternStridesMul ); + const uint rsiAdd = rowOffset( group % patternSizeAdd.yzw, patternStridesAdd ); + + for( uint i = thread; it.x < it.z; it.x += it.y, i++ ) + { + precise float f = tensor[ it.x ]; + float mul = loadPattern( patternMul, rsiMul, i, patternSizeMul, patternStridesMul ); + float add = loadPattern( patternAdd, rsiAdd, i, patternSizeAdd, patternStridesAdd ); + f *= mul; + f += add; + tensor[ it.x ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/fp64Utils.hlsli b/ComputeShaders/fp64Utils.hlsli new file mode 100644 index 0000000..9782718 --- /dev/null +++ b/ComputeShaders/fp64Utils.hlsli @@ -0,0 +1,28 @@ +// TODO: compile another version of these shader, and use it on GPUs with ExtendedDoublesShaderInstructions flag, will become slightly faster +// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_feature_data_d3d11_options +#ifndef ExtendedDoublesShaderInstructions +#define ExtendedDoublesShaderInstructions 0 +#endif + +// Compute num/den in FP64 precision +inline double div64( double num, double den ) +{ +#if ExtendedDoublesShaderInstructions + return num / den; +#else + // https://en.wikipedia.org/wiki/Division_algorithm#Newton%E2%80%93Raphson_division + double x = 1.0f / (float)den; + x += x * ( 1.0 - den * x ); + x += x * ( 1.0 - den * x ); + return num * x; +#endif +} + +// Compute sqrt(x) in FP64 precision +inline double sqrt64( double x ) +{ + double root = sqrt( (float)x ); + root = 0.5 * ( root + div64( x, root ) ); + root = 0.5 * ( root + div64( x, root ) ); + return root; +}
\ No newline at end of file diff --git a/ComputeShaders/groupReduce.hlsli b/ComputeShaders/groupReduce.hlsli new file mode 100644 index 0000000..1ffe1d8 --- /dev/null +++ b/ComputeShaders/groupReduce.hlsli @@ -0,0 +1,139 @@ +groupshared float sharedAccumulators[ 32 ]; + +// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group. +void horizontalSum( const uint thread, inout float sum ) +{ + sharedAccumulators[ thread ] = sum; + for( uint i = 16; i > 1; i /= 2 ) + { + GroupMemoryBarrierWithGroupSync(); + if( thread < i ) + { + sum += sharedAccumulators[ thread + i ]; + sharedAccumulators[ thread ] = sum; + } + } + GroupMemoryBarrierWithGroupSync(); + if( 0 == thread ) + sum += sharedAccumulators[ 1 ]; +} + +// Compute horisontal sum of the numbers, and broadcast to all threads of the group. +void horizontalSumBroadcast( const uint thread, inout float sum ) +{ + horizontalSum( thread, sum ); + if( 0 == thread ) + sharedAccumulators[ 0 ] = sum; + GroupMemoryBarrierWithGroupSync(); + sum = sharedAccumulators[ 0 ]; +} + +// Compute horisontal sum of the numbers, in the order equal to the CPU-running dot product implementation. +// The result is only correct on the thread #0 of the group. +void horizontalSumCompat( const uint thread, inout float sum ) +{ + sharedAccumulators[ thread ] = sum; + GroupMemoryBarrierWithGroupSync(); + + if( 0 == ( thread & 8 ) ) + { + // This runs on threads [ 0 .. 7 ] and [ 16 .. 23 ] + // sum01 = _mm256_add_ps( sum0, sum1 ); + // sum23 = _mm256_add_ps( sum2, sum3 ); + sum += sharedAccumulators[ thread + 8 ]; + sharedAccumulators[ thread ] = sum; + } + + GroupMemoryBarrierWithGroupSync(); + if( thread < 8 ) + { + // This runs on threads [ 0 .. 7 ] + // sum0123 = _mm256_add_ps( sum01, sum23 ); + sum += sharedAccumulators[ thread + 16 ]; + sharedAccumulators[ thread ] = sum; + } + + GroupMemoryBarrierWithGroupSync(); + if( thread < 4 ) + { + // const __m128 r4 = _mm_add_ps( _mm256_castps256_ps128( sum0123 ), _mm256_extractf128_ps( sum0123, 1 ) ); + sum += sharedAccumulators[ thread + 4 ]; + sharedAccumulators[ thread ] = sum; + } + + GroupMemoryBarrierWithGroupSync(); + if( thread < 2 ) + { + // const __m128 r2 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) ); + sum += sharedAccumulators[ thread + 2 ]; + sharedAccumulators[ thread ] = sum; + } + + GroupMemoryBarrierWithGroupSync(); + if( 0 == thread ) + { + // const __m128 r1 = _mm_add_ss( r2, _mm_movehdup_ps( r2 ) ); + sum += sharedAccumulators[ 1 ]; + } +} + +// Compute horisontal sum of the numbers, in yet another creative summation order recently implemented in the upstream +void horizontalSumCompatNew( const uint thread, inout float sum ) +{ + // GGML_F32x8_REDUCE + sharedAccumulators[ thread ] = sum; + GroupMemoryBarrierWithGroupSync(); + + if( 0 == ( thread & 8 ) ) + { + // Runs on threads [ 0 .. 7 ] and [ 16 .. 23 ] + sum += sharedAccumulators[ thread | 8 ]; + sharedAccumulators[ thread ] = sum; + } + GroupMemoryBarrierWithGroupSync(); + + if( thread < 8 ) + { + // Runs on threads [ 0 .. 7 ] + sum += sharedAccumulators[ thread | 0x10 ]; + sharedAccumulators[ thread ] = sum; + } + GroupMemoryBarrierWithGroupSync(); + + if( thread < 4 ) + { + // Runs on threads [ 0 .. 3 ] + sum += sharedAccumulators[ thread | 4 ]; + sharedAccumulators[ thread ] = sum; + } + GroupMemoryBarrierWithGroupSync(); + + if( thread < 4 && 0 == ( thread & 1 ) ) + { + // Runs on threads [ 0, 2 ] + sum += sharedAccumulators[ thread | 1 ]; + sharedAccumulators[ thread ] = sum; + } + GroupMemoryBarrierWithGroupSync(); + + if( 0 == thread ) + sum += sharedAccumulators[ 2 ]; +} + + +// Compute horizontal maximum of the numbers, and broadcast to all threads of the group. +void horizontalMaxBroadcast( const uint thread, inout float ax ) +{ + sharedAccumulators[ thread ] = ax; + for( uint i = 16; i > 0; i /= 2 ) + { + GroupMemoryBarrierWithGroupSync(); + if( thread < i ) + { + ax = max( ax, sharedAccumulators[ thread + i ] ); + sharedAccumulators[ thread ] = ax; + } + } + GroupMemoryBarrierWithGroupSync(); + ax = sharedAccumulators[ 0 ]; +}
\ No newline at end of file diff --git a/ComputeShaders/groupReduce64.hlsli b/ComputeShaders/groupReduce64.hlsli new file mode 100644 index 0000000..7094d03 --- /dev/null +++ b/ComputeShaders/groupReduce64.hlsli @@ -0,0 +1,46 @@ +groupshared float sharedAccumulators[ 64 ]; + +// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group. +void horizontalSum( const uint thread, inout float sum ) +{ + sharedAccumulators[ thread ] = sum; + for( uint i = 32; i > 1; i /= 2 ) + { + GroupMemoryBarrierWithGroupSync(); + if( thread < i ) + { + sum += sharedAccumulators[ thread + i ]; + sharedAccumulators[ thread ] = sum; + } + } + GroupMemoryBarrierWithGroupSync(); + if( 0 == thread ) + sum += sharedAccumulators[ 1 ]; +} + +// Compute horisontal sum of the numbers, and broadcast to all threads of the group. +void horizontalSumBroadcast( const uint thread, inout float sum ) +{ + horizontalSum( thread, sum ); + if( 0 == thread ) + sharedAccumulators[ 0 ] = sum; + GroupMemoryBarrierWithGroupSync(); + sum = sharedAccumulators[ 0 ]; +} + +// Compute horizontal maximum of the numbers, and broadcast to all threads of the group. +void horizontalMaxBroadcast( const uint thread, inout float ax ) +{ + sharedAccumulators[ thread ] = ax; + for( uint i = 32; i > 0; i /= 2 ) + { + GroupMemoryBarrierWithGroupSync(); + if( thread < i ) + { + ax = max( ax, sharedAccumulators[ thread + i ] ); + sharedAccumulators[ thread ] = ax; + } + } + GroupMemoryBarrierWithGroupSync(); + ax = sharedAccumulators[ 0 ]; +}
\ No newline at end of file diff --git a/ComputeShaders/matReshapePanels.hlsl b/ComputeShaders/matReshapePanels.hlsl new file mode 100644 index 0000000..f26f246 --- /dev/null +++ b/ComputeShaders/matReshapePanels.hlsl @@ -0,0 +1,105 @@ +// This shader reshapes a matrix into the shape expected by mulMatTiledEx.hlsl and mulMatByRowTiledEx.hlsl compute shaders +// It's called in runtime, also while loading models from disk. +// So far, it's only used when running on AMD GPUs. +#ifndef TILE_SIZE +static const uint TILE_SIZE = 32; +#endif + +// Input tensor +Buffer<float> source: register( t0 ); +// Output tensor +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 arg0Size: packoffset( c0 ); + uint4 arg0Strides: packoffset( c1 ); + // Count of elements per panel + uint panelSize : packoffset( c2.y ); + // Layer strides of the output matrix + uint2 layerStrides: packoffset( c2.z ); +} + +inline uint hadd( uint2 v2 ) { return v2.x + v2.y; } + +groupshared float tileBuffer[ TILE_SIZE ][ TILE_SIZE ]; + +[ numthreads( TILE_SIZE, 1, 1 ) ] +void main( const uint3 group: SV_GroupID, const uint thread : SV_GroupIndex ) +{ + uint rdi = hadd( group.yz * layerStrides ); + rdi += group.x * panelSize; + rdi += thread; + + uint rsi = hadd( group.yz * arg0Strides.zw ); + const uint baseY = group.x * TILE_SIZE; + const uint dispatchThread = baseY + thread; + // Reshaping into a column major horizontal panel, height = TILE_SIZE, width = width of the source matrix + uint width = arg0Size.x; + // Usually TILE_SIZE; can be less for the last panel on the matrix when we need to generate zeros instead of loading these numbers + const uint height = min( TILE_SIZE, arg0Size.y - baseY ); + + if( arg0Strides.x == 1 ) + { + // The input matrix is row major, can improve performance with coalesced loads and group shared buffer. + rsi += baseY * arg0Strides.y; + + const uint widthCompleteTiles = width / TILE_SIZE; + + if( height < TILE_SIZE ) + { + // This thread group was dispatched for the last panel of the matrix, it doesn't have enough rows + // Write zeros to the corresponding elements of the groupshared buffer + for( uint j = height; j < TILE_SIZE; j++ ) + tileBuffer[ thread ][ j ] = 0.0; + } + + for( uint i = 0; i < widthCompleteTiles; i++, rsi += TILE_SIZE ) + { + // Load [ TILE_SIZE ] * [ TILE_SIZE ] block with fully coalesced loads, store to group shared buffer in transposed order + uint rsiTile = rsi + thread; + uint j; + for( j = 0; j < height; j++, rsiTile += arg0Strides.y ) + { + // Each iteration of the loop loads a row of [ TILE_SIZE ] elements from the corresponding row of the source tensor + // Fully coalesced load + float f = source[ rsiTile ]; + // Random store but the local memory's fast, this works rather well in practice + tileBuffer[ thread ][ j ] = f; + } + + GroupMemoryBarrierWithGroupSync(); + + // Copy from group shared buffer to output tensor + for( j = 0; j < TILE_SIZE; j++, rdi += TILE_SIZE ) + { + // Fully coalesced loads and stores + float f = tileBuffer[ j ][ thread ]; + result[ rdi ] = f; + } + + GroupMemoryBarrierWithGroupSync(); + } + + width %= TILE_SIZE; + if( 0 == width ) + return; + rsi += thread * arg0Strides.y; + } + else + rsi += dispatchThread * arg0Strides.y; + + for( uint i = 0; i < width; i++ ) + { + float f; + [branch] + if( thread < height ) + f = source[ rsi ]; + else + f = 0.0; + rsi += arg0Strides.x; + + result[ rdi ] = f; + rdi += TILE_SIZE; + } +}
\ No newline at end of file diff --git a/ComputeShaders/miscUtils.hlsli b/ComputeShaders/miscUtils.hlsli new file mode 100644 index 0000000..b957a06 --- /dev/null +++ b/ComputeShaders/miscUtils.hlsli @@ -0,0 +1,84 @@ +// When GPUs are converting FP32 to FP16, they always truncate towards 0, documented there: +// https://learn.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-data-conversion#conververting-from-a-higher-range-representation-to-a-lower-range-representation +// Whisper code uses _mm_cvtps_ph( x, 0 ), the 0 stands for "Round to nearest even": https://www.felixcloutier.com/x86/vcvtps2ph +// This function adjusts FP32 value making it so that truncation towards 0 results in the value equal to what CPU is doing +inline float adjustFp16( const float src ) +{ + const uint trunc16 = f32tof16( src ); + const float trunc32 = f16tof32( trunc16 ); + + const uint truncExp = ( trunc16 >> 10 ) & 0x1F; + if( truncExp != 0x1F ) + { + const uint next16 = trunc16 + 1; + const float next32 = f16tof32( next16 ); + + const float errTrunc = abs( src - trunc32 ); + const float errNext = abs( src - next32 ); + + if( errTrunc < errNext ) + { + // Truncated was closer to the source + return src; + } + else if( errTrunc > errNext ) + { + // Truncated + 1 was closer to the source + return next32; + } + else + { + // Exactly half, doing banker's rounding to nearest even + return ( 0 == ( trunc16 & 1 ) ) ? src : next32; + } + } + else + { + // INF or NAN + return src; + } +} + +// Convert FP32 number to FP16, using rounding to nearest +inline uint fp16Rounded( const float src ) +{ + const uint trunc16 = f32tof16( src ); + const float trunc32 = f16tof32( trunc16 ); + + const uint truncExp = ( trunc16 >> 10 ) & 0x1F; + if( truncExp != 0x1F ) + { + const uint next16 = trunc16 + 1; + const float next32 = f16tof32( next16 ); + + const float errTrunc = abs( src - trunc32 ); + const float errNext = abs( src - next32 ); + + if( errTrunc < errNext ) + { + // Truncated was closer to the source + return trunc16; + } + else if( errTrunc > errNext ) + { + // Truncated + 1 was closer to the source + return next16; + } + else + { + // Exactly half, doing banker's rounding to nearest even + return ( 0 == ( trunc16 & 1 ) ) ? trunc16 : next16; + } + } + else + { + // INF or NAN + return trunc16; + } +} + +// Round up the number to be a multiple of 32 +inline uint roundUp32( uint x ) +{ + return ( x + 31 ) & ( ~31u ); +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatByRow.hlsl b/ComputeShaders/mulMatByRow.hlsl new file mode 100644 index 0000000..565cfdc --- /dev/null +++ b/ComputeShaders/mulMatByRow.hlsl @@ -0,0 +1,49 @@ +// Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ] +// Dispatch [ E1, E2, E3 ] groups of this shader +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 arg0Size: packoffset( c0 ); + uint4 arg0Strides: packoffset( c1 ); + uint4 arg1Size: packoffset( c2 ); + uint4 arg1Strides: packoffset( c3 ); + uint4 resultSize: packoffset( c4 ); + uint4 resultStrides: packoffset( c5 ); +} + +#include "groupReduce.hlsli" + +inline uint hadd( uint3 vec ) +{ + return vec.x + vec.y + vec.z; +} +inline uint hadd( uint2 vec ) +{ + return vec.x + vec.y; +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint s0 = hadd( group * arg0Strides.yzw ); + uint s1 = hadd( group.yz * arg1Strides.zw ); + const uint s0End = s0 + arg0Size.x * arg0Strides.x; + const uint s0Inc = 32 * arg0Strides.x; + const uint s1Inc = 32 * arg1Strides.x; + + s0 += thread * arg0Strides.x; + s1 += thread * arg1Strides.x; + float dp = 0; + for( ; s0 < s0End; s0 += s0Inc, s1 += s1Inc ) + dp = mad( arg0[ s0 ], arg1[ s1 ], dp ); + + horizontalSum( thread, dp ); + if( 0 != thread ) + return; + + const uint rdi = group.x + hadd( group.yz * resultStrides.zw ); + result[ rdi ] = dp; +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatByRow64.hlsl b/ComputeShaders/mulMatByRow64.hlsl new file mode 100644 index 0000000..db5f801 --- /dev/null +++ b/ComputeShaders/mulMatByRow64.hlsl @@ -0,0 +1,90 @@ +// Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ] +// Dispatch [ E1, E2, E3 ] groups of this shader +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 arg0Size: packoffset( c0 ); + uint4 arg0Strides: packoffset( c1 ); + uint4 arg1Size: packoffset( c2 ); + uint4 arg1Strides: packoffset( c3 ); + uint4 resultSize: packoffset( c4 ); + uint4 resultStrides: packoffset( c5 ); +} + +inline uint hadd( uint3 vec ) +{ + return vec.x + vec.y + vec.z; +} +inline uint hadd( uint2 vec ) +{ + return vec.x + vec.y; +} + +// No idea why, but that particular configuration appears to be the fastest one on Ryzen 7 5700G iGPU +// Not by much, though: when trying a few numbers I saw 1.30 - 1.42 seconds for this compute shader +static const uint THREADS = 64; +static const uint REDUCTION_BUFFER = 32; +groupshared float sharedAccumulators[ REDUCTION_BUFFER ]; + +// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group. +void horizontalSum( const uint thread, inout float sum ) +{ + if( THREADS > REDUCTION_BUFFER ) + { + for( uint t = REDUCTION_BUFFER; t < THREADS; t += REDUCTION_BUFFER ) + { + // Threads [ t .. t + REDUCTION_BUFFER ] store into the buffer + if( thread >= t && thread < t + REDUCTION_BUFFER ) + sharedAccumulators[ thread - t ] = sum; + + GroupMemoryBarrierWithGroupSync(); + + // Threads [ 0 .. REDUCTION_BUFFER ] increment their local sum with the value loaded from the buffer + if( thread < REDUCTION_BUFFER ) + sum += sharedAccumulators[ thread ]; + } + } + + if( thread < REDUCTION_BUFFER ) + sharedAccumulators[ thread ] = sum; + + for( uint i = REDUCTION_BUFFER / 2; i > 1; i /= 2 ) + { + GroupMemoryBarrierWithGroupSync(); + if( thread < i ) + { + sum += sharedAccumulators[ thread + i ]; + sharedAccumulators[ thread ] = sum; + } + } + + GroupMemoryBarrierWithGroupSync(); + if( 0 == thread ) + sum += sharedAccumulators[ 1 ]; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint s0 = hadd( group * arg0Strides.yzw ); + uint s1 = hadd( group.yz * arg1Strides.zw ); + const uint s0End = s0 + arg0Size.x * arg0Strides.x; + const uint s0Inc = THREADS * arg0Strides.x; + const uint s1Inc = THREADS * arg1Strides.x; + + s0 += thread * arg0Strides.x; + s1 += thread * arg1Strides.x; + float dp = 0; + for( ; s0 < s0End; s0 += s0Inc, s1 += s1Inc ) + dp = mad( arg0[ s0 ], arg1[ s1 ], dp ); + + horizontalSum( thread, dp ); + if( 0 != thread ) + return; + + const uint rdi = group.x + hadd( group.yz * resultStrides.zw ); + result[ rdi ] = dp; +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatByRowTiled.hlsl b/ComputeShaders/mulMatByRowTiled.hlsl new file mode 100644 index 0000000..fea2fcb --- /dev/null +++ b/ComputeShaders/mulMatByRowTiled.hlsl @@ -0,0 +1,120 @@ +// Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ] +// Dispatch [ ( E1 + TILE_Y - 1 ) / TILE_Y, E2, E3 ] thread groups of this shader + +#ifndef TILE_Y +static const uint TILE_Y = 64; +#endif +#ifndef THREADS_X +static const uint THREADS_X = 32; +#endif +#ifndef THREADS_Y +static const uint THREADS_Y = 16; +#endif + +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 arg0Size: packoffset( c0 ); + uint4 arg0Strides: packoffset( c1 ); + uint4 arg1Size: packoffset( c2 ); + uint4 arg1Strides: packoffset( c3 ); + uint4 resultSize: packoffset( c4 ); + uint4 resultStrides: packoffset( c5 ); +} + +groupshared float resTemp[ TILE_Y ][ THREADS_X ]; + +inline uint hadd( uint2 vec ) +{ + return vec.x + vec.y; +} + +[ numthreads( THREADS_X, THREADS_Y, 1 ) ] +void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID, uint threadFlattenned : SV_GroupIndex ) +{ + uint i; + // Zero out the shared buffer + for( i = thread.y; i < TILE_Y; i += THREADS_Y ) + resTemp[ i ][ thread.x ] = 0.0; + GroupMemoryBarrierWithGroupSync(); + + // Count of rows to compute in this thread group + const uint height = min( TILE_Y, arg0Size.y - group.x * TILE_Y ); + + uint s0 = hadd( group.yz * arg0Strides.zw ); //< arg0 layer for the thread group + s0 += group.x * TILE_Y * arg0Strides.y; //< arg0 first row for the thread group + s0 += hadd( arg0Strides.xy * thread.xy ); //< arg0 load index for the thread + + uint s1 = hadd( group.yz * arg1Strides.zw ); //< arg1 layer for the thread group + s1 += thread.x * arg1Strides.x; //< arg1 load index for the thread + + const uint completeTiles = arg0Size.x / THREADS_X; + // Each iteration of that loop loads THREADS_X elements from arg1, + // a block of [ THREADS_X, height ] elements from arg0, + // and accumulates these dot products in the shared buffer + for( uint t = 0; t < completeTiles; t++, s0 += THREADS_X * arg0Strides.x, s1 += THREADS_X * arg1Strides.x ) + { + // Load THREADS_X elements from arg1 + const float v1 = arg1[ s1 ]; + + uint rsi = s0; + for( i = thread.y; i < height; i += THREADS_Y, rsi += arg0Strides.y * THREADS_Y ) + { + // Load THREADS_X elements from arg0 + const float v0 = arg0[ rsi ]; + // Multiply and accumulate in the shared buffer + float acc = resTemp[ i ][ thread.x ]; + acc = mad( v0, v1, acc ); + resTemp[ i ][ thread.x ] = acc; + } + GroupMemoryBarrierWithGroupSync(); + } + + const uint rem = arg0Size.x % THREADS_X; + if( rem != 0 ) + { + // E0 ain't a multiple of THREADS_X, we have a remainder + float v1; + if( thread.x < rem ) + v1 = arg1[ s1 ]; + else + v1 = 0.0; + + for( i = thread.y; i < height; i += THREADS_Y, s0 += arg0Strides.y * THREADS_Y ) + { + if( thread.x >= rem ) + continue; + const float v0 = arg0[ s0 ]; + float acc = resTemp[ i ][ thread.x ]; + acc = mad( v0, v1, acc ); + resTemp[ i ][ thread.x ] = acc; + } + GroupMemoryBarrierWithGroupSync(); + } + + // Now we need horizontal sums of these shared accumulators, i.e. reduce [height][THREADS_X] shared array into [height][1] column + for( i = THREADS_X / 2; i > 0; i /= 2 ) + { + if( thread.x < i ) + { + for( uint j = thread.y; j < height; j += THREADS_Y ) + { + float sum = resTemp[ j ][ thread.x ]; + sum += resTemp[ j ][ thread.x + i ]; + resTemp[ j ][ thread.x ] = sum; + } + } + GroupMemoryBarrierWithGroupSync(); + } + + // And finally, store that column to global memory + if( threadFlattenned >= height ) + return; + + uint rdi = hadd( group.yz * resultStrides.zw ) + group.x * TILE_Y * resultStrides.x; + rdi += threadFlattenned * resultStrides.x; + result[ rdi ] = resTemp[ threadFlattenned ][ 0 ]; +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatByRowTiled64.hlsl b/ComputeShaders/mulMatByRowTiled64.hlsl new file mode 100644 index 0000000..6c63f2d --- /dev/null +++ b/ComputeShaders/mulMatByRowTiled64.hlsl @@ -0,0 +1,4 @@ +#define THREADS_Y 32 +#define THREADS_X 32 +#define TILE_Y 128 +#include "mulMatByRowTiled.hlsl"
\ No newline at end of file diff --git a/ComputeShaders/mulMatByRowTiledEx.hlsl b/ComputeShaders/mulMatByRowTiledEx.hlsl new file mode 100644 index 0000000..d377b8c --- /dev/null +++ b/ComputeShaders/mulMatByRowTiledEx.hlsl @@ -0,0 +1,156 @@ +// matrix*row vector product, needs first argument reshaped into a sequence of horizontal column major panels +#ifndef TILE_SIZE +static const uint TILE_SIZE = 32; +#endif +#ifndef TILE_HEIGHT +static const uint TILE_HEIGHT = 32; +#endif +#ifndef THREADS_Y +static const uint THREADS_Y = 16; +#endif + +// First tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ] +Buffer<float> arg0: register( t0 ); +// Second tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ] +Buffer<float> arg1: register( t1 ); +// FP32 output tensor, row major and continuous +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 arg0Size: packoffset( c0 ); + uint arg0panel: packoffset( c1.y ); + uint2 arg0LayerStrides: packoffset( c1.z ); + // uint4 arg1Size: packoffset( c2 ); + uint4 arg1Strides: packoffset( c3 ); + uint4 resultSize: packoffset( c4 ); + uint4 resultStrides: packoffset( c5 ); +} + +groupshared float tileOutput[ THREADS_Y ][ TILE_SIZE ]; +groupshared float tile0[ TILE_HEIGHT ][ TILE_SIZE ]; +groupshared float tile1[ TILE_HEIGHT ]; + +void multiplyTiles( const uint3 thread ) +{ + float r = 0.0; + for( uint i = thread.y; i < TILE_HEIGHT; i += THREADS_Y ) + { + float a = tile0[ i ][ thread.x ]; + float b = tile1[ i ]; + r = mad( a, b, r ); + } + tileOutput[ thread.y ][ thread.x ] += r; +} + +void reduceOutput( const uint3 thread ) +{ + float curr = 0.0; + [branch] + if( thread.y < THREADS_Y / 2 ) + curr = tileOutput[ thread.y ][ thread.x ]; + + for( uint i = THREADS_Y / 2; i > 0; i /= 2 ) + { + [branch] + if( thread.y < i ) + { + curr += tileOutput[ thread.y + i ][ thread.x ]; + tileOutput[ thread.y ][ thread.x ] = curr; + } + GroupMemoryBarrierWithGroupSync(); + } +} + +void storeTile( const uint threadFlat, const uint4 pos, const uint size ) +{ + if( threadFlat >= size ) + return; + const uint4 prod4 = pos * resultStrides; + const uint2 prod2 = prod4.xy + prod4.zw; + uint rdi = prod2.x + prod2.y; + result[ rdi + threadFlat ] = tileOutput[ 0 ][ threadFlat ]; +} + +[ numthreads( TILE_SIZE, THREADS_Y, 1 ) ] +void main( const uint3 group: SV_GroupID, const uint3 thread : SV_GroupThreadID, uint threadFlat : SV_GroupIndex ) +{ + uint i; + // Zero all 3 shared buffers + tileOutput[ thread.y ][ thread.x ] = 0.0; + for( i = thread.y; i < TILE_HEIGHT; i += THREADS_Y ) + tile0[ i ][ thread.x ] = 0.0; + if( threadFlat < THREADS_Y ) + tile1[ threadFlat ] = 0.0; + + const uint2 layer = group.yz; + uint rsi0 = group.x * arg0panel + layer.x * arg0LayerStrides.x + layer.y * arg0LayerStrides.y; + uint rsi1 = layer.x * arg1Strides.z + layer.y * arg1Strides.w; + + const uint threadOffset = thread.y * TILE_SIZE + thread.x; + rsi0 += threadOffset; + rsi1 += threadFlat * arg1Strides.x; + + const uint completeTiles = arg0Size.x / TILE_HEIGHT; + for( i = 0; i < completeTiles; i++ ) + { + // Load [ TILE_SIZE, TILE_HEIGHT ] block from the first source tensor into the groupshared buffer + for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y ) + { + tile0[ j ][ thread.x ] = arg0[ rsi0 ]; + rsi0 += THREADS_Y * TILE_SIZE; + } + // Load [ TILE_HEIGHT ] row from the second source into another groupshared buffer + [ branch ] + if( threadFlat < TILE_HEIGHT ) + tile1[ threadFlat ] = arg1[ rsi1 ]; + rsi1 += TILE_HEIGHT * arg1Strides.x; + + GroupMemoryBarrierWithGroupSync(); + + multiplyTiles( thread ); + + GroupMemoryBarrierWithGroupSync(); + } + + const uint rem = arg0Size.x % TILE_HEIGHT; + if( rem != 0 ) + { + for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y ) + { + float a; + [branch] + if( j < rem ) + { + a = arg0[ rsi0 ]; + rsi0 += THREADS_Y * TILE_SIZE; + } + else + a = 0.0; + tile0[ j ][ thread.x ] = a; + } + + if( threadFlat < TILE_HEIGHT ) + { + float b; + [branch] + if( threadFlat < rem ) + b = arg1[ rsi1 ]; + else + b = 0.0; + tile1[ threadFlat ] = b; + } + + GroupMemoryBarrierWithGroupSync(); + + multiplyTiles( thread ); + + GroupMemoryBarrierWithGroupSync(); + } + + reduceOutput( thread ); + + const uint resultPos = group.x * TILE_SIZE; + const uint outputSize = min( TILE_SIZE, resultSize.x - resultPos ); + storeTile( threadFlat, uint4( resultPos, 0, layer ), outputSize ); +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatByScalar.hlsl b/ComputeShaders/mulMatByScalar.hlsl new file mode 100644 index 0000000..82df5d4 --- /dev/null +++ b/ComputeShaders/mulMatByScalar.hlsl @@ -0,0 +1,41 @@ +// Matrix * scalar product, like [ 1, E1, E2, E3 ] * [ 1, 1, E2, E3 ] = [ E1, 1, E2, E3 ] +// Dispatch [ E2, E3, 1 ] thread groups of this shader +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 arg0Size: packoffset( c0 ); + uint4 arg0Strides: packoffset( c1 ); + uint4 arg1Size: packoffset( c2 ); + uint4 arg1Strides: packoffset( c3 ); + uint4 resultSize: packoffset( c4 ); + uint4 resultStrides: packoffset( c5 ); +} + +inline uint hadd( uint2 vec ) +{ + return vec.x + vec.y; +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const float scalarValue = arg1[ hadd( group.xy * arg1Strides.zw ) ]; + + uint s0 = hadd( group.xy * arg0Strides.zw ); + const uint s0Inc = 32 * arg0Strides.y; + s0 += thread * arg0Strides.y; + + uint rdi = hadd( group.xy * resultStrides.zw ); + const uint rdiEnd = rdi + arg0Size.y; + rdi += thread; + + for( ; rdi < rdiEnd; rdi += 32, s0 += s0Inc ) + { + float f = arg0[ s0 ]; + f *= scalarValue; + result[ rdi ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatDotMain.hlsl b/ComputeShaders/mulMatDotMain.hlsl new file mode 100644 index 0000000..47c6d3e --- /dev/null +++ b/ComputeShaders/mulMatDotMain.hlsl @@ -0,0 +1,95 @@ +// GGML_TASK_COMPUTE step for matrix*matrix product, where nb01 >= nb00; +// Dispatch with [ ne11, ne01*ne02*ne03 ] thread groups +// Each thread group computes a single dot product +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + uint4 src1_elements: packoffset( c2 ); + uint4 result_elements: packoffset( c4 ); + uint4 result_strides: packoffset( c5 ); +} + +inline uint product( uint3 vec ) +{ + return vec.x * vec.y * vec.z; +} + +inline uint product( uint4 vec ) +{ + uint2 tmp = vec.xy * vec.zw; + return tmp.x * tmp.y; +} + +inline float dotProductInner( uint i0, uint i1, uint length, uint thread ) +{ + float res = 0; + for( uint i = thread; i < length; i += 32 ) + res = mad( arg0[ i0 + i ], arg1[ i1 + i ], res ); + return res; +} + +#include "groupReduce.hlsli" + +[numthreads( 32, 1, 1 )] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint ne00 = src0_elements.x; + const uint ne01 = src0_elements.y; + const uint ne02 = src0_elements.z; + const uint ne03 = src0_elements.w; + + const uint ne10 = src1_elements.x; + const uint ne11 = src1_elements.y; + const uint ne12 = src1_elements.z; + const uint ne13 = src1_elements.w; + + const int nb00 = src0_strides.x; + const int nb01 = src0_strides.y; + const int nb02 = src0_strides.z; + const int nb03 = src0_strides.w; + + // total rows in src0 + // const int nr = ne01*ne02*ne03; + const uint nr = product( src0_elements.yzw ); + + const uint ir = group.y; + + // src0 indices + const uint i03 = ir / ( ne02 * ne01 ); + const uint i02 = ( ir - i03 * ne02 * ne01 ) / ne01; + const uint i01 = ( ir - i03 * ne02 * ne01 - i02 * ne01 ); + + const uint i13 = i03; + const uint i12 = i02; + + const uint i0 = i01; + const uint i2 = i02; + const uint i3 = i03; + + // src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + // src1_col = wdata + ( i13 * ne12 * ne11 + i12 * ne11 + 0 ) * ne00; + const uint src0_row = i01 * nb01 + i02 * nb02 + i03 * nb03; + const uint src1_col = ( i13 * ne12 * ne11 + i12 * ne11 ) * ne00; + + const uint ic = group.x; + float curr = dotProductInner( src0_row, src1_col + ic * ne00, ne00, thread ); + horizontalSumCompatNew( thread, curr ); + + if( 0 != thread ) + return; + + const uint nb0 = result_strides.x; + const uint nb1 = result_strides.y; + const uint nb2 = result_strides.z; + const uint nb3 = result_strides.w; + + const uint ne0 = result_elements.x; + // float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + const uint dst_col = i0 * nb0 + i2 * nb2 + i3 * nb3; + result[ dst_col + ic * ne0 ] = curr; +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatDotReshape.hlsl b/ComputeShaders/mulMatDotReshape.hlsl new file mode 100644 index 0000000..ffb6f83 --- /dev/null +++ b/ComputeShaders/mulMatDotReshape.hlsl @@ -0,0 +1,33 @@ +// GGML_TASK_INIT step for matrix*matrix product, where nb01 >= nb00; +// Dispatch with [ ne11, ne12 ] groups +Buffer<float> arg0: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); +} + +#include "miscUtils.hlsli" + +// Each thread group of this shader copies a single rows of the matrix +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint i12 = group.y; + const uint i11 = group.x; + const uint ne10 = src0_elements.x; + const uint ne11 = src0_elements.y; + const uint nb12 = src0_strides.z; + const uint nb11 = src0_strides.y; + + uint rdi = i11 * ne10 + i12 * ne10 * ne11; + const uint rdiEnd = rdi + ne10; + uint rsi = i12 * nb12 + i11 * nb11; + rdi += thread; + rsi += thread; + + for( ; rdi < rdiEnd; rdi += 32, rsi += 32 ) + result[ rdi ] = adjustFp16( arg0[ rsi ] ); +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatMadMain.hlsl b/ComputeShaders/mulMatMadMain.hlsl new file mode 100644 index 0000000..0bd5753 --- /dev/null +++ b/ComputeShaders/mulMatMadMain.hlsl @@ -0,0 +1,154 @@ +// GGML_TASK_COMPUTE step for matrix*matrix product, where nb01 < nb00 +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> resultTensor: register( u0 ); +RWBuffer<float> tempBuffer: register( u1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 aSize: packoffset( c0 ); + uint4 aStride: packoffset( c1 ); + uint4 bSize: packoffset( c2 ); + uint4 bStride: packoffset( c3 ); + uint4 resSize: packoffset( c4 ); + bool resultFp16 : packoffset( c5.x ); + uint ne: packoffset( c5.y ); +} + +#include "miscUtils.hlsli" + +// tempBuffer[ rdi .. ] = 0.0 +inline void writeTempZeros( uint rdi, const uint len, const uint thread ) +{ + const uint rdiEnd = rdi + len; + for( rdi += thread; rdi < rdiEnd; rdi += 32 ) + tempBuffer[ rdi ] = 0.0; +} + +// tempBuffer[ rdi .. ] += mul * arg0[ rsi .. ] +inline void vectorMad( uint rsi, uint rdi, const uint len, const float mul, const uint thread ) +{ + const uint rsiEnd = rsi + len; + rsi += thread; + rdi += thread; + for( ; rsi < rsiEnd; rsi += 32, rdi += 32 ) + { + float f = tempBuffer[ rdi ]; + f = mad( mul, arg0[ rsi ], f ); + [branch] + if( resultFp16 ) + f = adjustFp16( f ); + tempBuffer[ rdi ] = f; + } +} + +// resultTensor[ rdi .. ] = tempBuffer[ rsi .. ] +inline void copyRow( uint rsi, uint rdi, const uint len, const uint thread ) +{ + const uint rsiEnd = rsi + len; + rsi += thread; + rdi += thread; + for( ; rsi < rsiEnd; rsi += 32, rdi += 32 ) + { + float f = tempBuffer[ rsi ]; + resultTensor[ rdi ] = f; + } +} + +// resultTensor[ rdi .. ] += tempBuffer[ rsi .. ] +inline void addRow( uint rsi, uint rdi, const uint len, const uint thread ) +{ + const uint rsiEnd = rsi + len; + rsi += thread; + rdi += thread; + for( ; rsi < rsiEnd; rsi += 32, rdi += 32 ) + { + float f = resultTensor[ rdi ]; + f += tempBuffer[ rsi ]; + resultTensor[ rdi ] = f; + } +} + +[numthreads( 32, 1, 1 )] +void main( const uint3 group: SV_GroupID, const uint thread : SV_GroupIndex ) +{ + const uint i1 = group[ 0 ]; + const uint i2 = group[ 1 ]; + const uint i3 = group[ 2 ]; + + const uint ne00 = aSize[ 0 ]; + const uint ne01 = aSize[ 1 ]; + const uint ne02 = aSize[ 2 ]; + const uint ne03 = aSize[ 3 ]; + + const uint ne10 = bSize[ 0 ]; + const uint ne11 = bSize[ 1 ]; + const uint ne12 = bSize[ 2 ]; + const uint ne13 = bSize[ 3 ]; + + const uint ne0 = resSize[ 0 ]; + const uint ne1 = resSize[ 1 ]; + const uint ne2 = resSize[ 2 ]; + const uint ne3 = resSize[ 3 ]; + + const uint nb00 = aStride[ 0 ]; + const uint nb01 = aStride[ 1 ]; + const uint nb02 = aStride[ 2 ]; + const uint nb03 = aStride[ 3 ]; + + const uint nb10 = bStride[ 0 ]; + const uint nb11 = bStride[ 1 ]; + const uint nb12 = bStride[ 2 ]; + const uint nb13 = bStride[ 3 ]; + + // dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + const uint tempRowThread0 = i3 * ne2 * ne1 * ne0 + i2 * ne1 * ne0 + i1 * ne0; + + // Faking 4 CPU threads trying to achieve bitwise compatibility with the CPU version + const uint nth = 4; + + // GGML_TASK_COMPUTE + { + // src0_col = src0->data + ( i00 * nb00 + i02 * nb02 + i03 * nb03 ); + const uint aBase = i2 * nb02 + i3 * nb03; + // src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + const uint bBase = i1 * nb11 + i2 * nb12 + i3 * nb13; + + // total columns in src1 + const uint nc = ne10; + // columns per thread + const uint dc = ( nc + nth - 1 ) / nth; + + uint tempRow = tempRowThread0; + for( uint ith = 0; ith < nth; ith++, tempRow += ne ) + { + writeTempZeros( tempRow, ne01, thread ); + + // column range for this thread + const uint ic0 = dc * ith; + const uint ic1 = min( ic0 + dc, nc ); + + for( uint ic = ic0; ic < ic1; ic++ ) + { + const uint idxA = aBase + ic * aStride[ 0 ]; + const uint idxB = bBase + ic * bStride[ 0 ]; + const float bValue = arg1[ idxB ]; + vectorMad( idxA, tempRow, ne01, bValue, thread ); + } + } + } + + // GGML_TASK_FINALIZE + { + const uint rdi = tempRowThread0; + // const uint rdi = i1 * resSize[ 0 ] + i2 * resSize[ 0 ] * resSize[ 1 ] + i3 * resSize[ 0 ] * resSize[ 1 ] * resSize[ 2 ]; + // const uint rdi = ( ( i3 * resSize[ 2 ] + i2 ) * resSize[ 1 ] + i1 ) * resSize[ 0 ]; + + uint tempRow = tempRowThread0; + copyRow( tempRow, rdi, ne01, thread ); + + tempRow += ne; + for( uint ith = 1; ith < nth; ith++, tempRow += ne ) + addRow( tempRow, rdi, ne01, thread ); + } +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatTiled.hlsl b/ComputeShaders/mulMatTiled.hlsl new file mode 100644 index 0000000..7e4d7d8 --- /dev/null +++ b/ComputeShaders/mulMatTiled.hlsl @@ -0,0 +1,236 @@ +// This compute shader implements matrix*matrix product, using tiling and other tricks to improve the performance +#ifndef TILE_SIZE +static const uint TILE_SIZE = 32; +#endif + +#ifndef THREADS_Y +// Performance measures on Ryzen 7 5700G iGPU, the time is just for this shader: +// 1 (32 threads per group) - 17.1 seconds, 2 - 9.02424 seconds, 4 - 6.95762 seconds, 6 - 6.79011 seconds, 8 - 6.67279 seconds, 10 - 6.9456 seconds, 16 - 7.20502 seconds +// On nVidia, 8 is also the fastest option. +static const uint THREADS_Y = 8; +#endif + +#ifndef STREAM_SECOND_MATRIX +#define STREAM_SECOND_MATRIX 0 +#endif + +#ifndef LOAD_ORDER + +// Load with coalesced loads from global memory whenever possible, store into groupshared buffer with random stores +// #define LOAD_ORDER bool2( ( 1 == arg0Strides[ 0 ] ) || ( 1 != arg0Strides[ 1 ] ), ( 1 == arg1Strides[ 0 ] ) || ( 1 != arg1Strides[ 1 ] ) ) + +// Load with random loads from global memory, store into groupshared buffer with coalesced stores +// On my AMD iGPU inside Ryzen 7 5700G, there's whopping 15% performance win with that tactics, from 6.67 to 5.66 seconds for this shader. +// My nVidia GPU does about the same +#define LOAD_ORDER bool2( false, true ) + +#endif + +Buffer<float> arg0: register( t0 ); +Buffer<float> arg1: register( t1 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 arg0Size: packoffset( c0 ); + uint4 arg0Strides: packoffset( c1 ); + uint4 arg1Strides: packoffset( c3 ); + uint4 resultSize: packoffset( c4 ); + uint4 resultStrides: packoffset( c5 ); +} + +groupshared float tile0[ TILE_SIZE ][ TILE_SIZE ]; +#if !STREAM_SECOND_MATRIX +groupshared float tile1[ TILE_SIZE ][ TILE_SIZE ]; +#endif +groupshared float resTemp[ TILE_SIZE ][ TILE_SIZE ]; + +#if STREAM_SECOND_MATRIX +void multiplyTiles( uint rsi, const uint3 thread, const uint w, const uint h ) +{ + for( uint i = thread.y; i < h; i += THREADS_Y, rsi += THREADS_Y * arg1Strides.y ) + { + float r = 0; + uint rsiRow = rsi; + for( uint j = 0; j < w; j++, rsiRow += arg1Strides.x ) + { + // One TILE_SIZE * 4 bytes coalesced load, broadcasted into THREADS_Y copies + const float s0 = tile0[ j ][ thread.x ]; + // THREADS_Y broadcasts from global memory, each one is 4 bytes broadcasted into TILE_SIZE copies + const float s1 = arg1[ rsiRow ]; + // Multiply and accumulate + r = mad( s0, s1, r ); + } + // Accumulate into the output tile + // THREADS_Y * 128 bytes coalesced loads and stores + resTemp[ i ][ thread.x ] += r; + } +} +#else +// Compute resTemp += tile0 * tile1, for TILE_SIZE^2 square matrices +// The group size is TILE_SIZE*THREADS_Y threads in this shader +void multiplyTiles( const uint3 thread ) +{ + for( uint i = thread.y; i < TILE_SIZE; i += THREADS_Y ) + { + float r = 0; + for( uint j = 0; j < TILE_SIZE; j++ ) + { + // One TILE_SIZE * 4 bytes coalesced load, broadcasted into THREADS_Y copies + const float s0 = tile0[ j ][ thread.x ]; + // THREADS_Y broadcasts, each one is 4 bytes broadcasted into TILE_SIZE copies + const float s1 = tile1[ i ][ j ]; + // Multiply and accumulate + r = mad( s0, s1, r ); + } + // Accumulate into the output tile + // THREADS_Y * 128 bytes coalesced loads and stores + resTemp[ i ][ thread.x ] += r; + } +} +#endif + +// Note we transposed these tiles while loading +void loadTile0( uint rsi, const uint3 thread, const uint w, const uint h, const bool rowMajor ) +{ + uint i; + if( rowMajor ) + { + rsi += arg0Strides.y * thread.y; + for( i = thread.y; i < h; i += THREADS_Y, rsi += arg0Strides.y * THREADS_Y ) + { + if( thread.x < w ) + tile0[ thread.x ][ i ] = arg0[ rsi + thread.x * arg0Strides.x ]; + else + tile0[ thread.x ][ i ] = 0.0; + } + } + else + { + // Unlike width which is smaller for the last tile, the height is always the same, and all these tiles are zero-initialized + if( thread.x >= h ) + return; + + rsi += arg0Strides.x * thread.y; + for( i = thread.y; i < w; i += THREADS_Y, rsi += arg0Strides.x * THREADS_Y ) + tile0[ i ][ thread.x ] = arg0[ rsi + thread.x * arg0Strides.y ]; + + if( i >= TILE_SIZE ) + return; + for( ; i < TILE_SIZE; i += THREADS_Y ) + tile0[ i ][ thread.x ] = 0.0; + } +} + +#if !STREAM_SECOND_MATRIX +void loadTile1( uint rsi, const uint3 thread, const uint w, const uint h, const bool rowMajor ) +{ + uint i; + if( rowMajor ) + { + rsi += thread.y * arg1Strides.y; + + for( i = thread.y; i < h; i += THREADS_Y, rsi += arg1Strides.y * THREADS_Y ) + { + if( thread.x < w ) + tile1[ i ][ thread.x ] = arg1[ rsi + thread.x * arg1Strides.x ]; + else + tile1[ i ][ thread.x ] = 0.0; + } + } + else + { + // Unlike width which is smaller for the last tile, the height is always the same, and all these tiles are zero-initialized + if( thread.x >= h ) + return; + + rsi += thread.y * arg1Strides.x; + for( i = thread.y; i < w; i += THREADS_Y, rsi += arg1Strides.x * THREADS_Y ) + tile1[ thread.x ][ i ] = arg1[ rsi + thread.x * arg0Strides.y ]; + if( i >= TILE_SIZE ) + return; + for( ; i < TILE_SIZE; i += THREADS_Y ) + tile1[ thread.x ][ i ] = 0.0; + } +} +#endif + +void storeTile( const uint3 thread, const uint4 pos, const uint2 size ) +{ + if( thread.x >= size.x ) + return; + const uint4 prod4 = pos * resultStrides; + const uint2 prod2 = prod4.xy + prod4.zw; + uint rdi = prod2.x + prod2.y; + rdi += resultStrides.y * thread.y; + for( uint i = thread.y; i < size.y; i += THREADS_Y, rdi += resultStrides.y * THREADS_Y ) + result[ rdi + thread.x * resultStrides.x ] = resTemp[ i ][ thread.x ]; +} + +[ numthreads( TILE_SIZE, THREADS_Y, 1 ) ] +void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID ) +{ + // Zero out these shared buffers + for( uint i = 0; i < TILE_SIZE; i += THREADS_Y ) + { + tile0[ i + thread.y ][ thread.x ] = 0.0; +#if !STREAM_SECOND_MATRIX + tile1[ i + thread.y ][ thread.x ] = 0.0; +#endif + resTemp[ i + thread.y ][ thread.x ] = 0.0; + } + + const uint2 resultPos = group.xy * TILE_SIZE; + const uint2 layer = uint2( group.z % resultSize.z, group.z / resultSize.z ); + uint rsi0 = resultPos.x * arg0Strides.y + layer.x * arg0Strides.z + layer.y * arg0Strides.w; + uint rsi1 = resultPos.y * arg1Strides.y + layer.x * arg1Strides.z + layer.y * arg1Strides.w; + + const uint rsi0Inc = TILE_SIZE * arg0Strides.x; + const uint rsi1Inc = TILE_SIZE * arg1Strides.x; + + const uint completeTiles = arg0Size.x / TILE_SIZE; + const uint rsi0AndAligned = rsi0 + rsi0Inc * completeTiles; + // Output tile size + // Normally TILE_SIZE^2, less than that for the tiles at the right and bottom edges of the output matrix + const uint2 outputSize = min( TILE_SIZE, resultSize.xy - resultPos ); + + const bool2 loadOrder = LOAD_ORDER; + +#if STREAM_SECOND_MATRIX + rsi1 += thread.y * arg1Strides.y; +#endif + for( ; rsi0 < rsi0AndAligned; rsi0 += rsi0Inc, rsi1 += rsi1Inc ) + { + loadTile0( rsi0, thread, TILE_SIZE, outputSize.x, loadOrder.x ); +#if STREAM_SECOND_MATRIX + GroupMemoryBarrierWithGroupSync(); + multiplyTiles( rsi1, thread, TILE_SIZE, outputSize.y ); +#else + loadTile1( rsi1, thread, TILE_SIZE, outputSize.y, loadOrder.y ); + GroupMemoryBarrierWithGroupSync(); + multiplyTiles( thread ); +#endif + // Need one moar barrier here. + // Otherwise, some threads of the group are loading the next tile into tile0/tile1 groupshared buffers on the next iteration of the loop, + // while other threads of the same group are still computing the matrix product, and getting incorrect values from that groupshared buffer. + // The missing barrier only caused a bug on AMD, and only with "ggml-large.bin" model; no idea why that is. + GroupMemoryBarrierWithGroupSync(); + } + + const uint rem = arg0Size.x % TILE_SIZE; + if( 0 != rem ) + { + loadTile0( rsi0, thread, rem, outputSize.x, loadOrder.x ); +#if STREAM_SECOND_MATRIX + GroupMemoryBarrierWithGroupSync(); + multiplyTiles( rsi1, thread, rem, outputSize.y ); +#else + loadTile1( rsi1, thread, rem, outputSize.y, loadOrder.y ); + GroupMemoryBarrierWithGroupSync(); + multiplyTiles( thread ); +#endif + } + + GroupMemoryBarrierWithGroupSync(); + storeTile( thread, uint4( resultPos, layer ), outputSize ); +}
\ No newline at end of file diff --git a/ComputeShaders/mulMatTiled64.hlsl b/ComputeShaders/mulMatTiled64.hlsl new file mode 100644 index 0000000..45d77b1 --- /dev/null +++ b/ComputeShaders/mulMatTiled64.hlsl @@ -0,0 +1,3 @@ +#define TILE_SIZE 64 +#define STREAM_SECOND_MATRIX 1 +#include "mulMatTiled.hlsl"
\ No newline at end of file diff --git a/ComputeShaders/mulMatTiledEx.hlsl b/ComputeShaders/mulMatTiledEx.hlsl new file mode 100644 index 0000000..0f23da2 --- /dev/null +++ b/ComputeShaders/mulMatTiledEx.hlsl @@ -0,0 +1,194 @@ +// This compute shader implements yet another version of matrix*matrix product +// For optimal VRAM access pattern, it requires both arguments to be reshaped into a sequence of horizontal column major panels. +// The panel height is TILE_SIZE, and the last panel of the matrix needs to be padded with zeros; see matReshapePanels.hlsl shader for the reshaping. +// So far, it's only used when running on AMD GPUs. +#ifndef TILE_SIZE +static const uint TILE_SIZE = 32; +#endif +#ifndef TILE_HEIGHT +static const uint TILE_HEIGHT = 32; +#endif +#ifndef THREADS_Y +static const uint THREADS_Y = 16; +#endif + +#ifndef STREAM_SECOND_MATRIX +#define STREAM_SECOND_MATRIX 1 +#endif + +// First tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ] +Buffer<float> arg0: register( t0 ); +// Second tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ] +Buffer<float> arg1: register( t1 ); +// FP32 output tensor, row major and continuous +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 arg0Size: packoffset( c0 ); + uint arg0panel: packoffset( c1.y ); + uint2 arg0LayerStrides: packoffset( c1.z ); + + // uint4 arg1Size: packoffset( c2 ); + uint arg1panel: packoffset( c3.y ); + uint2 arg1LayerStrides: packoffset( c3.z ); + + uint4 resultSize: packoffset( c4 ); + uint4 resultStrides: packoffset( c5 ); +} + +// Accumulator for the output tile +// That last `+1` helps a bit, I'm not sure why exactly but probebly because memory bank conflicts. +groupshared float tileOutput[ TILE_SIZE ][ TILE_SIZE + 1 ]; +// A smaller tile loaded from the first source matrix +groupshared float tile0[ TILE_HEIGHT ][ TILE_SIZE ]; +#if !STREAM_SECOND_MATRIX +// A smaller tile loaded from the second source matrix +groupshared float tile1[ TILE_HEIGHT ][ TILE_SIZE ]; +#endif + +#if STREAM_SECOND_MATRIX +void multiplyTiles( const uint3 thread, uint rsi, const uint h ) +{ + uint2 i = uint2( thread.y, rsi ); + const uint2 iInc = uint2( THREADS_Y, THREADS_Y ); + for( ; i.x < TILE_SIZE; i += iInc ) + { + float r = 0.0; + uint2 j = uint2( 0, i.y ); + const uint2 jInc = uint2( 1, TILE_SIZE ); + for( ; j.x < h; j += jInc ) + { + float a = tile0[ j.x ][ thread.x ]; + float b = arg1[ j.y ]; + r = mad( a, b, r ); + } + tileOutput[ i.x ][ thread.x ] += r; + } +} +#else +void multiplyTiles( const uint3 thread ) +{ + for( uint row = thread.y; row < TILE_SIZE; row += THREADS_Y ) + { + float r = 0.0; + for( uint j = 0; j < TILE_HEIGHT; j++ ) + { + float a = tile0[ j ][ thread.x ]; + float b = tile1[ j ][ row ]; + r = mad( a, b, r ); + } + tileOutput[ row ][ thread.x ] += r; + } +} +#endif + +void storeTile( const uint3 thread, const uint4 pos, const uint2 size ) +{ + if( thread.x >= size.x ) + return; + const uint4 prod4 = pos * resultStrides; + const uint2 prod2 = prod4.xy + prod4.zw; + uint rdi = prod2.x + prod2.y; + rdi += resultStrides.y * thread.y; + rdi += thread.x; + for( uint i = thread.y; i < size.y; i += THREADS_Y, rdi += resultStrides.y * THREADS_Y ) + result[ rdi ] = tileOutput[ i ][ thread.x ]; +} + +[numthreads( TILE_SIZE, THREADS_Y, 1 )] +void main( const uint3 group: SV_GroupID, const uint3 thread : SV_GroupThreadID ) +{ + uint i; + // Zero all 3 shared buffers + for( i = thread.y; i < TILE_SIZE; i += THREADS_Y ) + tileOutput[ i ][ thread.x ] = 0.0; + for( i = thread.y; i < TILE_HEIGHT; i += THREADS_Y ) + { + tile0[ i ][ thread.x ] = 0.0; +#if !STREAM_SECOND_MATRIX + tile1[ i ][ thread.x ] = 0.0; +#endif + } + + const uint2 layer = uint2( group.z % resultSize.z, group.z / resultSize.z ); + + uint rsi0 = group.x * arg0panel + layer.x * arg0LayerStrides.x + layer.y * arg0LayerStrides.y; + uint rsi1 = group.y * arg1panel + layer.x * arg1LayerStrides.x + layer.y * arg1LayerStrides.y; + + const uint threadOffset = thread.y * TILE_SIZE + thread.x; + rsi0 += threadOffset; +#if STREAM_SECOND_MATRIX + rsi1 += thread.y; +#else + rsi1 += threadOffset; +#endif + + const uint completeTiles = arg0Size.x / TILE_HEIGHT; + for( i = 0; i < completeTiles; i++ ) + { + // Load [ TILE_SIZE, TILE_HEIGHT ] block from both source tensors into these groupshared buffers + for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y ) + { + tile0[ j ][ thread.x ] = arg0[ rsi0 ]; + rsi0 += THREADS_Y * TILE_SIZE; +#if !STREAM_SECOND_MATRIX + tile1[ j ][ thread.x ] = arg1[ rsi1 ]; + rsi1 += THREADS_Y * TILE_SIZE; +#endif + } + + // Wait for all threads in the group to complete these loads + GroupMemoryBarrierWithGroupSync(); + +#if STREAM_SECOND_MATRIX + multiplyTiles( thread, rsi1, TILE_HEIGHT ); + rsi1 += TILE_HEIGHT * TILE_SIZE; +#else + // Multiply + accumulate the elements collected in the groupshared buffers + multiplyTiles( thread ); +#endif + GroupMemoryBarrierWithGroupSync(); + } + + const uint rem = arg0Size.x % TILE_HEIGHT; + if( rem != 0 ) + { + // Load [ TILE_SIZE, rem ] block from both source tensors, and zero out the padding elements + for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y ) + { + [branch] + if( j < rem ) + { + tile0[ j ][ thread.x ] = arg0[ rsi0 ]; + rsi0 += THREADS_Y * TILE_SIZE; +#if !STREAM_SECOND_MATRIX + tile1[ j ][ thread.x ] = arg1[ rsi1 ]; + rsi1 += THREADS_Y * TILE_SIZE; +#endif + } + else + { + tile0[ j ][ thread.x ] = 0.0; +#if !STREAM_SECOND_MATRIX + tile1[ j ][ thread.x ] = 0.0; +#endif + } + } + + // Wait for all threads in the group to complete these loads + GroupMemoryBarrierWithGroupSync(); + + // Multiply + accumulate the elements collected in the groupshared buffers +#if STREAM_SECOND_MATRIX + multiplyTiles( thread, rsi1, rem ); +#else + multiplyTiles( thread ); +#endif + GroupMemoryBarrierWithGroupSync(); + } + + const uint2 resultPos = group.xy * TILE_SIZE; + const uint2 outputSize = min( TILE_SIZE, resultSize.xy - resultPos ); + storeTile( thread, uint4( resultPos, layer ), outputSize ); +}
\ No newline at end of file diff --git a/ComputeShaders/norm.hlsl b/ComputeShaders/norm.hlsl new file mode 100644 index 0000000..eeb82b7 --- /dev/null +++ b/ComputeShaders/norm.hlsl @@ -0,0 +1,86 @@ +// Ported from ggml_compute_forward_norm_f32 +// Dispatch [ ne01, ne02, ne03 ] thread groups of this shader +Buffer<float> arg0: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + uint4 result_strides: packoffset( c3 ); +} + +static const float eps = 1e-5f; // TODO: make this a parameter + +#include "groupReduce.hlsli" + +float computeVectorSum( uint i, const uint length, const uint thread ) +{ + float res = 0.0; + + const uint iEnd = i + length; + i += thread; + for( ; i < iEnd; i += 32 ) + res += arg0[ i ]; + + horizontalSumBroadcast( thread, res ); + return res; +} + +float offsetAndComputeSumSquares( uint rsi, uint rdi, const float mean, const uint length, const uint thread ) +{ + float sum2 = 0.0; + + const uint rsiEnd = rsi + length; + rsi += thread; + rdi += thread; + for( ; rsi < rsiEnd; rsi += 32, rdi += 32 ) + { + float v = arg0[ rsi ] - mean; + result[ rdi ] = v; + sum2 = mad( v, v, sum2 ); + } + + horizontalSumBroadcast( thread, sum2 ); + return sum2; +} + +void scaleVector( uint rdi, const float scale, const uint length, const uint thread ) +{ + const uint rdiEnd = rdi + length; + for( rdi += thread; rdi < rdiEnd; rdi += 32 ) + { + float f = result[ rdi ]; + f *= scale; + result[ rdi ] = f; + } +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint i03 = group.z; + const uint i02 = group.y; + const uint i01 = group.x; + + const uint nb01 = src0_strides[ 1 ]; + const uint nb02 = src0_strides[ 2 ]; + const uint nb03 = src0_strides[ 3 ]; + + const uint p = i01 * nb01 + i02 * nb02 + i03 * nb03; + + const uint ne00 = src0_elements[ 0 ]; + + float mean = computeVectorSum( p, ne00, thread ); + mean /= (float)(int)ne00; + + const uint nb1 = result_strides[ 1 ]; + const uint nb2 = result_strides[ 2 ]; + const uint nb3 = result_strides[ 3 ]; + const uint y = i01 * nb1 + i02 * nb2 + i03 * nb3; + + float sum2 = offsetAndComputeSumSquares( p, y, mean, ne00, thread ); + const float scale = 1.0 / sqrt( sum2 / (float)(int)ne00 + eps ); + + scaleVector( y, scale, ne00, thread ); +}
\ No newline at end of file diff --git a/ComputeShaders/normCompat.hlsl b/ComputeShaders/normCompat.hlsl new file mode 100644 index 0000000..23e1228 --- /dev/null +++ b/ComputeShaders/normCompat.hlsl @@ -0,0 +1,82 @@ +// Ported from ggml_compute_forward_norm_f32 +// Dispatch [ ( ne01 + 31 ) / 32, ne02, ne03 ] thread groups of this shader +Buffer<float> arg0: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + uint4 result_strides: packoffset( c3 ); +} + +static const double eps = 1e-5; // TODO: make this a parameter + +#include "groupReduce.hlsli" + +double computeVectorSum( uint i, const uint length ) +{ + double res = 0.0; + const uint iEnd = i + length; + for( ; i < iEnd; i++ ) + res += arg0[ i ]; + return res; +} + +double offsetAndComputeSumSquares( uint rsi, uint rdi, const double mean, const uint length ) +{ + precise double sum2 = 0.0; + const uint rsiEnd = rsi + length; + for( ; rsi < rsiEnd; rsi++, rdi++ ) + { + double v = arg0[ rsi ]; + v -= mean; + result[ rdi ] = (float)v; + double prod = v * v; + sum2 += prod; + } + return sum2; +} + +void scaleVector( uint rdi, const float scale, const uint length ) +{ + const uint rdiEnd = rdi + length; + for( ; rdi < rdiEnd; rdi++ ) + { + float f = result[ rdi ]; + f *= scale; + result[ rdi ] = f; + } +} + +#include "fp64Utils.hlsli" + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 dtid: SV_DispatchThreadID ) +{ + const uint i03 = dtid.z; + const uint i02 = dtid.y; + const uint i01 = dtid.x; + if( i01 >= src0_elements[ 1 ] ) + return; + + const uint nb01 = src0_strides[ 1 ]; + const uint nb02 = src0_strides[ 2 ]; + const uint nb03 = src0_strides[ 3 ]; + + const uint p = i01 * nb01 + i02 * nb02 + i03 * nb03; + const uint ne00 = src0_elements[ 0 ]; + + double mean = computeVectorSum( p, ne00 ); + mean = div64( mean, (double)(int)ne00 ); + + const uint nb1 = result_strides[ 1 ]; + const uint nb2 = result_strides[ 2 ]; + const uint nb3 = result_strides[ 3 ]; + const uint y = i01 * nb1 + i02 * nb2 + i03 * nb3; + + const double sum2 = offsetAndComputeSumSquares( p, y, mean, ne00 ); + const float scale = (float)div64( 1.0, sqrt64( sum2 / (float)(int)ne00 + eps ) ); + + scaleVector( y, scale, ne00 ); +}
\ No newline at end of file diff --git a/ComputeShaders/normFixed.hlsl b/ComputeShaders/normFixed.hlsl new file mode 100644 index 0000000..8f2267f --- /dev/null +++ b/ComputeShaders/normFixed.hlsl @@ -0,0 +1,124 @@ +// Ported from ggml_compute_forward_norm_f32 +// Dispatch [ ne01, ne02, ne03 ] thread groups of this shader +Buffer<float> arg0: register( t0 ); +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + uint4 result_strides: packoffset( c3 ); +} + +static const float eps = 1e-5f; // TODO: make this a parameter + +// #include "groupReduce.hlsli" + +#ifndef THREADS +static const uint THREADS = 32; +#endif +static const uint ROW_LENGTH = 1024; +groupshared float rowBuffer[ ROW_LENGTH ]; + +static const uint REDUCTION_BUFFER = 32; +groupshared float sharedAccumulators[ REDUCTION_BUFFER ]; + +// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group. +void horizontalSum( const uint thread, inout float sum ) +{ + if( THREADS > REDUCTION_BUFFER ) + { + for( uint t = REDUCTION_BUFFER; t < THREADS; t += REDUCTION_BUFFER ) + { + // Threads [ t .. t + REDUCTION_BUFFER ] store into the buffer + if( thread >= t && thread < t + REDUCTION_BUFFER ) + sharedAccumulators[ thread - t ] = sum; + + GroupMemoryBarrierWithGroupSync(); + + // Threads [ 0 .. REDUCTION_BUFFER ] increment their local sum with the value loaded from the buffer + if( thread < REDUCTION_BUFFER ) + sum += sharedAccumulators[ thread ]; + } + } + + if( thread < REDUCTION_BUFFER ) + sharedAccumulators[ thread ] = sum; + + for( uint i = REDUCTION_BUFFER / 2; i > 1; i /= 2 ) + { + GroupMemoryBarrierWithGroupSync(); + if( thread < i ) + { + sum += sharedAccumulators[ thread + i ]; + sharedAccumulators[ thread ] = sum; + } + } + + GroupMemoryBarrierWithGroupSync(); + if( 0 == thread ) + sum += sharedAccumulators[ 1 ]; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint i03 = group.z; + const uint i02 = group.y; + const uint i01 = group.x; + const uint ne00 = ROW_LENGTH; + + // First pass: copy the data to local buffer, and compute sum + { + const uint nb01 = src0_strides[ 1 ]; + const uint nb02 = src0_strides[ 2 ]; + const uint nb03 = src0_strides[ 3 ]; + const uint p = i01 * nb01 + i02 * nb02 + i03 * nb03; + + float sum = 0; + for( uint i = thread; i < ne00; i += THREADS ) + { + float f = arg0[ p + i ]; + rowBuffer[ i ] = f; + sum += f; + } + horizontalSum( thread, sum ); + if( 0 == thread ) + sharedAccumulators[ 0 ] = sum / (float)(int)ne00; + GroupMemoryBarrierWithGroupSync(); + } + + // Second pass: offset and compute sum of squares + { + const float mean = sharedAccumulators[ 0 ]; + float sum2 = 0; + for( uint i = thread; i < ne00; i += THREADS ) + { + float v = rowBuffer[ i ]; + v -= mean; + rowBuffer[ i ] = v; + sum2 = mad( v, v, sum2 ); + } + horizontalSum( thread, sum2 ); + if( 0 == thread ) + sharedAccumulators[ 0 ] = 1.0 / sqrt( sum2 / (float)(int)ne00 + eps ); + GroupMemoryBarrierWithGroupSync(); + } + + // Final pass: apply the scale, and copy from group shared buffer to the destination + { + const float scale = sharedAccumulators[ 0 ]; + + const uint nb1 = result_strides[ 1 ]; + const uint nb2 = result_strides[ 2 ]; + const uint nb3 = result_strides[ 3 ]; + const uint y = i01 * nb1 + i02 * nb2 + i03 * nb3; + + for( uint i = thread; i < ne00; i += THREADS ) + { + float v = rowBuffer[ i ]; + v *= scale; + result[ y + i ] = v; + } + } +}
\ No newline at end of file diff --git a/ComputeShaders/normFixed64.hlsl b/ComputeShaders/normFixed64.hlsl new file mode 100644 index 0000000..14aab3d --- /dev/null +++ b/ComputeShaders/normFixed64.hlsl @@ -0,0 +1,2 @@ +#define THREADS 64 +#include "normFixed.hlsl"
\ No newline at end of file diff --git a/ComputeShaders/repeatUtils.hlsli b/ComputeShaders/repeatUtils.hlsli new file mode 100644 index 0000000..1181501 --- /dev/null +++ b/ComputeShaders/repeatUtils.hlsli @@ -0,0 +1,21 @@ +inline uint rowOffset( uint3 idx, uint4 strides ) +{ + return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ]; +} + +// Initial iterator state for a row of the output tensor +// x = current index, y = index increment, z = end of the index +inline uint3 tensorIteratorState( uint3 group, uint thread, uint4 size, uint4 stride ) +{ + uint3 res; + res.x = rowOffset( group, stride ); + res.y = THREADS * stride[ 0 ]; + res.z = res.x + size[ 0 ] * stride[ 0 ]; + res.x += thread * stride[ 0 ]; + return res; +} + +// Handle a complete row of output tensor, using the iterator made by tensorIteratorState() function +#define ROW_LOOP( ts ) for( ; ts.x < ts.z; ts.x += ts.y ) +// Same as above, using different row length +#define ROW_LOOP_EX( ts, len, stride ) for( ; ts.x < ts.z; ts.x += len * stride[ 0 ] )
\ No newline at end of file diff --git a/ComputeShaders/scaleInPlace.hlsl b/ComputeShaders/scaleInPlace.hlsl new file mode 100644 index 0000000..d0320b4 --- /dev/null +++ b/ComputeShaders/scaleInPlace.hlsl @@ -0,0 +1,23 @@ +RWBuffer<float> buffer: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 src0_elements: packoffset( c0 ); + uint4 src0_strides: packoffset( c1 ); + float multiplier: packoffset( c2.x ); +} + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint nc0 = src0_elements[ 0 ]; + uint i = group.x * src0_strides[ 1 ]; + const uint iEnd = i + nc0; + const float mul = multiplier; + for( i += thread; i < iEnd; i += 32 ) + { + float f = buffer[ i ]; + f *= mul; + buffer[ i ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/softMax.hlsl b/ComputeShaders/softMax.hlsl new file mode 100644 index 0000000..6ebd0f2 --- /dev/null +++ b/ComputeShaders/softMax.hlsl @@ -0,0 +1,71 @@ +// Dispatch [ nr, 1, 1 ] thread groups of this shader +RWBuffer<float> result: register( u0 ); + +// table_exp_f16 +Buffer<uint> lookupTable: register( t0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 elements: packoffset( c0 ); + uint4 strides: packoffset( c1 ); + uint nr: packoffset( c2.x ); + float inputScale: packoffset( c2.y ); +} + +#include "miscUtils.hlsli" +#include "groupReduce.hlsli" + +static const float negativeInfinity = asfloat( 0xff800000 ); + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint p = group.x * strides[ 1 ]; + const uint nc = elements[ 0 ]; + const uint pEnd = p + nc; + uint i; + + float m = negativeInfinity; + for( i = p + thread; i < pEnd; i += 32 ) + m = max( m, result[ i ] ); + horizontalMaxBroadcast( thread, m ); + + float sum = 0; + for( i = p + thread; i < pEnd; i += 32 ) + { + float f = result[ i ]; + + [branch] + if( f != negativeInfinity ) + { + f = ( f - m ) * inputScale; +#if 1 + // Similar to Radeon Graphics, computing the exponent on nVidia 1080Ti is also slightly faster than loading from the lookup table + f = exp( f ); +#else + uint s = fp16Rounded( f ); + s = lookupTable[ s ]; + f = f16tof32( s ); +#endif + sum += f; + } + else + f = 0; + + result[ i ] = f; + } + + horizontalSum( thread, sum ); + if( 0 == thread ) + sharedAccumulators[ 0 ] = 1.0 / sum; + GroupMemoryBarrierWithGroupSync(); + const float scale = sharedAccumulators[ 0 ]; + + // ggml_vec_scale_f32 + for( i = p + thread; i < pEnd; i += 32 ) + { + float f = result[ i ]; + f *= scale; + result[ i ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/softMax64.hlsl b/ComputeShaders/softMax64.hlsl new file mode 100644 index 0000000..7ecd2ef --- /dev/null +++ b/ComputeShaders/softMax64.hlsl @@ -0,0 +1,71 @@ +// Dispatch [ nr, 1, 1 ] thread groups of this shader +RWBuffer<float> result: register( u0 ); + +// table_exp_f16 +Buffer<uint> lookupTable: register( t0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 elements: packoffset( c0 ); + uint4 strides: packoffset( c1 ); + uint nr: packoffset( c2.x ); + float inputScale: packoffset( c2.y ); +} + +#include "miscUtils.hlsli" +#include "groupReduce64.hlsli" + +static const float negativeInfinity = asfloat( 0xff800000 ); + +[ numthreads( 64, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint p = group.x * strides[ 1 ]; + const uint nc = elements[ 0 ]; + const uint pEnd = p + nc; + uint i; + + float m = negativeInfinity; + for( i = p + thread; i < pEnd; i += 64 ) + m = max( m, result[ i ] ); + horizontalMaxBroadcast( thread, m ); + + float sum = 0; + for( i = p + thread; i < pEnd; i += 64 ) + { + float f = result[ i ]; + + [branch] + if( f != negativeInfinity ) + { + f = ( f - m ) * inputScale; +#if 1 + // At least on Radeon Graphics GPU inside Ryzen 7 5700G, computing exponent instead of loading from the buffer improves the performance + f = exp( f ); +#else + uint s = fp16Rounded( f ); + s = lookupTable[ s ]; + f = f16tof32( s ); +#endif + sum += f; + } + else + f = 0; + + result[ i ] = f; + } + + horizontalSum( thread, sum ); + if( 0 == thread ) + sharedAccumulators[ 0 ] = 1.0 / sum; + GroupMemoryBarrierWithGroupSync(); + const float scale = sharedAccumulators[ 0 ]; + + // ggml_vec_scale_f32 + for( i = p + thread; i < pEnd; i += 64 ) + { + float f = result[ i ]; + f *= scale; + result[ i ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/softMaxCompat.hlsl b/ComputeShaders/softMaxCompat.hlsl new file mode 100644 index 0000000..2215ebd --- /dev/null +++ b/ComputeShaders/softMaxCompat.hlsl @@ -0,0 +1,62 @@ +// ggml_compute_forward_soft_max_f32 +// Dispatch [ ( nr + 31 ) / 32, 1, 1 ] thread groups of this shader +RWBuffer<float> result: register( u0 ); + +// table_exp_f16 +Buffer<uint> lookupTable: register( t0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 elements: packoffset( c0 ); + uint4 strides: packoffset( c1 ); + uint nr: packoffset( c2.x ); +} + +#include "miscUtils.hlsli" +#include "fp64Utils.hlsli" + +static const float negativeInfinity = asfloat( 0xff800000 ); + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 dtid: SV_DispatchThreadID ) +{ + if( dtid.x >= nr ) + return; + + const uint p = dtid.x * strides[ 1 ]; + const uint nc = elements[ 0 ]; + const uint pEnd = p + nc; + uint i; + + float m = negativeInfinity; + for( i = p; i < pEnd; i++ ) + m = max( m, result[ i ] ); + + double sum = 0; + for( i = p; i < pEnd; i++ ) + { + float f = result[ i ]; + + [branch] + if( f != negativeInfinity ) + { + uint s = fp16Rounded( f - m ); + s = lookupTable[ s ]; + f = f16tof32( s ); + sum += f; + } + else + f = 0; + + result[ i ] = f; + } + + const float scale = (float)div64( 1.0, sum ); + // ggml_vec_scale_f32 + for( i = p; i < pEnd; i++ ) + { + float f = result[ i ]; + f *= scale; + result[ i ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/softMaxFixed.hlsl b/ComputeShaders/softMaxFixed.hlsl new file mode 100644 index 0000000..7b4add2 --- /dev/null +++ b/ComputeShaders/softMaxFixed.hlsl @@ -0,0 +1,79 @@ +// Special softMax shader for matrices with rows of 1500 elements. +// Uses group shared buffer of that length to save global memory bandwidth, more than 2x faster than the original. +// Dispatch [ nr, 1, 1 ] thread groups of this shader +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 elements: packoffset( c0 ); + uint4 strides: packoffset( c1 ); + uint nr: packoffset( c2.x ); + float inputScale: packoffset( c2.y ); +} + +#include "miscUtils.hlsli" +#include "groupReduce64.hlsli" + +static const uint THREADS = 64; +static const uint ROW_LENGTH = 1500; +groupshared float rowBuffer[ ROW_LENGTH ]; + +static const float negativeInfinity = asfloat( 0xff800000 ); + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint p = group.x * strides[ 1 ]; + const uint nc = ROW_LENGTH; + uint i; + + float m = negativeInfinity; + // First pass: compute maximum, and copy the row into the group shared buffer + for( i = thread; i < nc; i += THREADS ) + { + float f = result[ p + i ]; + m = max( m, f ); + rowBuffer[ i ] = f; + } + horizontalMaxBroadcast( thread, m ); + + // Second pass: apply initial scale, compute the exponent, and compute total sum over the row + float sum = 0; + for( i = thread; i < nc; i += THREADS ) + { + float f = rowBuffer[ i ]; + + [branch] + if( f != negativeInfinity ) + { + f = ( f - m ) * inputScale; +#if 1 + // At least on Radeon Graphics GPU inside Ryzen 7 5700G, computing exponent instead of loading from the buffer improves the performance + f = exp( f ); +#else + uint s = fp16Rounded( f ); + s = lookupTable[ s ]; + f = f16tof32( s ); +#endif + sum += f; + } + else + f = 0; + + rowBuffer[ i ] = f; + } + + horizontalSum( thread, sum ); + if( 0 == thread ) + sharedAccumulators[ 0 ] = 1.0 / sum; + GroupMemoryBarrierWithGroupSync(); + const float scale = sharedAccumulators[ 0 ]; + + // Final pass: apply the final scale, and copy the row from the group shared buffer back into the global memory + for( i = thread; i < nc; i += THREADS ) + { + float f = rowBuffer[ i ]; + f *= scale; + result[ p + i ] = f; + } +}
\ No newline at end of file diff --git a/ComputeShaders/zeroMemory.hlsl b/ComputeShaders/zeroMemory.hlsl new file mode 100644 index 0000000..c486636 --- /dev/null +++ b/ComputeShaders/zeroMemory.hlsl @@ -0,0 +1,27 @@ +RWBuffer<float> result: register( u0 ); + +cbuffer Constants: register( b0 ) +{ + uint elements: packoffset( c0.x ); +} + +// Thread group index is 16 bits per coordinate: +// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/nf-d3d11-id3d11devicecontext-dispatch +// We want this shader to support buffers up to 2 GB. +#ifndef THREADS +static const uint THREADS = 512; +#endif +#ifndef ITERATIONS +static const uint ITERATIONS = 128; +#endif + +static const uint itemsPerGroup = THREADS * ITERATIONS; + +[numthreads( THREADS, 1, 1 )] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint rdi = group.x * itemsPerGroup; + const uint rdiEnd = min( rdi + itemsPerGroup, elements ); + for( rdi += thread; rdi < rdiEnd; rdi += THREADS ) + result[ rdi ] = 0.0; +}
\ No newline at end of file diff --git a/Examples/MicrophoneCS/CaptureThread.cs b/Examples/MicrophoneCS/CaptureThread.cs new file mode 100644 index 0000000..b76a929 --- /dev/null +++ b/Examples/MicrophoneCS/CaptureThread.cs @@ -0,0 +1,61 @@ +using System.Runtime.ExceptionServices; +using Whisper; + +namespace MicrophoneCS +{ + sealed class CaptureThread: CaptureCallbacks + { + public CaptureThread( CommandLineArgs args, Context context, iAudioCapture source ) + { + callbacks = new TranscribeCallbacks( args ); + this.context = context; + this.source = source; + + thread = new Thread( threadMain ) { Name = "Capture Thread" }; + Console.WriteLine( "Press any key to quit" ); + thread.Start(); + } + + static void readKeyCallback( object? state ) + { + CaptureThread ct = ( state as CaptureThread ) ?? throw new ApplicationException(); + Console.ReadKey(); + ct.shouldQuit = true; + } + + public void join() + { + ThreadPool.QueueUserWorkItem( readKeyCallback, this ); + thread.Join(); + edi?.Throw(); + } + + volatile bool shouldQuit = false; + + protected override bool shouldCancel( Context sender ) => + shouldQuit; + + protected override void captureStatusChanged( Context sender, eCaptureStatus status ) + { + Console.WriteLine( $"CaptureStatusChanged: {status}" ); + } + + readonly TranscribeCallbacks callbacks; + readonly Thread thread; + readonly Context context; + readonly iAudioCapture source; + ExceptionDispatchInfo? edi = null; + + void threadMain() + { + try + { + context.runCapture( source, callbacks, this ); + } + catch( Exception ex ) + { + edi = ExceptionDispatchInfo.Capture( ex ); + } + } + } +}
\ No newline at end of file diff --git a/Examples/MicrophoneCS/CommandLineArgs.cs b/Examples/MicrophoneCS/CommandLineArgs.cs new file mode 100644 index 0000000..be5fbe9 --- /dev/null +++ b/Examples/MicrophoneCS/CommandLineArgs.cs @@ -0,0 +1,145 @@ +using System.Globalization; +using System.Reflection; +using Whisper; + +namespace MicrophoneCS +{ + sealed record class CommandLineArgs + { + public int n_threads = Environment.ProcessorCount; + public int offset_t_ms = 0; + public int offset_n = 0; + public int duration_ms = 0; + public int max_context = -1; + public int max_len = 0; + + public float word_thold = 0.01f; + + public bool speed_up = false; + public bool translate = false; + public bool diarize = false; + public bool output_txt = false; + public bool print_special = false; + public bool print_progress = false; + public bool print_colors = true; + public bool no_timestamps = false; + public int[]? prompt = null; + public int captureDeviceIndex = 0; + + public eLanguage language = eLanguage.English; + public string model = string.Empty; + + const bool output_wts = false; + public bool listDevices = false; + + public void apply( ref Parameters p ) + { + p.setFlag( eFullParamsFlags.PrintRealtime, false ); + p.setFlag( eFullParamsFlags.PrintProgress, print_progress ); + p.setFlag( eFullParamsFlags.PrintTimestamps, !no_timestamps ); + p.setFlag( eFullParamsFlags.PrintSpecial, print_special ); + p.setFlag( eFullParamsFlags.Translate, translate ); + p.language = language; + p.cpuThreads = n_threads; + if( max_context >= 0 ) + p.n_max_text_ctx = max_context; + p.offset_ms = offset_t_ms; + p.duration_ms = duration_ms; + p.setFlag( eFullParamsFlags.TokenTimestamps, output_wts || max_len > 0 ); + p.thold_pt = word_thold; + p.max_len = output_wts && max_len == 0 ? 60 : max_len; + p.setFlag( eFullParamsFlags.SpeedupAudio, speed_up ); + } + + public eResultFlags resultFlags() + { + eResultFlags flags = eResultFlags.None; + bool wts = output_wts || max_len > 0; + if( !no_timestamps || wts ) + flags |= eResultFlags.Timestamps; + if( wts || print_colors ) + flags |= eResultFlags.Tokens; + return flags; + } + + static eLanguage parseLanguage( string lang ) => + Library.languageFromCode( lang ) ?? throw new ArgumentException( $"Unknown language code \"{lang}\"" ); + + public CommandLineArgs( string[] argv ) + { + for( int i = 0; i < argv.Length; i++ ) + { + string arg = argv[ i ]; + if( arg == "-h" || arg == "--help" ) + { + printUsage(); + throw new OperationCanceledException(); + } + else if( arg == "-c" || arg == "--capture" ) captureDeviceIndex = int.Parse( argv[ ++i ] ); + else if( arg == "-ld" || arg == "--list-devices" ) listDevices = true; + else if( arg == "-t" || arg == "--threads" ) n_threads = int.Parse( argv[ ++i ] ); + else if( arg == "-ot" || arg == "--offset-t" ) offset_t_ms = int.Parse( argv[ ++i ] ); + else if( arg == "-on" || arg == "--offset-n" ) offset_n = int.Parse( argv[ ++i ] ); + else if( arg == "-d" || arg == "--duration" ) duration_ms = int.Parse( argv[ ++i ] ); + else if( arg == "-mc" || arg == "--max-context" ) max_context = int.Parse( argv[ ++i ] ); + else if( arg == "-ml" || arg == "--max-len" ) max_len = int.Parse( argv[ ++i ] ); + else if( arg == "-wt" || arg == "--word-thold" ) word_thold = float.Parse( argv[ ++i ], CultureInfo.InvariantCulture ); + else if( arg == "-su" || arg == "--speed-up" ) speed_up = true; + else if( arg == "-tr" || arg == "--translate" ) translate = true; + else if( arg == "-di" || arg == "--diarize" ) diarize = true; + else if( arg == "-otxt" || arg == "--output-txt" ) output_txt = true; + else if( arg == "-ps" || arg == "--print-special" ) print_special = true; + else if( arg == "-nc" || arg == "--no-colors" ) print_colors = false; + else if( arg == "-pp" || arg == "--print-progress" ) print_progress = true; + else if( arg == "-nt" || arg == "--no-timestamps" ) no_timestamps = true; + else if( arg == "-l" || arg == "--language" ) language = parseLanguage( argv[ ++i ] ); + else if( arg == "--prompt" ) prompt = parsePrompt( argv[ ++i ] ); + else if( arg == "-m" || arg == "--model" ) model = argv[ ++i ]; + else + throw new ArgumentException( $"Unknown argument: \"{arg}\"" ); + } + if( string.IsNullOrWhiteSpace( model ) ) + throw new ArgumentException( "The model file is not provided in the arguments" ); + if( !File.Exists( model ) ) + throw new FileNotFoundException( "Model not found", model ); + } + + static string cstr( bool b ) => b.ToString(); + + static int[]? parsePrompt( string str ) + { + if( string.IsNullOrWhiteSpace( str ) ) + return null; + // TODO: expose whisper_tokenize function, as a method of iModel COM interface + throw new NotImplementedException(); + } + + void printUsage() + { + Console.WriteLine(); + + Console.WriteLine( "usage: {0} [options] file0.mp3 file1.wma ...", Path.GetFileName( Assembly.GetExecutingAssembly().Location ) ); + Console.WriteLine(); + Console.WriteLine( "options:" ); + Console.WriteLine( " -h, --help [default] show this help message and exit" ); + Console.WriteLine( " -t N, --threads N [{0,-7:D}] number of threads to use during computation", n_threads ); + Console.WriteLine( " -ot N, --offset-t N [{0,-7:D}] time offset in milliseconds", offset_t_ms ); + Console.WriteLine( " -on N, --offset-n N [{0,-7:D}] segment index offset", offset_n ); + Console.WriteLine( " -d N, --duration N [{0,-7:D}] duration of audio to process in milliseconds", duration_ms ); + Console.WriteLine( " -mc N, --max-context N [{0,-7:D}] maximum number of text context tokens to store", max_context ); + Console.WriteLine( " -ml N, --max-len N [{0,-7:D}] maximum segment length in characters", max_len ); + Console.WriteLine( " -wt N, --word-thold N [{0,-7:F2}] word timestamp probability threshold", word_thold ); + Console.WriteLine( " -su, --speed-up [{0,-7}] speed up audio by x2 (reduced accuracy)", cstr( speed_up ) ); + Console.WriteLine( " -tr, --translate [{0,-7}] translate from source language to english", cstr( translate ) ); + Console.WriteLine( " -di, --diarize [{0,-7}] stereo audio diarization", cstr( diarize ) ); + Console.WriteLine( " -otxt, --output-txt [{0,-7}] output result in a text file", cstr( output_txt ) ); + Console.WriteLine( " -ps, --print-special [{0,-7}] print special tokens", cstr( print_special ) ); + Console.WriteLine( " -nc, --no-colors [{0,-7}] do not print colors", cstr( !print_colors ) ); + Console.WriteLine( " -nt, --no-timestamps [{0,-7}] do not print timestamps", cstr( no_timestamps ) ); + Console.WriteLine( " -l LANG, --language LANG [{0,-7}] spoken language", language.getCode() ); + Console.WriteLine( " --prompt PROMPT [ ] initial prompt" ); + Console.WriteLine( " -m FNAME, --model FNAME [{0,-7}] model path", model ); + Console.WriteLine( " -f FNAME, --file FNAME [{0,-7}] path of the input audio file", "" ); + } + } +}
\ No newline at end of file diff --git a/Examples/MicrophoneCS/MicrophoneCS.cs b/Examples/MicrophoneCS/MicrophoneCS.cs new file mode 100644 index 0000000..c095ee1 --- /dev/null +++ b/Examples/MicrophoneCS/MicrophoneCS.cs @@ -0,0 +1,56 @@ +using Whisper; + +namespace MicrophoneCS +{ + static class Program + { + static int Main( string[] args ) + { + try + { + CommandLineArgs cla; + try + { + cla = new CommandLineArgs( args ); + } + catch( OperationCanceledException ) + { + return 1; + } + const eLoggerFlags loggerFlags = eLoggerFlags.UseStandardError | eLoggerFlags.SkipFormatMessage; + Library.setLogSink( eLogLevel.Debug, loggerFlags ); + + using iMediaFoundation mf = Library.initMediaFoundation(); + CaptureDeviceId[] devices = mf.listCaptureDevices() ?? + throw new ApplicationException( "This computer has no audio capture devices" ); + + if( cla.listDevices ) + { + for( int i = 0; i < devices.Length; i++ ) + Console.WriteLine( "#{0}: {1}", i, devices[ i ].displayName ); + return 0; + } + if( cla.captureDeviceIndex < 0 || cla.captureDeviceIndex >= devices.Length ) + throw new ApplicationException( $"Capture device index is out of range; the valid range is [ 0 .. {devices.Length - 1} ]" ); + + using iAudioCapture captureDev = mf.openCaptureDevice( devices[ cla.captureDeviceIndex ] ); + + using iModel model = Library.loadModel( cla.model ); + using Context context = model.createContext(); + cla.apply( ref context.parameters ); + + CaptureThread thread = new CaptureThread( cla, context, captureDev ); + thread.join(); + + context.timingsPrint(); + return 0; + } + catch( Exception ex ) + { + // Console.WriteLine( ex.Message ); + Console.WriteLine( ex.ToString() ); + return ex.HResult; + } + } + } +}
\ No newline at end of file diff --git a/Examples/MicrophoneCS/MicrophoneCS.csproj b/Examples/MicrophoneCS/MicrophoneCS.csproj new file mode 100644 index 0000000..f417d20 --- /dev/null +++ b/Examples/MicrophoneCS/MicrophoneCS.csproj @@ -0,0 +1,27 @@ +<Project Sdk="Microsoft.NET.Sdk"> + + <PropertyGroup> + <OutputType>Exe</OutputType> + <TargetFramework>net6.0-windows</TargetFramework> + <ImplicitUsings>enable</ImplicitUsings> + <Nullable>enable</Nullable> + <CheckForOverflowUnderflow>true</CheckForOverflowUnderflow> + <AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath> + <Platforms>x64</Platforms> + </PropertyGroup> + + <ItemGroup> + <Compile Include="..\TranscribeCS\AnsiCodes.cs" Link="AnsiCodes.cs" /> + </ItemGroup> + + <ItemGroup> + <Content Include="..\..\x64\$(Configuration)\Whisper.dll" Link="Whisper.dll"> + <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> + </Content> + </ItemGroup> + + <ItemGroup> + <ProjectReference Include="..\..\WhisperNet\WhisperNet.csproj" /> + </ItemGroup> + +</Project>
\ No newline at end of file diff --git a/Examples/MicrophoneCS/TranscribeCallbacks.cs b/Examples/MicrophoneCS/TranscribeCallbacks.cs new file mode 100644 index 0000000..e4d14f4 --- /dev/null +++ b/Examples/MicrophoneCS/TranscribeCallbacks.cs @@ -0,0 +1,114 @@ +using System.Globalization; +using Whisper; + +namespace MicrophoneCS +{ + /// <summary>Implementation of Callbacks abstract class, to print these segments as soon as they’re produced by the library.</summary> + sealed class TranscribeCallbacks: Callbacks + { + readonly CommandLineArgs args; + readonly eResultFlags resultFlags; + + public TranscribeCallbacks( CommandLineArgs args ) + { + this.args = args; + resultFlags = args.resultFlags(); + Console.OutputEncoding = System.Text.Encoding.UTF8; + } + + // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] + // Lowest is red, middle is yellow, highest is green. + readonly string[] k_colors = new string[] + { + "\x1B[38;5;196m", "\x1B[38;5;202m", "\x1B[38;5;208m", "\x1B[38;5;214m", "\x1B[38;5;220m", + "\x1B[38;5;226m", "\x1B[38;5;190m", "\x1B[38;5;154m", "\x1B[38;5;118m", "\x1B[38;5;82m" + }; + + int colorIndex( in sToken tok ) + { + float p = tok.probability; + float p3 = p * p * p; + int col = (int)( p3 * k_colors.Length ); + col = Math.Clamp( col, 0, k_colors.Length - 1 ); + return col; + } + + public static string printTime( TimeSpan ts ) => + ts.ToString( "hh':'mm':'ss'.'fff", CultureInfo.InvariantCulture ); + public static string printTimeWithComma( TimeSpan ts ) => + ts.ToString( "hh':'mm':'ss','fff", CultureInfo.InvariantCulture ); + + protected override void onNewSegment( Context sender, int countNew ) + { + TranscribeResult res = sender.results( resultFlags ); + ReadOnlySpan<sToken> tokens = res.tokens; + + int s0 = res.segments.Length - countNew; + if( s0 == 0 ) + Console.WriteLine(); + + for( int i = s0; i < res.segments.Length; i++ ) + { + sSegment seg = res.segments[ i ]; + + if( args.no_timestamps ) + { + if( args.print_colors && AnsiCodes.enabled ) + { + foreach( sToken tok in res.getTokens( seg ) ) + { + if( !args.print_special && tok.hasFlag( eTokenFlags.Special ) ) + continue; + Console.Write( "{0}{1}{2}", k_colors[ colorIndex( tok ) ], tok.text, "\x1B[0m" ); + } + } + else + Console.Write( seg.text ); + Console.Out.Flush(); + continue; + } + + string speaker = ""; +#if false + if( args.diarize && pcmf32s.size() == 2 ) + { + const size_t n_samples = pcmf32s[ 0 ].size(); + const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples ); + const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples ); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for( int64_t j = is0; j < is1; j++ ) + { + energy0 += fabs( pcmf32s[ 0 ][ j ] ); + energy1 += fabs( pcmf32s[ 1 ][ j ] ); + } + + if( energy0 > 1.1 * energy1 ) + speaker = "(speaker 0)"; + else if( energy1 > 1.1 * energy0 ) + speaker = "(speaker 1)"; + else + speaker = "(speaker ?)"; + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } +#endif + if( args.print_colors && AnsiCodes.enabled ) + { + Console.Write( "[{0} --> {1}] ", printTime( seg.time.begin ), printTime( seg.time.end ) ); + foreach( sToken tok in res.getTokens( seg ) ) + { + if( !args.print_special && tok.hasFlag( eTokenFlags.Special ) ) + continue; + Console.Write( "{0}{1}{2}{3}", speaker, k_colors[ colorIndex( tok ) ], tok.text, "\x1B[0m" ); + } + Console.WriteLine(); + } + else + Console.WriteLine( "[{0} --> {1}] {2}{3}", printTime( seg.time.begin ), printTime( seg.time.end ), speaker, seg.text ); + } + } + } +}
\ No newline at end of file diff --git a/Examples/OldMain/OldMain.vcxproj b/Examples/OldMain/OldMain.vcxproj new file mode 100644 index 0000000..26f2c71 --- /dev/null +++ b/Examples/OldMain/OldMain.vcxproj @@ -0,0 +1,101 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{596f9770-9aeb-49d3-86ca-4200197df12b}</ProjectGuid> + <RootNamespace>OldMain</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <IncludePath>$(SolutionDir)Whisper\Source\;$(IncludePath)</IncludePath> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <IncludePath>$(SolutionDir)Whisper\Source\;$(IncludePath)</IncludePath> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <EnableEnhancedInstructionSet>AdvancedVectorExtensions2</EnableEnhancedInstructionSet> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <EnableEnhancedInstructionSet>AdvancedVectorExtensions2</EnableEnhancedInstructionSet> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + <LinkTimeCodeGeneration>UseLinkTimeCodeGeneration</LinkTimeCodeGeneration> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ClCompile Include="..\..\Whisper\source\ggml.c"> + <ExcludedFromBuild>true</ExcludedFromBuild> + </ClCompile> + <ClCompile Include="..\..\Whisper\source\ggmlMsvc.c" /> + <ClCompile Include="..\..\Whisper\source\whisper.cpp" /> + <ClCompile Include="main.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="..\..\Whisper\source\ggml.h" /> + <ClInclude Include="..\..\Whisper\source\whisper.h" /> + <ClInclude Include="dr_wav.h" /> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/Examples/OldMain/OldMain.vcxproj.filters b/Examples/OldMain/OldMain.vcxproj.filters new file mode 100644 index 0000000..78f29f0 --- /dev/null +++ b/Examples/OldMain/OldMain.vcxproj.filters @@ -0,0 +1,14 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <ClCompile Include="..\..\Whisper\source\ggmlMsvc.c" /> + <ClCompile Include="..\..\Whisper\source\whisper.cpp" /> + <ClCompile Include="main.cpp" /> + <ClCompile Include="..\..\Whisper\source\ggml.c" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="..\..\Whisper\source\ggml.h" /> + <ClInclude Include="..\..\Whisper\source\whisper.h" /> + <ClInclude Include="dr_wav.h" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/Examples/OldMain/dr_wav.h b/Examples/OldMain/dr_wav.h new file mode 100644 index 0000000..fd3e95b --- /dev/null +++ b/Examples/OldMain/dr_wav.h @@ -0,0 +1,6434 @@ +/* +WAV audio loader and writer. Choice of public domain or MIT-0. See license statements at the end of this file. +dr_wav - v0.12.16 - 2020-12-02 + +David Reid - mackron@gmail.com + +GitHub: https://github.com/mackron/dr_libs +*/ + +/* +RELEASE NOTES - VERSION 0.12 +============================ +Version 0.12 includes breaking changes to custom chunk handling. + + +Changes to Chunk Callback +------------------------- +dr_wav supports the ability to fire a callback when a chunk is encounted (except for WAVE and FMT chunks). The callback has been updated to include both the +container (RIFF or Wave64) and the FMT chunk which contains information about the format of the data in the wave file. + +Previously, there was no direct way to determine the container, and therefore no way to discriminate against the different IDs in the chunk header (RIFF and +Wave64 containers encode chunk ID's differently). The `container` parameter can be used to know which ID to use. + +Sometimes it can be useful to know the data format at the time the chunk callback is fired. A pointer to a `drwav_fmt` object is now passed into the chunk +callback which will give you information about the data format. To determine the sample format, use `drwav_fmt_get_format()`. This will return one of the +`DR_WAVE_FORMAT_*` tokens. +*/ + +/* +Introduction +============ +This is a single file library. To use it, do something like the following in one .c file. + + ```c + #define DR_WAV_IMPLEMENTATION + #include "dr_wav.h" + ``` + +You can then #include this file in other parts of the program as you would with any other header file. Do something like the following to read audio data: + + ```c + drwav wav; + if (!drwav_init_file(&wav, "my_song.wav", NULL)) { + // Error opening WAV file. + } + + drwav_int32* pDecodedInterleavedPCMFrames = malloc(wav.totalPCMFrameCount * wav.channels * sizeof(drwav_int32)); + size_t numberOfSamplesActuallyDecoded = drwav_read_pcm_frames_s32(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames); + + ... + + drwav_uninit(&wav); + ``` + +If you just want to quickly open and read the audio data in a single operation you can do something like this: + + ```c + unsigned int channels; + unsigned int sampleRate; + drwav_uint64 totalPCMFrameCount; + float* pSampleData = drwav_open_file_and_read_pcm_frames_f32("my_song.wav", &channels, &sampleRate, &totalPCMFrameCount, NULL); + if (pSampleData == NULL) { + // Error opening and reading WAV file. + } + + ... + + drwav_free(pSampleData); + ``` + +The examples above use versions of the API that convert the audio data to a consistent format (32-bit signed PCM, in this case), but you can still output the +audio data in its internal format (see notes below for supported formats): + + ```c + size_t framesRead = drwav_read_pcm_frames(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames); + ``` + +You can also read the raw bytes of audio data, which could be useful if dr_wav does not have native support for a particular data format: + + ```c + size_t bytesRead = drwav_read_raw(&wav, bytesToRead, pRawDataBuffer); + ``` + +dr_wav can also be used to output WAV files. This does not currently support compressed formats. To use this, look at `drwav_init_write()`, +`drwav_init_file_write()`, etc. Use `drwav_write_pcm_frames()` to write samples, or `drwav_write_raw()` to write raw data in the "data" chunk. + + ```c + drwav_data_format format; + format.container = drwav_container_riff; // <-- drwav_container_riff = normal WAV files, drwav_container_w64 = Sony Wave64. + format.format = DR_WAVE_FORMAT_PCM; // <-- Any of the DR_WAVE_FORMAT_* codes. + format.channels = 2; + format.sampleRate = 44100; + format.bitsPerSample = 16; + drwav_init_file_write(&wav, "data/recording.wav", &format, NULL); + + ... + + drwav_uint64 framesWritten = drwav_write_pcm_frames(pWav, frameCount, pSamples); + ``` + +dr_wav has seamless support the Sony Wave64 format. The decoder will automatically detect it and it should Just Work without any manual intervention. + + +Build Options +============= +#define these options before including this file. + +#define DR_WAV_NO_CONVERSION_API + Disables conversion APIs such as `drwav_read_pcm_frames_f32()` and `drwav_s16_to_f32()`. + +#define DR_WAV_NO_STDIO + Disables APIs that initialize a decoder from a file such as `drwav_init_file()`, `drwav_init_file_write()`, etc. + + + +Notes +===== +- Samples are always interleaved. +- The default read function does not do any data conversion. Use `drwav_read_pcm_frames_f32()`, `drwav_read_pcm_frames_s32()` and `drwav_read_pcm_frames_s16()` + to read and convert audio data to 32-bit floating point, signed 32-bit integer and signed 16-bit integer samples respectively. Tested and supported internal + formats include the following: + - Unsigned 8-bit PCM + - Signed 12-bit PCM + - Signed 16-bit PCM + - Signed 24-bit PCM + - Signed 32-bit PCM + - IEEE 32-bit floating point + - IEEE 64-bit floating point + - A-law and u-law + - Microsoft ADPCM + - IMA ADPCM (DVI, format code 0x11) +- dr_wav will try to read the WAV file as best it can, even if it's not strictly conformant to the WAV format. +*/ + +#ifndef dr_wav_h +#define dr_wav_h + +#ifdef __cplusplus +extern "C" { +#endif + +#define DRWAV_STRINGIFY(x) #x +#define DRWAV_XSTRINGIFY(x) DRWAV_STRINGIFY(x) + +#define DRWAV_VERSION_MAJOR 0 +#define DRWAV_VERSION_MINOR 12 +#define DRWAV_VERSION_REVISION 16 +#define DRWAV_VERSION_STRING DRWAV_XSTRINGIFY(DRWAV_VERSION_MAJOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_MINOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_REVISION) + +#include <stddef.h> /* For size_t. */ + +/* Sized types. */ +typedef signed char drwav_int8; +typedef unsigned char drwav_uint8; +typedef signed short drwav_int16; +typedef unsigned short drwav_uint16; +typedef signed int drwav_int32; +typedef unsigned int drwav_uint32; +#if defined(_MSC_VER) + typedef signed __int64 drwav_int64; + typedef unsigned __int64 drwav_uint64; +#else + #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wlong-long" + #if defined(__clang__) + #pragma GCC diagnostic ignored "-Wc++11-long-long" + #endif + #endif + typedef signed long long drwav_int64; + typedef unsigned long long drwav_uint64; + #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) + #pragma GCC diagnostic pop + #endif +#endif +#if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__)) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) + typedef drwav_uint64 drwav_uintptr; +#else + typedef drwav_uint32 drwav_uintptr; +#endif +typedef drwav_uint8 drwav_bool8; +typedef drwav_uint32 drwav_bool32; +#define DRWAV_TRUE 1 +#define DRWAV_FALSE 0 + +#if !defined(DRWAV_API) + #if defined(DRWAV_DLL) + #if defined(_WIN32) + #define DRWAV_DLL_IMPORT __declspec(dllimport) + #define DRWAV_DLL_EXPORT __declspec(dllexport) + #define DRWAV_DLL_PRIVATE static + #else + #if defined(__GNUC__) && __GNUC__ >= 4 + #define DRWAV_DLL_IMPORT __attribute__((visibility("default"))) + #define DRWAV_DLL_EXPORT __attribute__((visibility("default"))) + #define DRWAV_DLL_PRIVATE __attribute__((visibility("hidden"))) + #else + #define DRWAV_DLL_IMPORT + #define DRWAV_DLL_EXPORT + #define DRWAV_DLL_PRIVATE static + #endif + #endif + + #if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION) + #define DRWAV_API DRWAV_DLL_EXPORT + #else + #define DRWAV_API DRWAV_DLL_IMPORT + #endif + #define DRWAV_PRIVATE DRWAV_DLL_PRIVATE + #else + #define DRWAV_API extern + #define DRWAV_PRIVATE static + #endif +#endif + +typedef drwav_int32 drwav_result; +#define DRWAV_SUCCESS 0 +#define DRWAV_ERROR -1 /* A generic error. */ +#define DRWAV_INVALID_ARGS -2 +#define DRWAV_INVALID_OPERATION -3 +#define DRWAV_OUT_OF_MEMORY -4 +#define DRWAV_OUT_OF_RANGE -5 +#define DRWAV_ACCESS_DENIED -6 +#define DRWAV_DOES_NOT_EXIST -7 +#define DRWAV_ALREADY_EXISTS -8 +#define DRWAV_TOO_MANY_OPEN_FILES -9 +#define DRWAV_INVALID_FILE -10 +#define DRWAV_TOO_BIG -11 +#define DRWAV_PATH_TOO_LONG -12 +#define DRWAV_NAME_TOO_LONG -13 +#define DRWAV_NOT_DIRECTORY -14 +#define DRWAV_IS_DIRECTORY -15 +#define DRWAV_DIRECTORY_NOT_EMPTY -16 +#define DRWAV_END_OF_FILE -17 +#define DRWAV_NO_SPACE -18 +#define DRWAV_BUSY -19 +#define DRWAV_IO_ERROR -20 +#define DRWAV_INTERRUPT -21 +#define DRWAV_UNAVAILABLE -22 +#define DRWAV_ALREADY_IN_USE -23 +#define DRWAV_BAD_ADDRESS -24 +#define DRWAV_BAD_SEEK -25 +#define DRWAV_BAD_PIPE -26 +#define DRWAV_DEADLOCK -27 +#define DRWAV_TOO_MANY_LINKS -28 +#define DRWAV_NOT_IMPLEMENTED -29 +#define DRWAV_NO_MESSAGE -30 +#define DRWAV_BAD_MESSAGE -31 +#define DRWAV_NO_DATA_AVAILABLE -32 +#define DRWAV_INVALID_DATA -33 +#define DRWAV_TIMEOUT -34 +#define DRWAV_NO_NETWORK -35 +#define DRWAV_NOT_UNIQUE -36 +#define DRWAV_NOT_SOCKET -37 +#define DRWAV_NO_ADDRESS -38 +#define DRWAV_BAD_PROTOCOL -39 +#define DRWAV_PROTOCOL_UNAVAILABLE -40 +#define DRWAV_PROTOCOL_NOT_SUPPORTED -41 +#define DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED -42 +#define DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED -43 +#define DRWAV_SOCKET_NOT_SUPPORTED -44 +#define DRWAV_CONNECTION_RESET -45 +#define DRWAV_ALREADY_CONNECTED -46 +#define DRWAV_NOT_CONNECTED -47 +#define DRWAV_CONNECTION_REFUSED -48 +#define DRWAV_NO_HOST -49 +#define DRWAV_IN_PROGRESS -50 +#define DRWAV_CANCELLED -51 +#define DRWAV_MEMORY_ALREADY_MAPPED -52 +#define DRWAV_AT_END -53 + +/* Common data formats. */ +#define DR_WAVE_FORMAT_PCM 0x1 +#define DR_WAVE_FORMAT_ADPCM 0x2 +#define DR_WAVE_FORMAT_IEEE_FLOAT 0x3 +#define DR_WAVE_FORMAT_ALAW 0x6 +#define DR_WAVE_FORMAT_MULAW 0x7 +#define DR_WAVE_FORMAT_DVI_ADPCM 0x11 +#define DR_WAVE_FORMAT_EXTENSIBLE 0xFFFE + +/* Constants. */ +#ifndef DRWAV_MAX_SMPL_LOOPS +#define DRWAV_MAX_SMPL_LOOPS 1 +#endif + +/* Flags to pass into drwav_init_ex(), etc. */ +#define DRWAV_SEQUENTIAL 0x00000001 + +DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision); +DRWAV_API const char* drwav_version_string(void); + +typedef enum +{ + drwav_seek_origin_start, + drwav_seek_origin_current +} drwav_seek_origin; + +typedef enum +{ + drwav_container_riff, + drwav_container_w64, + drwav_container_rf64 +} drwav_container; + +typedef struct +{ + union + { + drwav_uint8 fourcc[4]; + drwav_uint8 guid[16]; + } id; + + /* The size in bytes of the chunk. */ + drwav_uint64 sizeInBytes; + + /* + RIFF = 2 byte alignment. + W64 = 8 byte alignment. + */ + unsigned int paddingSize; +} drwav_chunk_header; + +typedef struct +{ + /* + The format tag exactly as specified in the wave file's "fmt" chunk. This can be used by applications + that require support for data formats not natively supported by dr_wav. + */ + drwav_uint16 formatTag; + + /* The number of channels making up the audio data. When this is set to 1 it is mono, 2 is stereo, etc. */ + drwav_uint16 channels; + + /* The sample rate. Usually set to something like 44100. */ + drwav_uint32 sampleRate; + + /* Average bytes per second. You probably don't need this, but it's left here for informational purposes. */ + drwav_uint32 avgBytesPerSec; + + /* Block align. This is equal to the number of channels * bytes per sample. */ + drwav_uint16 blockAlign; + + /* Bits per sample. */ + drwav_uint16 bitsPerSample; + + /* The size of the extended data. Only used internally for validation, but left here for informational purposes. */ + drwav_uint16 extendedSize; + + /* + The number of valid bits per sample. When <formatTag> is equal to WAVE_FORMAT_EXTENSIBLE, <bitsPerSample> + is always rounded up to the nearest multiple of 8. This variable contains information about exactly how + many bits are valid per sample. Mainly used for informational purposes. + */ + drwav_uint16 validBitsPerSample; + + /* The channel mask. Not used at the moment. */ + drwav_uint32 channelMask; + + /* The sub-format, exactly as specified by the wave file. */ + drwav_uint8 subFormat[16]; +} drwav_fmt; + +DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT); + + +/* +Callback for when data is read. Return value is the number of bytes actually read. + +pUserData [in] The user data that was passed to drwav_init() and family. +pBufferOut [out] The output buffer. +bytesToRead [in] The number of bytes to read. + +Returns the number of bytes actually read. + +A return value of less than bytesToRead indicates the end of the stream. Do _not_ return from this callback until +either the entire bytesToRead is filled or you have reached the end of the stream. +*/ +typedef size_t (* drwav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); + +/* +Callback for when data is written. Returns value is the number of bytes actually written. + +pUserData [in] The user data that was passed to drwav_init_write() and family. +pData [out] A pointer to the data to write. +bytesToWrite [in] The number of bytes to write. + +Returns the number of bytes actually written. + +If the return value differs from bytesToWrite, it indicates an error. +*/ +typedef size_t (* drwav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite); + +/* +Callback for when data needs to be seeked. + +pUserData [in] The user data that was passed to drwav_init() and family. +offset [in] The number of bytes to move, relative to the origin. Will never be negative. +origin [in] The origin of the seek - the current position or the start of the stream. + +Returns whether or not the seek was successful. + +Whether or not it is relative to the beginning or current position is determined by the "origin" parameter which will be either drwav_seek_origin_start or +drwav_seek_origin_current. +*/ +typedef drwav_bool32 (* drwav_seek_proc)(void* pUserData, int offset, drwav_seek_origin origin); + +/* +Callback for when drwav_init_ex() finds a chunk. + +pChunkUserData [in] The user data that was passed to the pChunkUserData parameter of drwav_init_ex() and family. +onRead [in] A pointer to the function to call when reading. +onSeek [in] A pointer to the function to call when seeking. +pReadSeekUserData [in] The user data that was passed to the pReadSeekUserData parameter of drwav_init_ex() and family. +pChunkHeader [in] A pointer to an object containing basic header information about the chunk. Use this to identify the chunk. +container [in] Whether or not the WAV file is a RIFF or Wave64 container. If you're unsure of the difference, assume RIFF. +pFMT [in] A pointer to the object containing the contents of the "fmt" chunk. + +Returns the number of bytes read + seeked. + +To read data from the chunk, call onRead(), passing in pReadSeekUserData as the first parameter. Do the same for seeking with onSeek(). The return value must +be the total number of bytes you have read _plus_ seeked. + +Use the `container` argument to discriminate the fields in `pChunkHeader->id`. If the container is `drwav_container_riff` or `drwav_container_rf64` you should +use `id.fourcc`, otherwise you should use `id.guid`. + +The `pFMT` parameter can be used to determine the data format of the wave file. Use `drwav_fmt_get_format()` to get the sample format, which will be one of the +`DR_WAVE_FORMAT_*` identifiers. + +The read pointer will be sitting on the first byte after the chunk's header. You must not attempt to read beyond the boundary of the chunk. +*/ +typedef drwav_uint64 (* drwav_chunk_proc)(void* pChunkUserData, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_chunk_header* pChunkHeader, drwav_container container, const drwav_fmt* pFMT); + +typedef struct +{ + void* pUserData; + void* (* onMalloc)(size_t sz, void* pUserData); + void* (* onRealloc)(void* p, size_t sz, void* pUserData); + void (* onFree)(void* p, void* pUserData); +} drwav_allocation_callbacks; + +/* Structure for internal use. Only used for loaders opened with drwav_init_memory(). */ +typedef struct +{ + const drwav_uint8* data; + size_t dataSize; + size_t currentReadPos; +} drwav__memory_stream; + +/* Structure for internal use. Only used for writers opened with drwav_init_memory_write(). */ +typedef struct +{ + void** ppData; + size_t* pDataSize; + size_t dataSize; + size_t dataCapacity; + size_t currentWritePos; +} drwav__memory_stream_write; + +typedef struct +{ + drwav_container container; /* RIFF, W64. */ + drwav_uint32 format; /* DR_WAVE_FORMAT_* */ + drwav_uint32 channels; + drwav_uint32 sampleRate; + drwav_uint32 bitsPerSample; +} drwav_data_format; + + +/* See the following for details on the 'smpl' chunk: https://sites.google.com/site/musicgapi/technical-documents/wav-file-format#smpl */ +typedef struct +{ + drwav_uint32 cuePointId; + drwav_uint32 type; + drwav_uint32 start; + drwav_uint32 end; + drwav_uint32 fraction; + drwav_uint32 playCount; +} drwav_smpl_loop; + + typedef struct +{ + drwav_uint32 manufacturer; + drwav_uint32 product; + drwav_uint32 samplePeriod; + drwav_uint32 midiUnityNotes; + drwav_uint32 midiPitchFraction; + drwav_uint32 smpteFormat; + drwav_uint32 smpteOffset; + drwav_uint32 numSampleLoops; + drwav_uint32 samplerData; + drwav_smpl_loop loops[DRWAV_MAX_SMPL_LOOPS]; +} drwav_smpl; + +typedef struct +{ + /* A pointer to the function to call when more data is needed. */ + drwav_read_proc onRead; + + /* A pointer to the function to call when data needs to be written. Only used when the drwav object is opened in write mode. */ + drwav_write_proc onWrite; + + /* A pointer to the function to call when the wav file needs to be seeked. */ + drwav_seek_proc onSeek; + + /* The user data to pass to callbacks. */ + void* pUserData; + + /* Allocation callbacks. */ + drwav_allocation_callbacks allocationCallbacks; + + + /* Whether or not the WAV file is formatted as a standard RIFF file or W64. */ + drwav_container container; + + + /* Structure containing format information exactly as specified by the wav file. */ + drwav_fmt fmt; + + /* The sample rate. Will be set to something like 44100. */ + drwav_uint32 sampleRate; + + /* The number of channels. This will be set to 1 for monaural streams, 2 for stereo, etc. */ + drwav_uint16 channels; + + /* The bits per sample. Will be set to something like 16, 24, etc. */ + drwav_uint16 bitsPerSample; + + /* Equal to fmt.formatTag, or the value specified by fmt.subFormat if fmt.formatTag is equal to 65534 (WAVE_FORMAT_EXTENSIBLE). */ + drwav_uint16 translatedFormatTag; + + /* The total number of PCM frames making up the audio data. */ + drwav_uint64 totalPCMFrameCount; + + + /* The size in bytes of the data chunk. */ + drwav_uint64 dataChunkDataSize; + + /* The position in the stream of the first byte of the data chunk. This is used for seeking. */ + drwav_uint64 dataChunkDataPos; + + /* The number of bytes remaining in the data chunk. */ + drwav_uint64 bytesRemaining; + + + /* + Only used in sequential write mode. Keeps track of the desired size of the "data" chunk at the point of initialization time. Always + set to 0 for non-sequential writes and when the drwav object is opened in read mode. Used for validation. + */ + drwav_uint64 dataChunkDataSizeTargetWrite; + + /* Keeps track of whether or not the wav writer was initialized in sequential mode. */ + drwav_bool32 isSequentialWrite; + + + /* smpl chunk. */ + drwav_smpl smpl; + + + /* A hack to avoid a DRWAV_MALLOC() when opening a decoder with drwav_init_memory(). */ + drwav__memory_stream memoryStream; + drwav__memory_stream_write memoryStreamWrite; + + /* Generic data for compressed formats. This data is shared across all block-compressed formats. */ + struct + { + drwav_uint64 iCurrentPCMFrame; /* The index of the next PCM frame that will be read by drwav_read_*(). This is used with "totalPCMFrameCount" to ensure we don't read excess samples at the end of the last block. */ + } compressed; + + /* Microsoft ADPCM specific data. */ + struct + { + drwav_uint32 bytesRemainingInBlock; + drwav_uint16 predictor[2]; + drwav_int32 delta[2]; + drwav_int32 cachedFrames[4]; /* Samples are stored in this cache during decoding. */ + drwav_uint32 cachedFrameCount; + drwav_int32 prevFrames[2][2]; /* The previous 2 samples for each channel (2 channels at most). */ + } msadpcm; + + /* IMA ADPCM specific data. */ + struct + { + drwav_uint32 bytesRemainingInBlock; + drwav_int32 predictor[2]; + drwav_int32 stepIndex[2]; + drwav_int32 cachedFrames[16]; /* Samples are stored in this cache during decoding. */ + drwav_uint32 cachedFrameCount; + } ima; +} drwav; + + +/* +Initializes a pre-allocated drwav object for reading. + +pWav [out] A pointer to the drwav object being initialized. +onRead [in] The function to call when data needs to be read from the client. +onSeek [in] The function to call when the read position of the client data needs to move. +onChunk [in, optional] The function to call when a chunk is enumerated at initialized time. +pUserData, pReadSeekUserData [in, optional] A pointer to application defined data that will be passed to onRead and onSeek. +pChunkUserData [in, optional] A pointer to application defined data that will be passed to onChunk. +flags [in, optional] A set of flags for controlling how things are loaded. + +Returns true if successful; false otherwise. + +Close the loader with drwav_uninit(). + +This is the lowest level function for initializing a WAV file. You can also use drwav_init_file() and drwav_init_memory() +to open the stream from a file or from a block of memory respectively. + +Possible values for flags: + DRWAV_SEQUENTIAL: Never perform a backwards seek while loading. This disables the chunk callback and will cause this function + to return as soon as the data chunk is found. Any chunks after the data chunk will be ignored. + +drwav_init() is equivalent to "drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0);". + +The onChunk callback is not called for the WAVE or FMT chunks. The contents of the FMT chunk can be read from pWav->fmt +after the function returns. + +See also: drwav_init_file(), drwav_init_memory(), drwav_uninit() +*/ +DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Initializes a pre-allocated drwav object for writing. + +onWrite [in] The function to call when data needs to be written. +onSeek [in] The function to call when the write position needs to move. +pUserData [in, optional] A pointer to application defined data that will be passed to onWrite and onSeek. + +Returns true if successful; false otherwise. + +Close the writer with drwav_uninit(). + +This is the lowest level function for initializing a WAV file. You can also use drwav_init_file_write() and drwav_init_memory_write() +to open the stream from a file or from a block of memory respectively. + +If the total sample count is known, you can use drwav_init_write_sequential(). This avoids the need for dr_wav to perform +a post-processing step for storing the total sample count and the size of the data chunk which requires a backwards seek. + +See also: drwav_init_file_write(), drwav_init_memory_write(), drwav_uninit() +*/ +DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Utility function to determine the target size of the entire data to be written (including all headers and chunks). + +Returns the target size in bytes. + +Useful if the application needs to know the size to allocate. + +Only writing to the RIFF chunk and one data chunk is currently supported. + +See also: drwav_init_write(), drwav_init_file_write(), drwav_init_memory_write() +*/ +DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount); + +/* +Uninitializes the given drwav object. + +Use this only for objects initialized with drwav_init*() functions (drwav_init(), drwav_init_ex(), drwav_init_write(), drwav_init_write_sequential()). +*/ +DRWAV_API drwav_result drwav_uninit(drwav* pWav); + + +/* +Reads raw audio data. + +This is the lowest level function for reading audio data. It simply reads the given number of +bytes of the raw internal sample data. + +Consider using drwav_read_pcm_frames_s16(), drwav_read_pcm_frames_s32() or drwav_read_pcm_frames_f32() for +reading sample data in a consistent format. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of bytes actually read. +*/ +DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut); + +/* +Reads up to the specified number of PCM frames from the WAV file. + +The output data will be in the file's internal format, converted to native-endian byte order. Use +drwav_read_pcm_frames_s16/f32/s32() to read data in a specific format. + +If the return value is less than <framesToRead> it means the end of the file has been reached or +you have requested more PCM frames than can possibly fit in the output buffer. + +This function will only work when sample data is of a fixed size and uncompressed. If you are +using a compressed format consider using drwav_read_raw() or drwav_read_pcm_frames_s16/s32/f32(). + +pBufferOut can be NULL in which case a seek will be performed. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); + +/* +Seeks to the given PCM frame. + +Returns true if successful; false otherwise. +*/ +DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex); + + +/* +Writes raw audio data. + +Returns the number of bytes actually written. If this differs from bytesToWrite, it indicates an error. +*/ +DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData); + +/* +Writes PCM frames. + +Returns the number of PCM frames written. + +Input samples need to be in native-endian byte order. On big-endian architectures the input data will be converted to +little-endian. Use drwav_write_raw() to write raw audio data without performing any conversion. +*/ +DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); +DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); +DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); + + +/* Conversion Utilities */ +#ifndef DR_WAV_NO_CONVERSION_API + +/* +Reads a chunk of audio data and converts it to signed 16-bit PCM samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than <framesToRead> it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 32-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 32-bit floating point samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + + +/* +Reads a chunk of audio data and converts it to IEEE 32-bit floating point samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than <framesToRead> it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 16-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 32-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + + +/* +Reads a chunk of audio data and converts it to signed 32-bit PCM samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than <framesToRead> it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 16-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 32-bit floating point samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +#endif /* DR_WAV_NO_CONVERSION_API */ + + +/* High-Level Convenience Helpers */ + +#ifndef DR_WAV_NO_STDIO +/* +Helper for initializing a wave file for reading using stdio. + +This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav +objects because the operating system may restrict the number of file handles an application can have open at +any given time. +*/ +DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Helper for initializing a wave file for writing using stdio. + +This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav +objects because the operating system may restrict the number of file handles an application can have open at +any given time. +*/ +DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif /* DR_WAV_NO_STDIO */ + +/* +Helper for initializing a loader from a pre-allocated memory buffer. + +This does not create a copy of the data. It is up to the application to ensure the buffer remains valid for +the lifetime of the drwav object. + +The buffer should contain the contents of the entire wave file, not just the sample data. +*/ +DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Helper for initializing a writer which outputs data to a memory buffer. + +dr_wav will manage the memory allocations, however it is up to the caller to free the data with drwav_free(). + +The buffer will remain allocated even after drwav_uninit() is called. The buffer should not be considered valid +until after drwav_uninit() has been called. +*/ +DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); + + +#ifndef DR_WAV_NO_CONVERSION_API +/* +Opens and reads an entire wav file in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#ifndef DR_WAV_NO_STDIO +/* +Opens and decodes an entire wav file in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif +/* +Opens and decodes an entire wav file from a block of memory in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif + +/* Frees data that was allocated internally by dr_wav. */ +DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* Converts bytes from a wav stream to a sized type of native endian. */ +DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data); +DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data); +DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data); +DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data); +DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data); +DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data); + +/* Compares a GUID for the purpose of checking the type of a Wave64 chunk. */ +DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]); + +/* Compares a four-character-code for the purpose of checking the type of a RIFF chunk. */ +DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b); + +#ifdef __cplusplus +} +#endif +#endif /* dr_wav_h */ + + +/************************************************************************************************************************************************************ + ************************************************************************************************************************************************************ + + IMPLEMENTATION + + ************************************************************************************************************************************************************ + ************************************************************************************************************************************************************/ +#if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION) +#ifndef dr_wav_c +#define dr_wav_c + +#include <stdlib.h> +#include <string.h> /* For memcpy(), memset() */ +#include <limits.h> /* For INT_MAX */ + +#ifndef DR_WAV_NO_STDIO +#include <stdio.h> +#include <wchar.h> +#endif + +/* Standard library stuff. */ +#ifndef DRWAV_ASSERT +#include <assert.h> +#define DRWAV_ASSERT(expression) assert(expression) +#endif +#ifndef DRWAV_MALLOC +#define DRWAV_MALLOC(sz) malloc((sz)) +#endif +#ifndef DRWAV_REALLOC +#define DRWAV_REALLOC(p, sz) realloc((p), (sz)) +#endif +#ifndef DRWAV_FREE +#define DRWAV_FREE(p) free((p)) +#endif +#ifndef DRWAV_COPY_MEMORY +#define DRWAV_COPY_MEMORY(dst, src, sz) memcpy((dst), (src), (sz)) +#endif +#ifndef DRWAV_ZERO_MEMORY +#define DRWAV_ZERO_MEMORY(p, sz) memset((p), 0, (sz)) +#endif +#ifndef DRWAV_ZERO_OBJECT +#define DRWAV_ZERO_OBJECT(p) DRWAV_ZERO_MEMORY((p), sizeof(*p)) +#endif + +#define drwav_countof(x) (sizeof(x) / sizeof(x[0])) +#define drwav_align(x, a) ((((x) + (a) - 1) / (a)) * (a)) +#define drwav_min(a, b) (((a) < (b)) ? (a) : (b)) +#define drwav_max(a, b) (((a) > (b)) ? (a) : (b)) +#define drwav_clamp(x, lo, hi) (drwav_max((lo), drwav_min((hi), (x)))) + +#define DRWAV_MAX_SIMD_VECTOR_SIZE 64 /* 64 for AVX-512 in the future. */ + +/* CPU architecture. */ +#if defined(__x86_64__) || defined(_M_X64) + #define DRWAV_X64 +#elif defined(__i386) || defined(_M_IX86) + #define DRWAV_X86 +#elif defined(__arm__) || defined(_M_ARM) + #define DRWAV_ARM +#endif + +#ifdef _MSC_VER + #define DRWAV_INLINE __forceinline +#elif defined(__GNUC__) + /* + I've had a bug report where GCC is emitting warnings about functions possibly not being inlineable. This warning happens when + the __attribute__((always_inline)) attribute is defined without an "inline" statement. I think therefore there must be some + case where "__inline__" is not always defined, thus the compiler emitting these warnings. When using -std=c89 or -ansi on the + command line, we cannot use the "inline" keyword and instead need to use "__inline__". In an attempt to work around this issue + I am using "__inline__" only when we're compiling in strict ANSI mode. + */ + #if defined(__STRICT_ANSI__) + #define DRWAV_INLINE __inline__ __attribute__((always_inline)) + #else + #define DRWAV_INLINE inline __attribute__((always_inline)) + #endif +#elif defined(__WATCOMC__) + #define DRWAV_INLINE __inline +#else + #define DRWAV_INLINE +#endif + +#if defined(SIZE_MAX) + #define DRWAV_SIZE_MAX SIZE_MAX +#else + #if defined(_WIN64) || defined(_LP64) || defined(__LP64__) + #define DRWAV_SIZE_MAX ((drwav_uint64)0xFFFFFFFFFFFFFFFF) + #else + #define DRWAV_SIZE_MAX 0xFFFFFFFF + #endif +#endif + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #define DRWAV_HAS_BYTESWAP64_INTRINSIC +#elif defined(__clang__) + #if defined(__has_builtin) + #if __has_builtin(__builtin_bswap16) + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #endif + #if __has_builtin(__builtin_bswap32) + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #endif + #if __has_builtin(__builtin_bswap64) + #define DRWAV_HAS_BYTESWAP64_INTRINSIC + #endif + #endif +#elif defined(__GNUC__) + #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #define DRWAV_HAS_BYTESWAP64_INTRINSIC + #endif + #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)) + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #endif +#endif + +DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision) +{ + if (pMajor) { + *pMajor = DRWAV_VERSION_MAJOR; + } + + if (pMinor) { + *pMinor = DRWAV_VERSION_MINOR; + } + + if (pRevision) { + *pRevision = DRWAV_VERSION_REVISION; + } +} + +DRWAV_API const char* drwav_version_string(void) +{ + return DRWAV_VERSION_STRING; +} + +/* +These limits are used for basic validation when initializing the decoder. If you exceed these limits, first of all: what on Earth are +you doing?! (Let me know, I'd be curious!) Second, you can adjust these by #define-ing them before the dr_wav implementation. +*/ +#ifndef DRWAV_MAX_SAMPLE_RATE +#define DRWAV_MAX_SAMPLE_RATE 384000 +#endif +#ifndef DRWAV_MAX_CHANNELS +#define DRWAV_MAX_CHANNELS 256 +#endif +#ifndef DRWAV_MAX_BITS_PER_SAMPLE +#define DRWAV_MAX_BITS_PER_SAMPLE 64 +#endif + +static const drwav_uint8 drwavGUID_W64_RIFF[16] = {0x72,0x69,0x66,0x66, 0x2E,0x91, 0xCF,0x11, 0xA5,0xD6, 0x28,0xDB,0x04,0xC1,0x00,0x00}; /* 66666972-912E-11CF-A5D6-28DB04C10000 */ +static const drwav_uint8 drwavGUID_W64_WAVE[16] = {0x77,0x61,0x76,0x65, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 65766177-ACF3-11D3-8CD1-00C04F8EDB8A */ +/*static const drwav_uint8 drwavGUID_W64_JUNK[16] = {0x6A,0x75,0x6E,0x6B, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A};*/ /* 6B6E756A-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_FMT [16] = {0x66,0x6D,0x74,0x20, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 20746D66-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_FACT[16] = {0x66,0x61,0x63,0x74, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 74636166-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_DATA[16] = {0x64,0x61,0x74,0x61, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 61746164-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_SMPL[16] = {0x73,0x6D,0x70,0x6C, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 6C706D73-ACF3-11D3-8CD1-00C04F8EDB8A */ + +static DRWAV_INLINE drwav_bool32 drwav__guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]) +{ + int i; + for (i = 0; i < 16; i += 1) { + if (a[i] != b[i]) { + return DRWAV_FALSE; + } + } + + return DRWAV_TRUE; +} + +static DRWAV_INLINE drwav_bool32 drwav__fourcc_equal(const drwav_uint8* a, const char* b) +{ + return + a[0] == b[0] && + a[1] == b[1] && + a[2] == b[2] && + a[3] == b[3]; +} + + + +static DRWAV_INLINE int drwav__is_little_endian(void) +{ +#if defined(DRWAV_X86) || defined(DRWAV_X64) + return DRWAV_TRUE; +#elif defined(__BYTE_ORDER) && defined(__LITTLE_ENDIAN) && __BYTE_ORDER == __LITTLE_ENDIAN + return DRWAV_TRUE; +#else + int n = 1; + return (*(char*)&n) == 1; +#endif +} + +static DRWAV_INLINE drwav_uint16 drwav__bytes_to_u16(const drwav_uint8* data) +{ + return (data[0] << 0) | (data[1] << 8); +} + +static DRWAV_INLINE drwav_int16 drwav__bytes_to_s16(const drwav_uint8* data) +{ + return (short)drwav__bytes_to_u16(data); +} + +static DRWAV_INLINE drwav_uint32 drwav__bytes_to_u32(const drwav_uint8* data) +{ + return (data[0] << 0) | (data[1] << 8) | (data[2] << 16) | (data[3] << 24); +} + +static DRWAV_INLINE drwav_int32 drwav__bytes_to_s32(const drwav_uint8* data) +{ + return (drwav_int32)drwav__bytes_to_u32(data); +} + +static DRWAV_INLINE drwav_uint64 drwav__bytes_to_u64(const drwav_uint8* data) +{ + return + ((drwav_uint64)data[0] << 0) | ((drwav_uint64)data[1] << 8) | ((drwav_uint64)data[2] << 16) | ((drwav_uint64)data[3] << 24) | + ((drwav_uint64)data[4] << 32) | ((drwav_uint64)data[5] << 40) | ((drwav_uint64)data[6] << 48) | ((drwav_uint64)data[7] << 56); +} + +static DRWAV_INLINE drwav_int64 drwav__bytes_to_s64(const drwav_uint8* data) +{ + return (drwav_int64)drwav__bytes_to_u64(data); +} + +static DRWAV_INLINE void drwav__bytes_to_guid(const drwav_uint8* data, drwav_uint8* guid) +{ + int i; + for (i = 0; i < 16; ++i) { + guid[i] = data[i]; + } +} + + +static DRWAV_INLINE drwav_uint16 drwav__bswap16(drwav_uint16 n) +{ +#ifdef DRWAV_HAS_BYTESWAP16_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_ushort(n); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_bswap16(n); + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + return ((n & 0xFF00) >> 8) | + ((n & 0x00FF) << 8); +#endif +} + +static DRWAV_INLINE drwav_uint32 drwav__bswap32(drwav_uint32 n) +{ +#ifdef DRWAV_HAS_BYTESWAP32_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_ulong(n); + #elif defined(__GNUC__) || defined(__clang__) + #if defined(DRWAV_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 6) && !defined(DRWAV_64BIT) /* <-- 64-bit inline assembly has not been tested, so disabling for now. */ + /* Inline assembly optimized implementation for ARM. In my testing, GCC does not generate optimized code with __builtin_bswap32(). */ + drwav_uint32 r; + __asm__ __volatile__ ( + #if defined(DRWAV_64BIT) + "rev %w[out], %w[in]" : [out]"=r"(r) : [in]"r"(n) /* <-- This is untested. If someone in the community could test this, that would be appreciated! */ + #else + "rev %[out], %[in]" : [out]"=r"(r) : [in]"r"(n) + #endif + ); + return r; + #else + return __builtin_bswap32(n); + #endif + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + return ((n & 0xFF000000) >> 24) | + ((n & 0x00FF0000) >> 8) | + ((n & 0x0000FF00) << 8) | + ((n & 0x000000FF) << 24); +#endif +} + +static DRWAV_INLINE drwav_uint64 drwav__bswap64(drwav_uint64 n) +{ +#ifdef DRWAV_HAS_BYTESWAP64_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_uint64(n); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_bswap64(n); + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + /* Weird "<< 32" bitshift is required for C89 because it doesn't support 64-bit constants. Should be optimized out by a good compiler. */ + return ((n & ((drwav_uint64)0xFF000000 << 32)) >> 56) | + ((n & ((drwav_uint64)0x00FF0000 << 32)) >> 40) | + ((n & ((drwav_uint64)0x0000FF00 << 32)) >> 24) | + ((n & ((drwav_uint64)0x000000FF << 32)) >> 8) | + ((n & ((drwav_uint64)0xFF000000 )) << 8) | + ((n & ((drwav_uint64)0x00FF0000 )) << 24) | + ((n & ((drwav_uint64)0x0000FF00 )) << 40) | + ((n & ((drwav_uint64)0x000000FF )) << 56); +#endif +} + + +static DRWAV_INLINE drwav_int16 drwav__bswap_s16(drwav_int16 n) +{ + return (drwav_int16)drwav__bswap16((drwav_uint16)n); +} + +static DRWAV_INLINE void drwav__bswap_samples_s16(drwav_int16* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_s16(pSamples[iSample]); + } +} + + +static DRWAV_INLINE void drwav__bswap_s24(drwav_uint8* p) +{ + drwav_uint8 t; + t = p[0]; + p[0] = p[2]; + p[2] = t; +} + +static DRWAV_INLINE void drwav__bswap_samples_s24(drwav_uint8* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + drwav_uint8* pSample = pSamples + (iSample*3); + drwav__bswap_s24(pSample); + } +} + + +static DRWAV_INLINE drwav_int32 drwav__bswap_s32(drwav_int32 n) +{ + return (drwav_int32)drwav__bswap32((drwav_uint32)n); +} + +static DRWAV_INLINE void drwav__bswap_samples_s32(drwav_int32* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_s32(pSamples[iSample]); + } +} + + +static DRWAV_INLINE float drwav__bswap_f32(float n) +{ + union { + drwav_uint32 i; + float f; + } x; + x.f = n; + x.i = drwav__bswap32(x.i); + + return x.f; +} + +static DRWAV_INLINE void drwav__bswap_samples_f32(float* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_f32(pSamples[iSample]); + } +} + + +static DRWAV_INLINE double drwav__bswap_f64(double n) +{ + union { + drwav_uint64 i; + double f; + } x; + x.f = n; + x.i = drwav__bswap64(x.i); + + return x.f; +} + +static DRWAV_INLINE void drwav__bswap_samples_f64(double* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_f64(pSamples[iSample]); + } +} + + +static DRWAV_INLINE void drwav__bswap_samples_pcm(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample) +{ + /* Assumes integer PCM. Floating point PCM is done in drwav__bswap_samples_ieee(). */ + switch (bytesPerSample) + { + case 2: /* s16, s12 (loosely packed) */ + { + drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount); + } break; + case 3: /* s24 */ + { + drwav__bswap_samples_s24((drwav_uint8*)pSamples, sampleCount); + } break; + case 4: /* s32 */ + { + drwav__bswap_samples_s32((drwav_int32*)pSamples, sampleCount); + } break; + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + +static DRWAV_INLINE void drwav__bswap_samples_ieee(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample) +{ + switch (bytesPerSample) + { + #if 0 /* Contributions welcome for f16 support. */ + case 2: /* f16 */ + { + drwav__bswap_samples_f16((drwav_float16*)pSamples, sampleCount); + } break; + #endif + case 4: /* f32 */ + { + drwav__bswap_samples_f32((float*)pSamples, sampleCount); + } break; + case 8: /* f64 */ + { + drwav__bswap_samples_f64((double*)pSamples, sampleCount); + } break; + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + +static DRWAV_INLINE void drwav__bswap_samples(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample, drwav_uint16 format) +{ + switch (format) + { + case DR_WAVE_FORMAT_PCM: + { + drwav__bswap_samples_pcm(pSamples, sampleCount, bytesPerSample); + } break; + + case DR_WAVE_FORMAT_IEEE_FLOAT: + { + drwav__bswap_samples_ieee(pSamples, sampleCount, bytesPerSample); + } break; + + case DR_WAVE_FORMAT_ALAW: + case DR_WAVE_FORMAT_MULAW: + { + drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount); + } break; + + case DR_WAVE_FORMAT_ADPCM: + case DR_WAVE_FORMAT_DVI_ADPCM: + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + + +static void* drwav__malloc_default(size_t sz, void* pUserData) +{ + (void)pUserData; + return DRWAV_MALLOC(sz); +} + +static void* drwav__realloc_default(void* p, size_t sz, void* pUserData) +{ + (void)pUserData; + return DRWAV_REALLOC(p, sz); +} + +static void drwav__free_default(void* p, void* pUserData) +{ + (void)pUserData; + DRWAV_FREE(p); +} + + +static void* drwav__malloc_from_callbacks(size_t sz, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks == NULL) { + return NULL; + } + + if (pAllocationCallbacks->onMalloc != NULL) { + return pAllocationCallbacks->onMalloc(sz, pAllocationCallbacks->pUserData); + } + + /* Try using realloc(). */ + if (pAllocationCallbacks->onRealloc != NULL) { + return pAllocationCallbacks->onRealloc(NULL, sz, pAllocationCallbacks->pUserData); + } + + return NULL; +} + +static void* drwav__realloc_from_callbacks(void* p, size_t szNew, size_t szOld, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks == NULL) { + return NULL; + } + + if (pAllocationCallbacks->onRealloc != NULL) { + return pAllocationCallbacks->onRealloc(p, szNew, pAllocationCallbacks->pUserData); + } + + /* Try emulating realloc() in terms of malloc()/free(). */ + if (pAllocationCallbacks->onMalloc != NULL && pAllocationCallbacks->onFree != NULL) { + void* p2; + + p2 = pAllocationCallbacks->onMalloc(szNew, pAllocationCallbacks->pUserData); + if (p2 == NULL) { + return NULL; + } + + if (p != NULL) { + DRWAV_COPY_MEMORY(p2, p, szOld); + pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); + } + + return p2; + } + + return NULL; +} + +static void drwav__free_from_callbacks(void* p, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (p == NULL || pAllocationCallbacks == NULL) { + return; + } + + if (pAllocationCallbacks->onFree != NULL) { + pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); + } +} + + +static drwav_allocation_callbacks drwav_copy_allocation_callbacks_or_defaults(const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks != NULL) { + /* Copy. */ + return *pAllocationCallbacks; + } else { + /* Defaults. */ + drwav_allocation_callbacks allocationCallbacks; + allocationCallbacks.pUserData = NULL; + allocationCallbacks.onMalloc = drwav__malloc_default; + allocationCallbacks.onRealloc = drwav__realloc_default; + allocationCallbacks.onFree = drwav__free_default; + return allocationCallbacks; + } +} + + +static DRWAV_INLINE drwav_bool32 drwav__is_compressed_format_tag(drwav_uint16 formatTag) +{ + return + formatTag == DR_WAVE_FORMAT_ADPCM || + formatTag == DR_WAVE_FORMAT_DVI_ADPCM; +} + +static unsigned int drwav__chunk_padding_size_riff(drwav_uint64 chunkSize) +{ + return (unsigned int)(chunkSize % 2); +} + +static unsigned int drwav__chunk_padding_size_w64(drwav_uint64 chunkSize) +{ + return (unsigned int)(chunkSize % 8); +} + +static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut); +static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut); +static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount); + +static drwav_result drwav__read_chunk_header(drwav_read_proc onRead, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_chunk_header* pHeaderOut) +{ + if (container == drwav_container_riff || container == drwav_container_rf64) { + drwav_uint8 sizeInBytes[4]; + + if (onRead(pUserData, pHeaderOut->id.fourcc, 4) != 4) { + return DRWAV_AT_END; + } + + if (onRead(pUserData, sizeInBytes, 4) != 4) { + return DRWAV_INVALID_FILE; + } + + pHeaderOut->sizeInBytes = drwav__bytes_to_u32(sizeInBytes); + pHeaderOut->paddingSize = drwav__chunk_padding_size_riff(pHeaderOut->sizeInBytes); + *pRunningBytesReadOut += 8; + } else { + drwav_uint8 sizeInBytes[8]; + + if (onRead(pUserData, pHeaderOut->id.guid, 16) != 16) { + return DRWAV_AT_END; + } + + if (onRead(pUserData, sizeInBytes, 8) != 8) { + return DRWAV_INVALID_FILE; + } + + pHeaderOut->sizeInBytes = drwav__bytes_to_u64(sizeInBytes) - 24; /* <-- Subtract 24 because w64 includes the size of the header. */ + pHeaderOut->paddingSize = drwav__chunk_padding_size_w64(pHeaderOut->sizeInBytes); + *pRunningBytesReadOut += 24; + } + + return DRWAV_SUCCESS; +} + +static drwav_bool32 drwav__seek_forward(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData) +{ + drwav_uint64 bytesRemainingToSeek = offset; + while (bytesRemainingToSeek > 0) { + if (bytesRemainingToSeek > 0x7FFFFFFF) { + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + bytesRemainingToSeek -= 0x7FFFFFFF; + } else { + if (!onSeek(pUserData, (int)bytesRemainingToSeek, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + bytesRemainingToSeek = 0; + } + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav__seek_from_start(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData) +{ + if (offset <= 0x7FFFFFFF) { + return onSeek(pUserData, (int)offset, drwav_seek_origin_start); + } + + /* Larger than 32-bit seek. */ + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_start)) { + return DRWAV_FALSE; + } + offset -= 0x7FFFFFFF; + + for (;;) { + if (offset <= 0x7FFFFFFF) { + return onSeek(pUserData, (int)offset, drwav_seek_origin_current); + } + + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + offset -= 0x7FFFFFFF; + } + + /* Should never get here. */ + /*return DRWAV_TRUE; */ +} + + +static drwav_bool32 drwav__read_fmt(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_fmt* fmtOut) +{ + drwav_chunk_header header; + drwav_uint8 fmt[16]; + + if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + + /* Skip non-fmt chunks. */ + while (((container == drwav_container_riff || container == drwav_container_rf64) && !drwav__fourcc_equal(header.id.fourcc, "fmt ")) || (container == drwav_container_w64 && !drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT))) { + if (!drwav__seek_forward(onSeek, header.sizeInBytes + header.paddingSize, pUserData)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += header.sizeInBytes + header.paddingSize; + + /* Try the next header. */ + if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + } + + + /* Validation. */ + if (container == drwav_container_riff || container == drwav_container_rf64) { + if (!drwav__fourcc_equal(header.id.fourcc, "fmt ")) { + return DRWAV_FALSE; + } + } else { + if (!drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT)) { + return DRWAV_FALSE; + } + } + + + if (onRead(pUserData, fmt, sizeof(fmt)) != sizeof(fmt)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += sizeof(fmt); + + fmtOut->formatTag = drwav__bytes_to_u16(fmt + 0); + fmtOut->channels = drwav__bytes_to_u16(fmt + 2); + fmtOut->sampleRate = drwav__bytes_to_u32(fmt + 4); + fmtOut->avgBytesPerSec = drwav__bytes_to_u32(fmt + 8); + fmtOut->blockAlign = drwav__bytes_to_u16(fmt + 12); + fmtOut->bitsPerSample = drwav__bytes_to_u16(fmt + 14); + + fmtOut->extendedSize = 0; + fmtOut->validBitsPerSample = 0; + fmtOut->channelMask = 0; + memset(fmtOut->subFormat, 0, sizeof(fmtOut->subFormat)); + + if (header.sizeInBytes > 16) { + drwav_uint8 fmt_cbSize[2]; + int bytesReadSoFar = 0; + + if (onRead(pUserData, fmt_cbSize, sizeof(fmt_cbSize)) != sizeof(fmt_cbSize)) { + return DRWAV_FALSE; /* Expecting more data. */ + } + *pRunningBytesReadOut += sizeof(fmt_cbSize); + + bytesReadSoFar = 18; + + fmtOut->extendedSize = drwav__bytes_to_u16(fmt_cbSize); + if (fmtOut->extendedSize > 0) { + /* Simple validation. */ + if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + if (fmtOut->extendedSize != 22) { + return DRWAV_FALSE; + } + } + + if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + drwav_uint8 fmtext[22]; + if (onRead(pUserData, fmtext, fmtOut->extendedSize) != fmtOut->extendedSize) { + return DRWAV_FALSE; /* Expecting more data. */ + } + + fmtOut->validBitsPerSample = drwav__bytes_to_u16(fmtext + 0); + fmtOut->channelMask = drwav__bytes_to_u32(fmtext + 2); + drwav__bytes_to_guid(fmtext + 6, fmtOut->subFormat); + } else { + if (!onSeek(pUserData, fmtOut->extendedSize, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + } + *pRunningBytesReadOut += fmtOut->extendedSize; + + bytesReadSoFar += fmtOut->extendedSize; + } + + /* Seek past any leftover bytes. For w64 the leftover will be defined based on the chunk size. */ + if (!onSeek(pUserData, (int)(header.sizeInBytes - bytesReadSoFar), drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += (header.sizeInBytes - bytesReadSoFar); + } + + if (header.paddingSize > 0) { + if (!onSeek(pUserData, header.paddingSize, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += header.paddingSize; + } + + return DRWAV_TRUE; +} + + +static size_t drwav__on_read(drwav_read_proc onRead, void* pUserData, void* pBufferOut, size_t bytesToRead, drwav_uint64* pCursor) +{ + size_t bytesRead; + + DRWAV_ASSERT(onRead != NULL); + DRWAV_ASSERT(pCursor != NULL); + + bytesRead = onRead(pUserData, pBufferOut, bytesToRead); + *pCursor += bytesRead; + return bytesRead; +} + +#if 0 +static drwav_bool32 drwav__on_seek(drwav_seek_proc onSeek, void* pUserData, int offset, drwav_seek_origin origin, drwav_uint64* pCursor) +{ + DRWAV_ASSERT(onSeek != NULL); + DRWAV_ASSERT(pCursor != NULL); + + if (!onSeek(pUserData, offset, origin)) { + return DRWAV_FALSE; + } + + if (origin == drwav_seek_origin_start) { + *pCursor = offset; + } else { + *pCursor += offset; + } + + return DRWAV_TRUE; +} +#endif + + + +static drwav_uint32 drwav_get_bytes_per_pcm_frame(drwav* pWav) +{ + /* + The bytes per frame is a bit ambiguous. It can be either be based on the bits per sample, or the block align. The way I'm doing it here + is that if the bits per sample is a multiple of 8, use floor(bitsPerSample*channels/8), otherwise fall back to the block align. + */ + if ((pWav->bitsPerSample & 0x7) == 0) { + /* Bits per sample is a multiple of 8. */ + return (pWav->bitsPerSample * pWav->fmt.channels) >> 3; + } else { + return pWav->fmt.blockAlign; + } +} + +DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT) +{ + if (pFMT == NULL) { + return 0; + } + + if (pFMT->formatTag != DR_WAVE_FORMAT_EXTENSIBLE) { + return pFMT->formatTag; + } else { + return drwav__bytes_to_u16(pFMT->subFormat); /* Only the first two bytes are required. */ + } +} + +static drwav_bool32 drwav_preinit(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pWav == NULL || onRead == NULL || onSeek == NULL) { + return DRWAV_FALSE; + } + + DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav)); + pWav->onRead = onRead; + pWav->onSeek = onSeek; + pWav->pUserData = pReadSeekUserData; + pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); + + if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { + return DRWAV_FALSE; /* Invalid allocation callbacks. */ + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags) +{ + /* This function assumes drwav_preinit() has been called beforehand. */ + + drwav_uint64 cursor; /* <-- Keeps track of the byte position so we can seek to specific locations. */ + drwav_bool32 sequential; + drwav_uint8 riff[4]; + drwav_fmt fmt; + unsigned short translatedFormatTag; + drwav_bool32 foundDataChunk; + drwav_uint64 dataChunkSize = 0; /* <-- Important! Don't explicitly set this to 0 anywhere else. Calculation of the size of the data chunk is performed in different paths depending on the container. */ + drwav_uint64 sampleCountFromFactChunk = 0; /* Same as dataChunkSize - make sure this is the only place this is initialized to 0. */ + drwav_uint64 chunkSize; + + cursor = 0; + sequential = (flags & DRWAV_SEQUENTIAL) != 0; + + /* The first 4 bytes should be the RIFF identifier. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, riff, sizeof(riff), &cursor) != sizeof(riff)) { + return DRWAV_FALSE; + } + + /* + The first 4 bytes can be used to identify the container. For RIFF files it will start with "RIFF" and for + w64 it will start with "riff". + */ + if (drwav__fourcc_equal(riff, "RIFF")) { + pWav->container = drwav_container_riff; + } else if (drwav__fourcc_equal(riff, "riff")) { + int i; + drwav_uint8 riff2[12]; + + pWav->container = drwav_container_w64; + + /* Check the rest of the GUID for validity. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, riff2, sizeof(riff2), &cursor) != sizeof(riff2)) { + return DRWAV_FALSE; + } + + for (i = 0; i < 12; ++i) { + if (riff2[i] != drwavGUID_W64_RIFF[i+4]) { + return DRWAV_FALSE; + } + } + } else if (drwav__fourcc_equal(riff, "RF64")) { + pWav->container = drwav_container_rf64; + } else { + return DRWAV_FALSE; /* Unknown or unsupported container. */ + } + + + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + drwav_uint8 chunkSizeBytes[4]; + drwav_uint8 wave[4]; + + /* RIFF/WAVE */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) { + return DRWAV_FALSE; + } + + if (pWav->container == drwav_container_riff) { + if (drwav__bytes_to_u32(chunkSizeBytes) < 36) { + return DRWAV_FALSE; /* Chunk size should always be at least 36 bytes. */ + } + } else { + if (drwav__bytes_to_u32(chunkSizeBytes) != 0xFFFFFFFF) { + return DRWAV_FALSE; /* Chunk size should always be set to -1/0xFFFFFFFF for RF64. The actual size is retrieved later. */ + } + } + + if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) { + return DRWAV_FALSE; + } + + if (!drwav__fourcc_equal(wave, "WAVE")) { + return DRWAV_FALSE; /* Expecting "WAVE". */ + } + } else { + drwav_uint8 chunkSizeBytes[8]; + drwav_uint8 wave[16]; + + /* W64 */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) { + return DRWAV_FALSE; + } + + if (drwav__bytes_to_u64(chunkSizeBytes) < 80) { + return DRWAV_FALSE; + } + + if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) { + return DRWAV_FALSE; + } + + if (!drwav__guid_equal(wave, drwavGUID_W64_WAVE)) { + return DRWAV_FALSE; + } + } + + + /* For RF64, the "ds64" chunk must come next, before the "fmt " chunk. */ + if (pWav->container == drwav_container_rf64) { + drwav_uint8 sizeBytes[8]; + drwav_uint64 bytesRemainingInChunk; + drwav_chunk_header header; + drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header); + if (result != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + if (!drwav__fourcc_equal(header.id.fourcc, "ds64")) { + return DRWAV_FALSE; /* Expecting "ds64". */ + } + + bytesRemainingInChunk = header.sizeInBytes + header.paddingSize; + + /* We don't care about the size of the RIFF chunk - skip it. */ + if (!drwav__seek_forward(pWav->onSeek, 8, pWav->pUserData)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + cursor += 8; + + + /* Next 8 bytes is the size of the "data" chunk. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + dataChunkSize = drwav__bytes_to_u64(sizeBytes); + + + /* Next 8 bytes is the same count which we would usually derived from the FACT chunk if it was available. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + sampleCountFromFactChunk = drwav__bytes_to_u64(sizeBytes); + + + /* Skip over everything else. */ + if (!drwav__seek_forward(pWav->onSeek, bytesRemainingInChunk, pWav->pUserData)) { + return DRWAV_FALSE; + } + cursor += bytesRemainingInChunk; + } + + + /* The next bytes should be the "fmt " chunk. */ + if (!drwav__read_fmt(pWav->onRead, pWav->onSeek, pWav->pUserData, pWav->container, &cursor, &fmt)) { + return DRWAV_FALSE; /* Failed to read the "fmt " chunk. */ + } + + /* Basic validation. */ + if ((fmt.sampleRate == 0 || fmt.sampleRate > DRWAV_MAX_SAMPLE_RATE) || + (fmt.channels == 0 || fmt.channels > DRWAV_MAX_CHANNELS) || + (fmt.bitsPerSample == 0 || fmt.bitsPerSample > DRWAV_MAX_BITS_PER_SAMPLE) || + fmt.blockAlign == 0) { + return DRWAV_FALSE; /* Probably an invalid WAV file. */ + } + + + /* Translate the internal format. */ + translatedFormatTag = fmt.formatTag; + if (translatedFormatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + translatedFormatTag = drwav__bytes_to_u16(fmt.subFormat + 0); + } + + + /* + We need to enumerate over each chunk for two reasons: + 1) The "data" chunk may not be the next one + 2) We may want to report each chunk back to the client + + In order to correctly report each chunk back to the client we will need to keep looping until the end of the file. + */ + foundDataChunk = DRWAV_FALSE; + + /* The next chunk we care about is the "data" chunk. This is not necessarily the next chunk so we'll need to loop. */ + for (;;) + { + drwav_chunk_header header; + drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header); + if (result != DRWAV_SUCCESS) { + if (!foundDataChunk) { + return DRWAV_FALSE; + } else { + break; /* Probably at the end of the file. Get out of the loop. */ + } + } + + /* Tell the client about this chunk. */ + if (!sequential && onChunk != NULL) { + drwav_uint64 callbackBytesRead = onChunk(pChunkUserData, pWav->onRead, pWav->onSeek, pWav->pUserData, &header, pWav->container, &fmt); + + /* + dr_wav may need to read the contents of the chunk, so we now need to seek back to the position before + we called the callback. + */ + if (callbackBytesRead > 0) { + if (!drwav__seek_from_start(pWav->onSeek, cursor, pWav->pUserData)) { + return DRWAV_FALSE; + } + } + } + + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + + chunkSize = header.sizeInBytes; + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + if (drwav__fourcc_equal(header.id.fourcc, "data")) { + foundDataChunk = DRWAV_TRUE; + if (pWav->container != drwav_container_rf64) { /* The data chunk size for RF64 will always be set to 0xFFFFFFFF here. It was set to it's true value earlier. */ + dataChunkSize = chunkSize; + } + } + } else { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_DATA)) { + foundDataChunk = DRWAV_TRUE; + dataChunkSize = chunkSize; + } + } + + /* + If at this point we have found the data chunk and we're running in sequential mode, we need to break out of this loop. The reason for + this is that we would otherwise require a backwards seek which sequential mode forbids. + */ + if (foundDataChunk && sequential) { + break; + } + + /* Optional. Get the total sample count from the FACT chunk. This is useful for compressed formats. */ + if (pWav->container == drwav_container_riff) { + if (drwav__fourcc_equal(header.id.fourcc, "fact")) { + drwav_uint32 sampleCount; + if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCount, 4, &cursor) != 4) { + return DRWAV_FALSE; + } + chunkSize -= 4; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + + /* + The sample count in the "fact" chunk is either unreliable, or I'm not understanding it properly. For now I am only enabling this + for Microsoft ADPCM formats. + */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + sampleCountFromFactChunk = sampleCount; + } else { + sampleCountFromFactChunk = 0; + } + } + } else if (pWav->container == drwav_container_w64) { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_FACT)) { + if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCountFromFactChunk, 8, &cursor) != 8) { + return DRWAV_FALSE; + } + chunkSize -= 8; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + } + } else if (pWav->container == drwav_container_rf64) { + /* We retrieved the sample count from the ds64 chunk earlier so no need to do that here. */ + } + + /* "smpl" chunk. */ + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + if (drwav__fourcc_equal(header.id.fourcc, "smpl")) { + drwav_uint8 smplHeaderData[36]; /* 36 = size of the smpl header section, not including the loop data. */ + if (chunkSize >= sizeof(smplHeaderData)) { + drwav_uint64 bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplHeaderData, sizeof(smplHeaderData), &cursor); + chunkSize -= bytesJustRead; + + if (bytesJustRead == sizeof(smplHeaderData)) { + drwav_uint32 iLoop; + + pWav->smpl.manufacturer = drwav__bytes_to_u32(smplHeaderData+0); + pWav->smpl.product = drwav__bytes_to_u32(smplHeaderData+4); + pWav->smpl.samplePeriod = drwav__bytes_to_u32(smplHeaderData+8); + pWav->smpl.midiUnityNotes = drwav__bytes_to_u32(smplHeaderData+12); + pWav->smpl.midiPitchFraction = drwav__bytes_to_u32(smplHeaderData+16); + pWav->smpl.smpteFormat = drwav__bytes_to_u32(smplHeaderData+20); + pWav->smpl.smpteOffset = drwav__bytes_to_u32(smplHeaderData+24); + pWav->smpl.numSampleLoops = drwav__bytes_to_u32(smplHeaderData+28); + pWav->smpl.samplerData = drwav__bytes_to_u32(smplHeaderData+32); + + for (iLoop = 0; iLoop < pWav->smpl.numSampleLoops && iLoop < drwav_countof(pWav->smpl.loops); ++iLoop) { + drwav_uint8 smplLoopData[24]; /* 24 = size of a loop section in the smpl chunk. */ + bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplLoopData, sizeof(smplLoopData), &cursor); + chunkSize -= bytesJustRead; + + if (bytesJustRead == sizeof(smplLoopData)) { + pWav->smpl.loops[iLoop].cuePointId = drwav__bytes_to_u32(smplLoopData+0); + pWav->smpl.loops[iLoop].type = drwav__bytes_to_u32(smplLoopData+4); + pWav->smpl.loops[iLoop].start = drwav__bytes_to_u32(smplLoopData+8); + pWav->smpl.loops[iLoop].end = drwav__bytes_to_u32(smplLoopData+12); + pWav->smpl.loops[iLoop].fraction = drwav__bytes_to_u32(smplLoopData+16); + pWav->smpl.loops[iLoop].playCount = drwav__bytes_to_u32(smplLoopData+20); + } else { + break; /* Break from the smpl loop for loop. */ + } + } + } + } else { + /* Looks like invalid data. Ignore the chunk. */ + } + } + } else { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_SMPL)) { + /* + This path will be hit when a W64 WAV file contains a smpl chunk. I don't have a sample file to test this path, so a contribution + is welcome to add support for this. + */ + } + } + + /* Make sure we seek past the padding. */ + chunkSize += header.paddingSize; + if (!drwav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData)) { + break; + } + cursor += chunkSize; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + } + + /* If we haven't found a data chunk, return an error. */ + if (!foundDataChunk) { + return DRWAV_FALSE; + } + + /* We may have moved passed the data chunk. If so we need to move back. If running in sequential mode we can assume we are already sitting on the data chunk. */ + if (!sequential) { + if (!drwav__seek_from_start(pWav->onSeek, pWav->dataChunkDataPos, pWav->pUserData)) { + return DRWAV_FALSE; + } + cursor = pWav->dataChunkDataPos; + } + + + /* At this point we should be sitting on the first byte of the raw audio data. */ + + pWav->fmt = fmt; + pWav->sampleRate = fmt.sampleRate; + pWav->channels = fmt.channels; + pWav->bitsPerSample = fmt.bitsPerSample; + pWav->bytesRemaining = dataChunkSize; + pWav->translatedFormatTag = translatedFormatTag; + pWav->dataChunkDataSize = dataChunkSize; + + if (sampleCountFromFactChunk != 0) { + pWav->totalPCMFrameCount = sampleCountFromFactChunk; + } else { + pWav->totalPCMFrameCount = dataChunkSize / drwav_get_bytes_per_pcm_frame(pWav); + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + drwav_uint64 totalBlockHeaderSizeInBytes; + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + + /* Make sure any trailing partial block is accounted for. */ + if ((blockCount * fmt.blockAlign) < dataChunkSize) { + blockCount += 1; + } + + /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */ + totalBlockHeaderSizeInBytes = blockCount * (6*fmt.channels); + pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels; + } + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + drwav_uint64 totalBlockHeaderSizeInBytes; + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + + /* Make sure any trailing partial block is accounted for. */ + if ((blockCount * fmt.blockAlign) < dataChunkSize) { + blockCount += 1; + } + + /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */ + totalBlockHeaderSizeInBytes = blockCount * (4*fmt.channels); + pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels; + + /* The header includes a decoded sample for each channel which acts as the initial predictor sample. */ + pWav->totalPCMFrameCount += blockCount; + } + } + + /* Some formats only support a certain number of channels. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM || pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + if (pWav->channels > 2) { + return DRWAV_FALSE; + } + } + +#ifdef DR_WAV_LIBSNDFILE_COMPAT + /* + I use libsndfile as a benchmark for testing, however in the version I'm using (from the Windows installer on the libsndfile website), + it appears the total sample count libsndfile uses for MS-ADPCM is incorrect. It would seem they are computing the total sample count + from the number of blocks, however this results in the inclusion of extra silent samples at the end of the last block. The correct + way to know the total sample count is to inspect the "fact" chunk, which should always be present for compressed formats, and should + always include the sample count. This little block of code below is only used to emulate the libsndfile logic so I can properly run my + correctness tests against libsndfile, and is disabled by default. + */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (6*pWav->channels))) * 2)) / fmt.channels; /* x2 because two samples per byte. */ + } + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (4*pWav->channels))) * 2) + (blockCount * pWav->channels)) / fmt.channels; + } +#endif + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init__internal(pWav, onChunk, pChunkUserData, flags); +} + + +static drwav_uint32 drwav__riff_chunk_size_riff(drwav_uint64 dataChunkSize) +{ + drwav_uint64 chunkSize = 4 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 24 = "fmt " chunk. */ + if (chunkSize > 0xFFFFFFFFUL) { + chunkSize = 0xFFFFFFFFUL; + } + + return (drwav_uint32)chunkSize; /* Safe cast due to the clamp above. */ +} + +static drwav_uint32 drwav__data_chunk_size_riff(drwav_uint64 dataChunkSize) +{ + if (dataChunkSize <= 0xFFFFFFFFUL) { + return (drwav_uint32)dataChunkSize; + } else { + return 0xFFFFFFFFUL; + } +} + +static drwav_uint64 drwav__riff_chunk_size_w64(drwav_uint64 dataChunkSize) +{ + drwav_uint64 dataSubchunkPaddingSize = drwav__chunk_padding_size_w64(dataChunkSize); + + return 80 + 24 + dataChunkSize + dataSubchunkPaddingSize; /* +24 because W64 includes the size of the GUID and size fields. */ +} + +static drwav_uint64 drwav__data_chunk_size_w64(drwav_uint64 dataChunkSize) +{ + return 24 + dataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ +} + +static drwav_uint64 drwav__riff_chunk_size_rf64(drwav_uint64 dataChunkSize) +{ + drwav_uint64 chunkSize = 4 + 36 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 36 = "ds64" chunk. 24 = "fmt " chunk. */ + if (chunkSize > 0xFFFFFFFFUL) { + chunkSize = 0xFFFFFFFFUL; + } + + return chunkSize; +} + +static drwav_uint64 drwav__data_chunk_size_rf64(drwav_uint64 dataChunkSize) +{ + return dataChunkSize; +} + + +static size_t drwav__write(drwav* pWav, const void* pData, size_t dataSize) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + /* Generic write. Assumes no byte reordering required. */ + return pWav->onWrite(pWav->pUserData, pData, dataSize); +} + +static size_t drwav__write_u16ne_to_le(drwav* pWav, drwav_uint16 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap16(value); + } + + return drwav__write(pWav, &value, 2); +} + +static size_t drwav__write_u32ne_to_le(drwav* pWav, drwav_uint32 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap32(value); + } + + return drwav__write(pWav, &value, 4); +} + +static size_t drwav__write_u64ne_to_le(drwav* pWav, drwav_uint64 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap64(value); + } + + return drwav__write(pWav, &value, 8); +} + + +static drwav_bool32 drwav_preinit_write(drwav* pWav, const drwav_data_format* pFormat, drwav_bool32 isSequential, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pWav == NULL || onWrite == NULL) { + return DRWAV_FALSE; + } + + if (!isSequential && onSeek == NULL) { + return DRWAV_FALSE; /* <-- onSeek is required when in non-sequential mode. */ + } + + /* Not currently supporting compressed formats. Will need to add support for the "fact" chunk before we enable this. */ + if (pFormat->format == DR_WAVE_FORMAT_EXTENSIBLE) { + return DRWAV_FALSE; + } + if (pFormat->format == DR_WAVE_FORMAT_ADPCM || pFormat->format == DR_WAVE_FORMAT_DVI_ADPCM) { + return DRWAV_FALSE; + } + + DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav)); + pWav->onWrite = onWrite; + pWav->onSeek = onSeek; + pWav->pUserData = pUserData; + pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); + + if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { + return DRWAV_FALSE; /* Invalid allocation callbacks. */ + } + + pWav->fmt.formatTag = (drwav_uint16)pFormat->format; + pWav->fmt.channels = (drwav_uint16)pFormat->channels; + pWav->fmt.sampleRate = pFormat->sampleRate; + pWav->fmt.avgBytesPerSec = (drwav_uint32)((pFormat->bitsPerSample * pFormat->sampleRate * pFormat->channels) / 8); + pWav->fmt.blockAlign = (drwav_uint16)((pFormat->channels * pFormat->bitsPerSample) / 8); + pWav->fmt.bitsPerSample = (drwav_uint16)pFormat->bitsPerSample; + pWav->fmt.extendedSize = 0; + pWav->isSequentialWrite = isSequential; + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount) +{ + /* The function assumes drwav_preinit_write() was called beforehand. */ + + size_t runningPos = 0; + drwav_uint64 initialDataChunkSize = 0; + drwav_uint64 chunkSizeFMT; + + /* + The initial values for the "RIFF" and "data" chunks depends on whether or not we are initializing in sequential mode or not. In + sequential mode we set this to its final values straight away since they can be calculated from the total sample count. In non- + sequential mode we initialize it all to zero and fill it out in drwav_uninit() using a backwards seek. + */ + if (pWav->isSequentialWrite) { + initialDataChunkSize = (totalSampleCount * pWav->fmt.bitsPerSample) / 8; + + /* + The RIFF container has a limit on the number of samples. drwav is not allowing this. There's no practical limits for Wave64 + so for the sake of simplicity I'm not doing any validation for that. + */ + if (pFormat->container == drwav_container_riff) { + if (initialDataChunkSize > (0xFFFFFFFFUL - 36)) { + return DRWAV_FALSE; /* Not enough room to store every sample. */ + } + } + } + + pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize; + + + /* "RIFF" chunk. */ + if (pFormat->container == drwav_container_riff) { + drwav_uint32 chunkSizeRIFF = 28 + (drwav_uint32)initialDataChunkSize; /* +28 = "WAVE" + [sizeof "fmt " chunk] */ + runningPos += drwav__write(pWav, "RIFF", 4); + runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeRIFF); + runningPos += drwav__write(pWav, "WAVE", 4); + } else if (pFormat->container == drwav_container_w64) { + drwav_uint64 chunkSizeRIFF = 80 + 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ + runningPos += drwav__write(pWav, drwavGUID_W64_RIFF, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeRIFF); + runningPos += drwav__write(pWav, drwavGUID_W64_WAVE, 16); + } else if (pFormat->container == drwav_container_rf64) { + runningPos += drwav__write(pWav, "RF64", 4); + runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always 0xFFFFFFFF for RF64. Set to a proper value in the "ds64" chunk. */ + runningPos += drwav__write(pWav, "WAVE", 4); + } + + + /* "ds64" chunk (RF64 only). */ + if (pFormat->container == drwav_container_rf64) { + drwav_uint32 initialds64ChunkSize = 28; /* 28 = [Size of RIFF (8 bytes)] + [Size of DATA (8 bytes)] + [Sample Count (8 bytes)] + [Table Length (4 bytes)]. Table length always set to 0. */ + drwav_uint64 initialRiffChunkSize = 8 + initialds64ChunkSize + initialDataChunkSize; /* +8 for the ds64 header. */ + + runningPos += drwav__write(pWav, "ds64", 4); + runningPos += drwav__write_u32ne_to_le(pWav, initialds64ChunkSize); /* Size of ds64. */ + runningPos += drwav__write_u64ne_to_le(pWav, initialRiffChunkSize); /* Size of RIFF. Set to true value at the end. */ + runningPos += drwav__write_u64ne_to_le(pWav, initialDataChunkSize); /* Size of DATA. Set to true value at the end. */ + runningPos += drwav__write_u64ne_to_le(pWav, totalSampleCount); /* Sample count. */ + runningPos += drwav__write_u32ne_to_le(pWav, 0); /* Table length. Always set to zero in our case since we're not doing any other chunks than "DATA". */ + } + + + /* "fmt " chunk. */ + if (pFormat->container == drwav_container_riff || pFormat->container == drwav_container_rf64) { + chunkSizeFMT = 16; + runningPos += drwav__write(pWav, "fmt ", 4); + runningPos += drwav__write_u32ne_to_le(pWav, (drwav_uint32)chunkSizeFMT); + } else if (pFormat->container == drwav_container_w64) { + chunkSizeFMT = 40; + runningPos += drwav__write(pWav, drwavGUID_W64_FMT, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeFMT); + } + + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.formatTag); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.channels); + runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.sampleRate); + runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.avgBytesPerSec); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.blockAlign); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.bitsPerSample); + + pWav->dataChunkDataPos = runningPos; + + /* "data" chunk. */ + if (pFormat->container == drwav_container_riff) { + drwav_uint32 chunkSizeDATA = (drwav_uint32)initialDataChunkSize; + runningPos += drwav__write(pWav, "data", 4); + runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeDATA); + } else if (pFormat->container == drwav_container_w64) { + drwav_uint64 chunkSizeDATA = 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ + runningPos += drwav__write(pWav, drwavGUID_W64_DATA, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeDATA); + } else if (pFormat->container == drwav_container_rf64) { + runningPos += drwav__write(pWav, "data", 4); + runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always set to 0xFFFFFFFF for RF64. The true size of the data chunk is specified in the ds64 chunk. */ + } + + /* + The runningPos variable is incremented in the section above but is left unused which is causing some static analysis tools to detect it + as a dead store. I'm leaving this as-is for safety just in case I want to expand this function later to include other tags and want to + keep track of the running position for whatever reason. The line below should silence the static analysis tools. + */ + (void)runningPos; + + /* Set some properties for the client's convenience. */ + pWav->container = pFormat->container; + pWav->channels = (drwav_uint16)pFormat->channels; + pWav->sampleRate = pFormat->sampleRate; + pWav->bitsPerSample = (drwav_uint16)pFormat->bitsPerSample; + pWav->translatedFormatTag = (drwav_uint16)pFormat->format; + + return DRWAV_TRUE; +} + + +DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit_write(pWav, pFormat, DRWAV_FALSE, onWrite, onSeek, pUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init_write__internal(pWav, pFormat, 0); /* DRWAV_FALSE = Not Sequential */ +} + +DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit_write(pWav, pFormat, DRWAV_TRUE, onWrite, NULL, pUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init_write__internal(pWav, pFormat, totalSampleCount); /* DRWAV_TRUE = Sequential */ +} + +DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_write_sequential(pWav, pFormat, totalPCMFrameCount*pFormat->channels, onWrite, pUserData, pAllocationCallbacks); +} + +DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount) +{ + /* Casting totalSampleCount to drwav_int64 for VC6 compatibility. No issues in practice because nobody is going to exhaust the whole 63 bits. */ + drwav_uint64 targetDataSizeBytes = (drwav_uint64)((drwav_int64)totalSampleCount * pFormat->channels * pFormat->bitsPerSample/8.0); + drwav_uint64 riffChunkSizeBytes; + drwav_uint64 fileSizeBytes = 0; + + if (pFormat->container == drwav_container_riff) { + riffChunkSizeBytes = drwav__riff_chunk_size_riff(targetDataSizeBytes); + fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */ + } else if (pFormat->container == drwav_container_w64) { + riffChunkSizeBytes = drwav__riff_chunk_size_w64(targetDataSizeBytes); + fileSizeBytes = riffChunkSizeBytes; + } else if (pFormat->container == drwav_container_rf64) { + riffChunkSizeBytes = drwav__riff_chunk_size_rf64(targetDataSizeBytes); + fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */ + } + + return fileSizeBytes; +} + + +#ifndef DR_WAV_NO_STDIO + +/* drwav_result_from_errno() is only used for fopen() and wfopen() so putting it inside DR_WAV_NO_STDIO for now. If something else needs this later we can move it out. */ +#include <errno.h> +static drwav_result drwav_result_from_errno(int e) +{ + switch (e) + { + case 0: return DRWAV_SUCCESS; + #ifdef EPERM + case EPERM: return DRWAV_INVALID_OPERATION; + #endif + #ifdef ENOENT + case ENOENT: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef ESRCH + case ESRCH: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef EINTR + case EINTR: return DRWAV_INTERRUPT; + #endif + #ifdef EIO + case EIO: return DRWAV_IO_ERROR; + #endif + #ifdef ENXIO + case ENXIO: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef E2BIG + case E2BIG: return DRWAV_INVALID_ARGS; + #endif + #ifdef ENOEXEC + case ENOEXEC: return DRWAV_INVALID_FILE; + #endif + #ifdef EBADF + case EBADF: return DRWAV_INVALID_FILE; + #endif + #ifdef ECHILD + case ECHILD: return DRWAV_ERROR; + #endif + #ifdef EAGAIN + case EAGAIN: return DRWAV_UNAVAILABLE; + #endif + #ifdef ENOMEM + case ENOMEM: return DRWAV_OUT_OF_MEMORY; + #endif + #ifdef EACCES + case EACCES: return DRWAV_ACCESS_DENIED; + #endif + #ifdef EFAULT + case EFAULT: return DRWAV_BAD_ADDRESS; + #endif + #ifdef ENOTBLK + case ENOTBLK: return DRWAV_ERROR; + #endif + #ifdef EBUSY + case EBUSY: return DRWAV_BUSY; + #endif + #ifdef EEXIST + case EEXIST: return DRWAV_ALREADY_EXISTS; + #endif + #ifdef EXDEV + case EXDEV: return DRWAV_ERROR; + #endif + #ifdef ENODEV + case ENODEV: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef ENOTDIR + case ENOTDIR: return DRWAV_NOT_DIRECTORY; + #endif + #ifdef EISDIR + case EISDIR: return DRWAV_IS_DIRECTORY; + #endif + #ifdef EINVAL + case EINVAL: return DRWAV_INVALID_ARGS; + #endif + #ifdef ENFILE + case ENFILE: return DRWAV_TOO_MANY_OPEN_FILES; + #endif + #ifdef EMFILE + case EMFILE: return DRWAV_TOO_MANY_OPEN_FILES; + #endif + #ifdef ENOTTY + case ENOTTY: return DRWAV_INVALID_OPERATION; + #endif + #ifdef ETXTBSY + case ETXTBSY: return DRWAV_BUSY; + #endif + #ifdef EFBIG + case EFBIG: return DRWAV_TOO_BIG; + #endif + #ifdef ENOSPC + case ENOSPC: return DRWAV_NO_SPACE; + #endif + #ifdef ESPIPE + case ESPIPE: return DRWAV_BAD_SEEK; + #endif + #ifdef EROFS + case EROFS: return DRWAV_ACCESS_DENIED; + #endif + #ifdef EMLINK + case EMLINK: return DRWAV_TOO_MANY_LINKS; + #endif + #ifdef EPIPE + case EPIPE: return DRWAV_BAD_PIPE; + #endif + #ifdef EDOM + case EDOM: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef ERANGE + case ERANGE: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef EDEADLK + case EDEADLK: return DRWAV_DEADLOCK; + #endif + #ifdef ENAMETOOLONG + case ENAMETOOLONG: return DRWAV_PATH_TOO_LONG; + #endif + #ifdef ENOLCK + case ENOLCK: return DRWAV_ERROR; + #endif + #ifdef ENOSYS + case ENOSYS: return DRWAV_NOT_IMPLEMENTED; + #endif + #ifdef ENOTEMPTY + case ENOTEMPTY: return DRWAV_DIRECTORY_NOT_EMPTY; + #endif + #ifdef ELOOP + case ELOOP: return DRWAV_TOO_MANY_LINKS; + #endif + #ifdef ENOMSG + case ENOMSG: return DRWAV_NO_MESSAGE; + #endif + #ifdef EIDRM + case EIDRM: return DRWAV_ERROR; + #endif + #ifdef ECHRNG + case ECHRNG: return DRWAV_ERROR; + #endif + #ifdef EL2NSYNC + case EL2NSYNC: return DRWAV_ERROR; + #endif + #ifdef EL3HLT + case EL3HLT: return DRWAV_ERROR; + #endif + #ifdef EL3RST + case EL3RST: return DRWAV_ERROR; + #endif + #ifdef ELNRNG + case ELNRNG: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef EUNATCH + case EUNATCH: return DRWAV_ERROR; + #endif + #ifdef ENOCSI + case ENOCSI: return DRWAV_ERROR; + #endif + #ifdef EL2HLT + case EL2HLT: return DRWAV_ERROR; + #endif + #ifdef EBADE + case EBADE: return DRWAV_ERROR; + #endif + #ifdef EBADR + case EBADR: return DRWAV_ERROR; + #endif + #ifdef EXFULL + case EXFULL: return DRWAV_ERROR; + #endif + #ifdef ENOANO + case ENOANO: return DRWAV_ERROR; + #endif + #ifdef EBADRQC + case EBADRQC: return DRWAV_ERROR; + #endif + #ifdef EBADSLT + case EBADSLT: return DRWAV_ERROR; + #endif + #ifdef EBFONT + case EBFONT: return DRWAV_INVALID_FILE; + #endif + #ifdef ENOSTR + case ENOSTR: return DRWAV_ERROR; + #endif + #ifdef ENODATA + case ENODATA: return DRWAV_NO_DATA_AVAILABLE; + #endif + #ifdef ETIME + case ETIME: return DRWAV_TIMEOUT; + #endif + #ifdef ENOSR + case ENOSR: return DRWAV_NO_DATA_AVAILABLE; + #endif + #ifdef ENONET + case ENONET: return DRWAV_NO_NETWORK; + #endif + #ifdef ENOPKG + case ENOPKG: return DRWAV_ERROR; + #endif + #ifdef EREMOTE + case EREMOTE: return DRWAV_ERROR; + #endif + #ifdef ENOLINK + case ENOLINK: return DRWAV_ERROR; + #endif + #ifdef EADV + case EADV: return DRWAV_ERROR; + #endif + #ifdef ESRMNT + case ESRMNT: return DRWAV_ERROR; + #endif + #ifdef ECOMM + case ECOMM: return DRWAV_ERROR; + #endif + #ifdef EPROTO + case EPROTO: return DRWAV_ERROR; + #endif + #ifdef EMULTIHOP + case EMULTIHOP: return DRWAV_ERROR; + #endif + #ifdef EDOTDOT + case EDOTDOT: return DRWAV_ERROR; + #endif + #ifdef EBADMSG + case EBADMSG: return DRWAV_BAD_MESSAGE; + #endif + #ifdef EOVERFLOW + case EOVERFLOW: return DRWAV_TOO_BIG; + #endif + #ifdef ENOTUNIQ + case ENOTUNIQ: return DRWAV_NOT_UNIQUE; + #endif + #ifdef EBADFD + case EBADFD: return DRWAV_ERROR; + #endif + #ifdef EREMCHG + case EREMCHG: return DRWAV_ERROR; + #endif + #ifdef ELIBACC + case ELIBACC: return DRWAV_ACCESS_DENIED; + #endif + #ifdef ELIBBAD + case ELIBBAD: return DRWAV_INVALID_FILE; + #endif + #ifdef ELIBSCN + case ELIBSCN: return DRWAV_INVALID_FILE; + #endif + #ifdef ELIBMAX + case ELIBMAX: return DRWAV_ERROR; + #endif + #ifdef ELIBEXEC + case ELIBEXEC: return DRWAV_ERROR; + #endif + #ifdef EILSEQ + case EILSEQ: return DRWAV_INVALID_DATA; + #endif + #ifdef ERESTART + case ERESTART: return DRWAV_ERROR; + #endif + #ifdef ESTRPIPE + case ESTRPIPE: return DRWAV_ERROR; + #endif + #ifdef EUSERS + case EUSERS: return DRWAV_ERROR; + #endif + #ifdef ENOTSOCK + case ENOTSOCK: return DRWAV_NOT_SOCKET; + #endif + #ifdef EDESTADDRREQ + case EDESTADDRREQ: return DRWAV_NO_ADDRESS; + #endif + #ifdef EMSGSIZE + case EMSGSIZE: return DRWAV_TOO_BIG; + #endif + #ifdef EPROTOTYPE + case EPROTOTYPE: return DRWAV_BAD_PROTOCOL; + #endif + #ifdef ENOPROTOOPT + case ENOPROTOOPT: return DRWAV_PROTOCOL_UNAVAILABLE; + #endif + #ifdef EPROTONOSUPPORT + case EPROTONOSUPPORT: return DRWAV_PROTOCOL_NOT_SUPPORTED; + #endif + #ifdef ESOCKTNOSUPPORT + case ESOCKTNOSUPPORT: return DRWAV_SOCKET_NOT_SUPPORTED; + #endif + #ifdef EOPNOTSUPP + case EOPNOTSUPP: return DRWAV_INVALID_OPERATION; + #endif + #ifdef EPFNOSUPPORT + case EPFNOSUPPORT: return DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED; + #endif + #ifdef EAFNOSUPPORT + case EAFNOSUPPORT: return DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED; + #endif + #ifdef EADDRINUSE + case EADDRINUSE: return DRWAV_ALREADY_IN_USE; + #endif + #ifdef EADDRNOTAVAIL + case EADDRNOTAVAIL: return DRWAV_ERROR; + #endif + #ifdef ENETDOWN + case ENETDOWN: return DRWAV_NO_NETWORK; + #endif + #ifdef ENETUNREACH + case ENETUNREACH: return DRWAV_NO_NETWORK; + #endif + #ifdef ENETRESET + case ENETRESET: return DRWAV_NO_NETWORK; + #endif + #ifdef ECONNABORTED + case ECONNABORTED: return DRWAV_NO_NETWORK; + #endif + #ifdef ECONNRESET + case ECONNRESET: return DRWAV_CONNECTION_RESET; + #endif + #ifdef ENOBUFS + case ENOBUFS: return DRWAV_NO_SPACE; + #endif + #ifdef EISCONN + case EISCONN: return DRWAV_ALREADY_CONNECTED; + #endif + #ifdef ENOTCONN + case ENOTCONN: return DRWAV_NOT_CONNECTED; + #endif + #ifdef ESHUTDOWN + case ESHUTDOWN: return DRWAV_ERROR; + #endif + #ifdef ETOOMANYREFS + case ETOOMANYREFS: return DRWAV_ERROR; + #endif + #ifdef ETIMEDOUT + case ETIMEDOUT: return DRWAV_TIMEOUT; + #endif + #ifdef ECONNREFUSED + case ECONNREFUSED: return DRWAV_CONNECTION_REFUSED; + #endif + #ifdef EHOSTDOWN + case EHOSTDOWN: return DRWAV_NO_HOST; + #endif + #ifdef EHOSTUNREACH + case EHOSTUNREACH: return DRWAV_NO_HOST; + #endif + #ifdef EALREADY + case EALREADY: return DRWAV_IN_PROGRESS; + #endif + #ifdef EINPROGRESS + case EINPROGRESS: return DRWAV_IN_PROGRESS; + #endif + #ifdef ESTALE + case ESTALE: return DRWAV_INVALID_FILE; + #endif + #ifdef EUCLEAN + case EUCLEAN: return DRWAV_ERROR; + #endif + #ifdef ENOTNAM + case ENOTNAM: return DRWAV_ERROR; + #endif + #ifdef ENAVAIL + case ENAVAIL: return DRWAV_ERROR; + #endif + #ifdef EISNAM + case EISNAM: return DRWAV_ERROR; + #endif + #ifdef EREMOTEIO + case EREMOTEIO: return DRWAV_IO_ERROR; + #endif + #ifdef EDQUOT + case EDQUOT: return DRWAV_NO_SPACE; + #endif + #ifdef ENOMEDIUM + case ENOMEDIUM: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef EMEDIUMTYPE + case EMEDIUMTYPE: return DRWAV_ERROR; + #endif + #ifdef ECANCELED + case ECANCELED: return DRWAV_CANCELLED; + #endif + #ifdef ENOKEY + case ENOKEY: return DRWAV_ERROR; + #endif + #ifdef EKEYEXPIRED + case EKEYEXPIRED: return DRWAV_ERROR; + #endif + #ifdef EKEYREVOKED + case EKEYREVOKED: return DRWAV_ERROR; + #endif + #ifdef EKEYREJECTED + case EKEYREJECTED: return DRWAV_ERROR; + #endif + #ifdef EOWNERDEAD + case EOWNERDEAD: return DRWAV_ERROR; + #endif + #ifdef ENOTRECOVERABLE + case ENOTRECOVERABLE: return DRWAV_ERROR; + #endif + #ifdef ERFKILL + case ERFKILL: return DRWAV_ERROR; + #endif + #ifdef EHWPOISON + case EHWPOISON: return DRWAV_ERROR; + #endif + default: return DRWAV_ERROR; + } +} + +static drwav_result drwav_fopen(FILE** ppFile, const char* pFilePath, const char* pOpenMode) +{ +#if _MSC_VER && _MSC_VER >= 1400 + errno_t err; +#endif + + if (ppFile != NULL) { + *ppFile = NULL; /* Safety. */ + } + + if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) { + return DRWAV_INVALID_ARGS; + } + +#if _MSC_VER && _MSC_VER >= 1400 + err = fopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return drwav_result_from_errno(err); + } +#else +#if defined(_WIN32) || defined(__APPLE__) + *ppFile = fopen(pFilePath, pOpenMode); +#else + #if defined(_FILE_OFFSET_BITS) && _FILE_OFFSET_BITS == 64 && defined(_LARGEFILE64_SOURCE) + *ppFile = fopen64(pFilePath, pOpenMode); + #else + *ppFile = fopen(pFilePath, pOpenMode); + #endif +#endif + if (*ppFile == NULL) { + drwav_result result = drwav_result_from_errno(errno); + if (result == DRWAV_SUCCESS) { + result = DRWAV_ERROR; /* Just a safety check to make sure we never ever return success when pFile == NULL. */ + } + + return result; + } +#endif + + return DRWAV_SUCCESS; +} + +/* +_wfopen() isn't always available in all compilation environments. + + * Windows only. + * MSVC seems to support it universally as far back as VC6 from what I can tell (haven't checked further back). + * MinGW-64 (both 32- and 64-bit) seems to support it. + * MinGW wraps it in !defined(__STRICT_ANSI__). + * OpenWatcom wraps it in !defined(_NO_EXT_KEYS). + +This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs() +fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support. +*/ +#if defined(_WIN32) + #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS)) + #define DRWAV_HAS_WFOPEN + #endif +#endif + +static drwav_result drwav_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_t* pOpenMode, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (ppFile != NULL) { + *ppFile = NULL; /* Safety. */ + } + + if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) { + return DRWAV_INVALID_ARGS; + } + +#if defined(DRWAV_HAS_WFOPEN) + { + /* Use _wfopen() on Windows. */ + #if defined(_MSC_VER) && _MSC_VER >= 1400 + errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return drwav_result_from_errno(err); + } + #else + *ppFile = _wfopen(pFilePath, pOpenMode); + if (*ppFile == NULL) { + return drwav_result_from_errno(errno); + } + #endif + (void)pAllocationCallbacks; + } +#else + /* + Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can + think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for + maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. + */ + { + mbstate_t mbs; + size_t lenMB; + const wchar_t* pFilePathTemp = pFilePath; + char* pFilePathMB = NULL; + char pOpenModeMB[32] = {0}; + + /* Get the length first. */ + DRWAV_ZERO_OBJECT(&mbs); + lenMB = wcsrtombs(NULL, &pFilePathTemp, 0, &mbs); + if (lenMB == (size_t)-1) { + return drwav_result_from_errno(errno); + } + + pFilePathMB = (char*)drwav__malloc_from_callbacks(lenMB + 1, pAllocationCallbacks); + if (pFilePathMB == NULL) { + return DRWAV_OUT_OF_MEMORY; + } + + pFilePathTemp = pFilePath; + DRWAV_ZERO_OBJECT(&mbs); + wcsrtombs(pFilePathMB, &pFilePathTemp, lenMB + 1, &mbs); + + /* The open mode should always consist of ASCII characters so we should be able to do a trivial conversion. */ + { + size_t i = 0; + for (;;) { + if (pOpenMode[i] == 0) { + pOpenModeMB[i] = '\0'; + break; + } + + pOpenModeMB[i] = (char)pOpenMode[i]; + i += 1; + } + } + + *ppFile = fopen(pFilePathMB, pOpenModeMB); + + drwav__free_from_callbacks(pFilePathMB, pAllocationCallbacks); + } + + if (*ppFile == NULL) { + return DRWAV_ERROR; + } +#endif + + return DRWAV_SUCCESS; +} + + +static size_t drwav__on_read_stdio(void* pUserData, void* pBufferOut, size_t bytesToRead) +{ + return fread(pBufferOut, 1, bytesToRead, (FILE*)pUserData); +} + +static size_t drwav__on_write_stdio(void* pUserData, const void* pData, size_t bytesToWrite) +{ + return fwrite(pData, 1, bytesToWrite, (FILE*)pUserData); +} + +static drwav_bool32 drwav__on_seek_stdio(void* pUserData, int offset, drwav_seek_origin origin) +{ + return fseek((FILE*)pUserData, offset, (origin == drwav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; +} + +DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_ex(pWav, filename, NULL, NULL, 0, pAllocationCallbacks); +} + + +static drwav_bool32 drwav_init_file__internal_FILE(drwav* pWav, FILE* pFile, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav_bool32 result; + + result = drwav_preinit(pWav, drwav__on_read_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + result = drwav_init__internal(pWav, onChunk, pChunkUserData, flags); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_fopen(&pFile, filename, "rb") != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_ex_w(pWav, filename, NULL, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_wfopen(&pFile, filename, L"rb", pAllocationCallbacks) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks); +} + + +static drwav_bool32 drwav_init_file_write__internal_FILE(drwav* pWav, FILE* pFile, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav_bool32 result; + + result = drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + result = drwav_init_write__internal(pWav, pFormat, totalSampleCount); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init_file_write__internal(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_fopen(&pFile, filename, "wb") != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks); +} + +static drwav_bool32 drwav_init_file_write_w__internal(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_wfopen(&pFile, filename, L"wb", pAllocationCallbacks) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_file_write_sequential(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write_w__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write_w__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_file_write_sequential_w(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} +#endif /* DR_WAV_NO_STDIO */ + + +static size_t drwav__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead) +{ + drwav* pWav = (drwav*)pUserData; + size_t bytesRemaining; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->memoryStream.dataSize >= pWav->memoryStream.currentReadPos); + + bytesRemaining = pWav->memoryStream.dataSize - pWav->memoryStream.currentReadPos; + if (bytesToRead > bytesRemaining) { + bytesToRead = bytesRemaining; + } + + if (bytesToRead > 0) { + DRWAV_COPY_MEMORY(pBufferOut, pWav->memoryStream.data + pWav->memoryStream.currentReadPos, bytesToRead); + pWav->memoryStream.currentReadPos += bytesToRead; + } + + return bytesToRead; +} + +static drwav_bool32 drwav__on_seek_memory(void* pUserData, int offset, drwav_seek_origin origin) +{ + drwav* pWav = (drwav*)pUserData; + DRWAV_ASSERT(pWav != NULL); + + if (origin == drwav_seek_origin_current) { + if (offset > 0) { + if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) { + return DRWAV_FALSE; /* Trying to seek too far forward. */ + } + } else { + if (pWav->memoryStream.currentReadPos < (size_t)-offset) { + return DRWAV_FALSE; /* Trying to seek too far backwards. */ + } + } + + /* This will never underflow thanks to the clamps above. */ + pWav->memoryStream.currentReadPos += offset; + } else { + if ((drwav_uint32)offset <= pWav->memoryStream.dataSize) { + pWav->memoryStream.currentReadPos = offset; + } else { + return DRWAV_FALSE; /* Trying to seek too far forward. */ + } + } + + return DRWAV_TRUE; +} + +static size_t drwav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite) +{ + drwav* pWav = (drwav*)pUserData; + size_t bytesRemaining; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->memoryStreamWrite.dataCapacity >= pWav->memoryStreamWrite.currentWritePos); + + bytesRemaining = pWav->memoryStreamWrite.dataCapacity - pWav->memoryStreamWrite.currentWritePos; + if (bytesRemaining < bytesToWrite) { + /* Need to reallocate. */ + void* pNewData; + size_t newDataCapacity = (pWav->memoryStreamWrite.dataCapacity == 0) ? 256 : pWav->memoryStreamWrite.dataCapacity * 2; + + /* If doubling wasn't enough, just make it the minimum required size to write the data. */ + if ((newDataCapacity - pWav->memoryStreamWrite.currentWritePos) < bytesToWrite) { + newDataCapacity = pWav->memoryStreamWrite.currentWritePos + bytesToWrite; + } + + pNewData = drwav__realloc_from_callbacks(*pWav->memoryStreamWrite.ppData, newDataCapacity, pWav->memoryStreamWrite.dataCapacity, &pWav->allocationCallbacks); + if (pNewData == NULL) { + return 0; + } + + *pWav->memoryStreamWrite.ppData = pNewData; + pWav->memoryStreamWrite.dataCapacity = newDataCapacity; + } + + DRWAV_COPY_MEMORY(((drwav_uint8*)(*pWav->memoryStreamWrite.ppData)) + pWav->memoryStreamWrite.currentWritePos, pDataIn, bytesToWrite); + + pWav->memoryStreamWrite.currentWritePos += bytesToWrite; + if (pWav->memoryStreamWrite.dataSize < pWav->memoryStreamWrite.currentWritePos) { + pWav->memoryStreamWrite.dataSize = pWav->memoryStreamWrite.currentWritePos; + } + + *pWav->memoryStreamWrite.pDataSize = pWav->memoryStreamWrite.dataSize; + + return bytesToWrite; +} + +static drwav_bool32 drwav__on_seek_memory_write(void* pUserData, int offset, drwav_seek_origin origin) +{ + drwav* pWav = (drwav*)pUserData; + DRWAV_ASSERT(pWav != NULL); + + if (origin == drwav_seek_origin_current) { + if (offset > 0) { + if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) { + offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); /* Trying to seek too far forward. */ + } + } else { + if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) { + offset = -(int)pWav->memoryStreamWrite.currentWritePos; /* Trying to seek too far backwards. */ + } + } + + /* This will never underflow thanks to the clamps above. */ + pWav->memoryStreamWrite.currentWritePos += offset; + } else { + if ((drwav_uint32)offset <= pWav->memoryStreamWrite.dataSize) { + pWav->memoryStreamWrite.currentWritePos = offset; + } else { + pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; /* Trying to seek too far forward. */ + } + } + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_ex(pWav, data, dataSize, NULL, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (data == NULL || dataSize == 0) { + return DRWAV_FALSE; + } + + if (!drwav_preinit(pWav, drwav__on_read_memory, drwav__on_seek_memory, pWav, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + pWav->memoryStream.data = (const drwav_uint8*)data; + pWav->memoryStream.dataSize = dataSize; + pWav->memoryStream.currentReadPos = 0; + + return drwav_init__internal(pWav, onChunk, pChunkUserData, flags); +} + + +static drwav_bool32 drwav_init_memory_write__internal(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (ppData == NULL || pDataSize == NULL) { + return DRWAV_FALSE; + } + + *ppData = NULL; /* Important because we're using realloc()! */ + *pDataSize = 0; + + if (!drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_memory, drwav__on_seek_memory_write, pWav, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + pWav->memoryStreamWrite.ppData = ppData; + pWav->memoryStreamWrite.pDataSize = pDataSize; + pWav->memoryStreamWrite.dataSize = 0; + pWav->memoryStreamWrite.dataCapacity = 0; + pWav->memoryStreamWrite.currentWritePos = 0; + + return drwav_init_write__internal(pWav, pFormat, totalSampleCount); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_memory_write_sequential(pWav, ppData, pDataSize, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} + + + +DRWAV_API drwav_result drwav_uninit(drwav* pWav) +{ + drwav_result result = DRWAV_SUCCESS; + + if (pWav == NULL) { + return DRWAV_INVALID_ARGS; + } + + /* + If the drwav object was opened in write mode we'll need to finalize a few things: + - Make sure the "data" chunk is aligned to 16-bits for RIFF containers, or 64 bits for W64 containers. + - Set the size of the "data" chunk. + */ + if (pWav->onWrite != NULL) { + drwav_uint32 paddingSize = 0; + + /* Padding. Do not adjust pWav->dataChunkDataSize - this should not include the padding. */ + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + paddingSize = drwav__chunk_padding_size_riff(pWav->dataChunkDataSize); + } else { + paddingSize = drwav__chunk_padding_size_w64(pWav->dataChunkDataSize); + } + + if (paddingSize > 0) { + drwav_uint64 paddingData = 0; + drwav__write(pWav, &paddingData, paddingSize); /* Byte order does not matter for this. */ + } + + /* + Chunk sizes. When using sequential mode, these will have been filled in at initialization time. We only need + to do this when using non-sequential mode. + */ + if (pWav->onSeek && !pWav->isSequentialWrite) { + if (pWav->container == drwav_container_riff) { + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, 4, drwav_seek_origin_start)) { + drwav_uint32 riffChunkSize = drwav__riff_chunk_size_riff(pWav->dataChunkDataSize); + drwav__write_u32ne_to_le(pWav, riffChunkSize); + } + + /* the "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 4, drwav_seek_origin_start)) { + drwav_uint32 dataChunkSize = drwav__data_chunk_size_riff(pWav->dataChunkDataSize); + drwav__write_u32ne_to_le(pWav, dataChunkSize); + } + } else if (pWav->container == drwav_container_w64) { + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, 16, drwav_seek_origin_start)) { + drwav_uint64 riffChunkSize = drwav__riff_chunk_size_w64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, riffChunkSize); + } + + /* The "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 16, drwav_seek_origin_start)) { + drwav_uint64 dataChunkSize = drwav__data_chunk_size_w64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, dataChunkSize); + } + } else if (pWav->container == drwav_container_rf64) { + /* We only need to update the ds64 chunk. The "RIFF" and "data" chunks always have their sizes set to 0xFFFFFFFF for RF64. */ + int ds64BodyPos = 12 + 8; + + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, drwav_seek_origin_start)) { + drwav_uint64 riffChunkSize = drwav__riff_chunk_size_rf64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, riffChunkSize); + } + + /* The "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, drwav_seek_origin_start)) { + drwav_uint64 dataChunkSize = drwav__data_chunk_size_rf64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, dataChunkSize); + } + } + } + + /* Validation for sequential mode. */ + if (pWav->isSequentialWrite) { + if (pWav->dataChunkDataSize != pWav->dataChunkDataSizeTargetWrite) { + result = DRWAV_INVALID_FILE; + } + } + } + +#ifndef DR_WAV_NO_STDIO + /* + If we opened the file with drwav_open_file() we will want to close the file handle. We can know whether or not drwav_open_file() + was used by looking at the onRead and onSeek callbacks. + */ + if (pWav->onRead == drwav__on_read_stdio || pWav->onWrite == drwav__on_write_stdio) { + fclose((FILE*)pWav->pUserData); + } +#endif + + return result; +} + + + +DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut) +{ + size_t bytesRead; + + if (pWav == NULL || bytesToRead == 0) { + return 0; + } + + if (bytesToRead > pWav->bytesRemaining) { + bytesToRead = (size_t)pWav->bytesRemaining; + } + + if (pBufferOut != NULL) { + bytesRead = pWav->onRead(pWav->pUserData, pBufferOut, bytesToRead); + } else { + /* We need to seek. If we fail, we need to read-and-discard to make sure we get a good byte count. */ + bytesRead = 0; + while (bytesRead < bytesToRead) { + size_t bytesToSeek = (bytesToRead - bytesRead); + if (bytesToSeek > 0x7FFFFFFF) { + bytesToSeek = 0x7FFFFFFF; + } + + if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, drwav_seek_origin_current) == DRWAV_FALSE) { + break; + } + + bytesRead += bytesToSeek; + } + + /* When we get here we may need to read-and-discard some data. */ + while (bytesRead < bytesToRead) { + drwav_uint8 buffer[4096]; + size_t bytesSeeked; + size_t bytesToSeek = (bytesToRead - bytesRead); + if (bytesToSeek > sizeof(buffer)) { + bytesToSeek = sizeof(buffer); + } + + bytesSeeked = pWav->onRead(pWav->pUserData, buffer, bytesToSeek); + bytesRead += bytesSeeked; + + if (bytesSeeked < bytesToSeek) { + break; /* Reached the end. */ + } + } + } + + pWav->bytesRemaining -= bytesRead; + return bytesRead; +} + + + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + drwav_uint32 bytesPerFrame; + drwav_uint64 bytesToRead; /* Intentionally uint64 instead of size_t so we can do a check that we're not reading too much on 32-bit builds. */ + + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + /* Cannot use this function for compressed formats. */ + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + return 0; + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + bytesToRead = framesToRead * bytesPerFrame; + if (bytesToRead > DRWAV_SIZE_MAX) { + bytesToRead = (DRWAV_SIZE_MAX / bytesPerFrame) * bytesPerFrame; /* Round the number of bytes to read to a clean frame boundary. */ + } + + /* + Doing an explicit check here just to make it clear that we don't want to be attempt to read anything if there's no bytes to read. There + *could* be a time where it evaluates to 0 due to overflowing. + */ + if (bytesToRead == 0) { + return 0; + } + + return drwav_read_raw(pWav, (size_t)bytesToRead, pBufferOut) / bytesPerFrame; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut); + + if (pBufferOut != NULL) { + drwav__bswap_samples(pBufferOut, framesRead*pWav->channels, drwav_get_bytes_per_pcm_frame(pWav)/pWav->channels, pWav->translatedFormatTag); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + if (drwav__is_little_endian()) { + return drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut); + } else { + return drwav_read_pcm_frames_be(pWav, framesToRead, pBufferOut); + } +} + + + +DRWAV_API drwav_bool32 drwav_seek_to_first_pcm_frame(drwav* pWav) +{ + if (pWav->onWrite != NULL) { + return DRWAV_FALSE; /* No seeking in write mode. */ + } + + if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, drwav_seek_origin_start)) { + return DRWAV_FALSE; + } + + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + pWav->compressed.iCurrentPCMFrame = 0; + + /* Cached data needs to be cleared for compressed formats. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + DRWAV_ZERO_OBJECT(&pWav->msadpcm); + } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + DRWAV_ZERO_OBJECT(&pWav->ima); + } else { + DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */ + } + } + + pWav->bytesRemaining = pWav->dataChunkDataSize; + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex) +{ + /* Seeking should be compatible with wave files > 2GB. */ + + if (pWav == NULL || pWav->onSeek == NULL) { + return DRWAV_FALSE; + } + + /* No seeking in write mode. */ + if (pWav->onWrite != NULL) { + return DRWAV_FALSE; + } + + /* If there are no samples, just return DRWAV_TRUE without doing anything. */ + if (pWav->totalPCMFrameCount == 0) { + return DRWAV_TRUE; + } + + /* Make sure the sample is clamped. */ + if (targetFrameIndex >= pWav->totalPCMFrameCount) { + targetFrameIndex = pWav->totalPCMFrameCount - 1; + } + + /* + For compressed formats we just use a slow generic seek. If we are seeking forward we just seek forward. If we are going backwards we need + to seek back to the start. + */ + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + /* TODO: This can be optimized. */ + + /* + If we're seeking forward it's simple - just keep reading samples until we hit the sample we're requesting. If we're seeking backwards, + we first need to seek back to the start and then just do the same thing as a forward seek. + */ + if (targetFrameIndex < pWav->compressed.iCurrentPCMFrame) { + if (!drwav_seek_to_first_pcm_frame(pWav)) { + return DRWAV_FALSE; + } + } + + if (targetFrameIndex > pWav->compressed.iCurrentPCMFrame) { + drwav_uint64 offsetInFrames = targetFrameIndex - pWav->compressed.iCurrentPCMFrame; + + drwav_int16 devnull[2048]; + while (offsetInFrames > 0) { + drwav_uint64 framesRead = 0; + drwav_uint64 framesToRead = offsetInFrames; + if (framesToRead > drwav_countof(devnull)/pWav->channels) { + framesToRead = drwav_countof(devnull)/pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + framesRead = drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, devnull); + } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + framesRead = drwav_read_pcm_frames_s16__ima(pWav, framesToRead, devnull); + } else { + DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */ + } + + if (framesRead != framesToRead) { + return DRWAV_FALSE; + } + + offsetInFrames -= framesRead; + } + } + } else { + drwav_uint64 totalSizeInBytes; + drwav_uint64 currentBytePos; + drwav_uint64 targetBytePos; + drwav_uint64 offset; + + totalSizeInBytes = pWav->totalPCMFrameCount * drwav_get_bytes_per_pcm_frame(pWav); + DRWAV_ASSERT(totalSizeInBytes >= pWav->bytesRemaining); + + currentBytePos = totalSizeInBytes - pWav->bytesRemaining; + targetBytePos = targetFrameIndex * drwav_get_bytes_per_pcm_frame(pWav); + + if (currentBytePos < targetBytePos) { + /* Offset forwards. */ + offset = (targetBytePos - currentBytePos); + } else { + /* Offset backwards. */ + if (!drwav_seek_to_first_pcm_frame(pWav)) { + return DRWAV_FALSE; + } + offset = targetBytePos; + } + + while (offset > 0) { + int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset); + if (!pWav->onSeek(pWav->pUserData, offset32, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + + pWav->bytesRemaining -= offset32; + offset -= offset32; + } + } + + return DRWAV_TRUE; +} + + +DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData) +{ + size_t bytesWritten; + + if (pWav == NULL || bytesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesWritten = pWav->onWrite(pWav->pUserData, pData, bytesToWrite); + pWav->dataChunkDataSize += bytesWritten; + + return bytesWritten; +} + + +DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + drwav_uint64 bytesToWrite; + drwav_uint64 bytesWritten; + const drwav_uint8* pRunningData; + + if (pWav == NULL || framesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8); + if (bytesToWrite > DRWAV_SIZE_MAX) { + return 0; + } + + bytesWritten = 0; + pRunningData = (const drwav_uint8*)pData; + + while (bytesToWrite > 0) { + size_t bytesJustWritten; + drwav_uint64 bytesToWriteThisIteration; + + bytesToWriteThisIteration = bytesToWrite; + DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */ + + bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, pRunningData); + if (bytesJustWritten == 0) { + break; + } + + bytesToWrite -= bytesJustWritten; + bytesWritten += bytesJustWritten; + pRunningData += bytesJustWritten; + } + + return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels; +} + +DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + drwav_uint64 bytesToWrite; + drwav_uint64 bytesWritten; + drwav_uint32 bytesPerSample; + const drwav_uint8* pRunningData; + + if (pWav == NULL || framesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8); + if (bytesToWrite > DRWAV_SIZE_MAX) { + return 0; + } + + bytesWritten = 0; + pRunningData = (const drwav_uint8*)pData; + + bytesPerSample = drwav_get_bytes_per_pcm_frame(pWav) / pWav->channels; + + while (bytesToWrite > 0) { + drwav_uint8 temp[4096]; + drwav_uint32 sampleCount; + size_t bytesJustWritten; + drwav_uint64 bytesToWriteThisIteration; + + bytesToWriteThisIteration = bytesToWrite; + DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */ + + /* + WAV files are always little-endian. We need to byte swap on big-endian architectures. Since our input buffer is read-only we need + to use an intermediary buffer for the conversion. + */ + sampleCount = sizeof(temp)/bytesPerSample; + + if (bytesToWriteThisIteration > ((drwav_uint64)sampleCount)*bytesPerSample) { + bytesToWriteThisIteration = ((drwav_uint64)sampleCount)*bytesPerSample; + } + + DRWAV_COPY_MEMORY(temp, pRunningData, (size_t)bytesToWriteThisIteration); + drwav__bswap_samples(temp, sampleCount, bytesPerSample, pWav->translatedFormatTag); + + bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, temp); + if (bytesJustWritten == 0) { + break; + } + + bytesToWrite -= bytesJustWritten; + bytesWritten += bytesJustWritten; + pRunningData += bytesJustWritten; + } + + return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels; +} + +DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + if (drwav__is_little_endian()) { + return drwav_write_pcm_frames_le(pWav, framesToWrite, pData); + } else { + return drwav_write_pcm_frames_be(pWav, framesToWrite, pData); + } +} + + +static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead = 0; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(framesToRead > 0); + + /* TODO: Lots of room for optimization here. */ + + while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + /* If there are no cached frames we need to load a new block. */ + if (pWav->msadpcm.cachedFrameCount == 0 && pWav->msadpcm.bytesRemainingInBlock == 0) { + if (pWav->channels == 1) { + /* Mono. */ + drwav_uint8 header[7]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + pWav->msadpcm.predictor[0] = header[0]; + pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 1); + pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 3); + pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 5); + pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0]; + pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.cachedFrameCount = 2; + } else { + /* Stereo. */ + drwav_uint8 header[14]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + pWav->msadpcm.predictor[0] = header[0]; + pWav->msadpcm.predictor[1] = header[1]; + pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 2); + pWav->msadpcm.delta[1] = drwav__bytes_to_s16(header + 4); + pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 6); + pWav->msadpcm.prevFrames[1][1] = (drwav_int32)drwav__bytes_to_s16(header + 8); + pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 10); + pWav->msadpcm.prevFrames[1][0] = (drwav_int32)drwav__bytes_to_s16(header + 12); + + pWav->msadpcm.cachedFrames[0] = pWav->msadpcm.prevFrames[0][0]; + pWav->msadpcm.cachedFrames[1] = pWav->msadpcm.prevFrames[1][0]; + pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1]; + pWav->msadpcm.cachedFrameCount = 2; + } + } + + /* Output anything that's cached. */ + while (framesToRead > 0 && pWav->msadpcm.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + if (pBufferOut != NULL) { + drwav_uint32 iSample = 0; + for (iSample = 0; iSample < pWav->channels; iSample += 1) { + pBufferOut[iSample] = (drwav_int16)pWav->msadpcm.cachedFrames[(drwav_countof(pWav->msadpcm.cachedFrames) - (pWav->msadpcm.cachedFrameCount*pWav->channels)) + iSample]; + } + + pBufferOut += pWav->channels; + } + + framesToRead -= 1; + totalFramesRead += 1; + pWav->compressed.iCurrentPCMFrame += 1; + pWav->msadpcm.cachedFrameCount -= 1; + } + + if (framesToRead == 0) { + return totalFramesRead; + } + + + /* + If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next + loop iteration which will trigger the loading of a new block. + */ + if (pWav->msadpcm.cachedFrameCount == 0) { + if (pWav->msadpcm.bytesRemainingInBlock == 0) { + continue; + } else { + static drwav_int32 adaptationTable[] = { + 230, 230, 230, 230, 307, 409, 512, 614, + 768, 614, 512, 409, 307, 230, 230, 230 + }; + static drwav_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; + static drwav_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; + + drwav_uint8 nibbles; + drwav_int32 nibble0; + drwav_int32 nibble1; + + if (pWav->onRead(pWav->pUserData, &nibbles, 1) != 1) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock -= 1; + + /* TODO: Optimize away these if statements. */ + nibble0 = ((nibbles & 0xF0) >> 4); if ((nibbles & 0x80)) { nibble0 |= 0xFFFFFFF0UL; } + nibble1 = ((nibbles & 0x0F) >> 0); if ((nibbles & 0x08)) { nibble1 |= 0xFFFFFFF0UL; } + + if (pWav->channels == 1) { + /* Mono. */ + drwav_int32 newSample0; + drwav_int32 newSample1; + + newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample0 += nibble0 * pWav->msadpcm.delta[0]; + newSample0 = drwav_clamp(newSample0, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample0; + + + newSample1 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample1 += nibble1 * pWav->msadpcm.delta[0]; + newSample1 = drwav_clamp(newSample1, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample1; + + + pWav->msadpcm.cachedFrames[2] = newSample0; + pWav->msadpcm.cachedFrames[3] = newSample1; + pWav->msadpcm.cachedFrameCount = 2; + } else { + /* Stereo. */ + drwav_int32 newSample0; + drwav_int32 newSample1; + + /* Left. */ + newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample0 += nibble0 * pWav->msadpcm.delta[0]; + newSample0 = drwav_clamp(newSample0, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample0; + + + /* Right. */ + newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; + newSample1 += nibble1 * pWav->msadpcm.delta[1]; + newSample1 = drwav_clamp(newSample1, -32768, 32767); + + pWav->msadpcm.delta[1] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8; + if (pWav->msadpcm.delta[1] < 16) { + pWav->msadpcm.delta[1] = 16; + } + + pWav->msadpcm.prevFrames[1][0] = pWav->msadpcm.prevFrames[1][1]; + pWav->msadpcm.prevFrames[1][1] = newSample1; + + pWav->msadpcm.cachedFrames[2] = newSample0; + pWav->msadpcm.cachedFrames[3] = newSample1; + pWav->msadpcm.cachedFrameCount = 1; + } + } + } + } + + return totalFramesRead; +} + + +static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead = 0; + drwav_uint32 iChannel; + + static drwav_int32 indexTable[16] = { + -1, -1, -1, -1, 2, 4, 6, 8, + -1, -1, -1, -1, 2, 4, 6, 8 + }; + + static drwav_int32 stepTable[89] = { + 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, + 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, + 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, + 130, 143, 157, 173, 190, 209, 230, 253, 279, 307, + 337, 371, 408, 449, 494, 544, 598, 658, 724, 796, + 876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066, + 2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358, + 5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899, + 15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767 + }; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(framesToRead > 0); + + /* TODO: Lots of room for optimization here. */ + + while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + /* If there are no cached samples we need to load a new block. */ + if (pWav->ima.cachedFrameCount == 0 && pWav->ima.bytesRemainingInBlock == 0) { + if (pWav->channels == 1) { + /* Mono. */ + drwav_uint8 header[4]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + if (header[2] >= drwav_countof(stepTable)) { + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current); + pWav->ima.bytesRemainingInBlock = 0; + return totalFramesRead; /* Invalid data. */ + } + + pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0); + pWav->ima.stepIndex[0] = header[2]; + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[0]; + pWav->ima.cachedFrameCount = 1; + } else { + /* Stereo. */ + drwav_uint8 header[8]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + if (header[2] >= drwav_countof(stepTable) || header[6] >= drwav_countof(stepTable)) { + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current); + pWav->ima.bytesRemainingInBlock = 0; + return totalFramesRead; /* Invalid data. */ + } + + pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0); + pWav->ima.stepIndex[0] = header[2]; + pWav->ima.predictor[1] = drwav__bytes_to_s16(header + 4); + pWav->ima.stepIndex[1] = header[6]; + + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 2] = pWav->ima.predictor[0]; + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[1]; + pWav->ima.cachedFrameCount = 1; + } + } + + /* Output anything that's cached. */ + while (framesToRead > 0 && pWav->ima.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + if (pBufferOut != NULL) { + drwav_uint32 iSample; + for (iSample = 0; iSample < pWav->channels; iSample += 1) { + pBufferOut[iSample] = (drwav_int16)pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + iSample]; + } + pBufferOut += pWav->channels; + } + + framesToRead -= 1; + totalFramesRead += 1; + pWav->compressed.iCurrentPCMFrame += 1; + pWav->ima.cachedFrameCount -= 1; + } + + if (framesToRead == 0) { + return totalFramesRead; + } + + /* + If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next + loop iteration which will trigger the loading of a new block. + */ + if (pWav->ima.cachedFrameCount == 0) { + if (pWav->ima.bytesRemainingInBlock == 0) { + continue; + } else { + /* + From what I can tell with stereo streams, it looks like every 4 bytes (8 samples) is for one channel. So it goes 4 bytes for the + left channel, 4 bytes for the right channel. + */ + pWav->ima.cachedFrameCount = 8; + for (iChannel = 0; iChannel < pWav->channels; ++iChannel) { + drwav_uint32 iByte; + drwav_uint8 nibbles[4]; + if (pWav->onRead(pWav->pUserData, &nibbles, 4) != 4) { + pWav->ima.cachedFrameCount = 0; + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock -= 4; + + for (iByte = 0; iByte < 4; ++iByte) { + drwav_uint8 nibble0 = ((nibbles[iByte] & 0x0F) >> 0); + drwav_uint8 nibble1 = ((nibbles[iByte] & 0xF0) >> 4); + + drwav_int32 step = stepTable[pWav->ima.stepIndex[iChannel]]; + drwav_int32 predictor = pWav->ima.predictor[iChannel]; + + drwav_int32 diff = step >> 3; + if (nibble0 & 1) diff += step >> 2; + if (nibble0 & 2) diff += step >> 1; + if (nibble0 & 4) diff += step; + if (nibble0 & 8) diff = -diff; + + predictor = drwav_clamp(predictor + diff, -32768, 32767); + pWav->ima.predictor[iChannel] = predictor; + pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble0], 0, (drwav_int32)drwav_countof(stepTable)-1); + pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+0)*pWav->channels + iChannel] = predictor; + + + step = stepTable[pWav->ima.stepIndex[iChannel]]; + predictor = pWav->ima.predictor[iChannel]; + + diff = step >> 3; + if (nibble1 & 1) diff += step >> 2; + if (nibble1 & 2) diff += step >> 1; + if (nibble1 & 4) diff += step; + if (nibble1 & 8) diff = -diff; + + predictor = drwav_clamp(predictor + diff, -32768, 32767); + pWav->ima.predictor[iChannel] = predictor; + pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble1], 0, (drwav_int32)drwav_countof(stepTable)-1); + pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+1)*pWav->channels + iChannel] = predictor; + } + } + } + } + } + + return totalFramesRead; +} + + +#ifndef DR_WAV_NO_CONVERSION_API +static unsigned short g_drwavAlawTable[256] = { + 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580, + 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0, + 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600, + 0xD500, 0xD700, 0xD100, 0xD300, 0xDD00, 0xDF00, 0xD900, 0xDB00, 0xC500, 0xC700, 0xC100, 0xC300, 0xCD00, 0xCF00, 0xC900, 0xCB00, + 0xFEA8, 0xFEB8, 0xFE88, 0xFE98, 0xFEE8, 0xFEF8, 0xFEC8, 0xFED8, 0xFE28, 0xFE38, 0xFE08, 0xFE18, 0xFE68, 0xFE78, 0xFE48, 0xFE58, + 0xFFA8, 0xFFB8, 0xFF88, 0xFF98, 0xFFE8, 0xFFF8, 0xFFC8, 0xFFD8, 0xFF28, 0xFF38, 0xFF08, 0xFF18, 0xFF68, 0xFF78, 0xFF48, 0xFF58, + 0xFAA0, 0xFAE0, 0xFA20, 0xFA60, 0xFBA0, 0xFBE0, 0xFB20, 0xFB60, 0xF8A0, 0xF8E0, 0xF820, 0xF860, 0xF9A0, 0xF9E0, 0xF920, 0xF960, + 0xFD50, 0xFD70, 0xFD10, 0xFD30, 0xFDD0, 0xFDF0, 0xFD90, 0xFDB0, 0xFC50, 0xFC70, 0xFC10, 0xFC30, 0xFCD0, 0xFCF0, 0xFC90, 0xFCB0, + 0x1580, 0x1480, 0x1780, 0x1680, 0x1180, 0x1080, 0x1380, 0x1280, 0x1D80, 0x1C80, 0x1F80, 0x1E80, 0x1980, 0x1880, 0x1B80, 0x1A80, + 0x0AC0, 0x0A40, 0x0BC0, 0x0B40, 0x08C0, 0x0840, 0x09C0, 0x0940, 0x0EC0, 0x0E40, 0x0FC0, 0x0F40, 0x0CC0, 0x0C40, 0x0DC0, 0x0D40, + 0x5600, 0x5200, 0x5E00, 0x5A00, 0x4600, 0x4200, 0x4E00, 0x4A00, 0x7600, 0x7200, 0x7E00, 0x7A00, 0x6600, 0x6200, 0x6E00, 0x6A00, + 0x2B00, 0x2900, 0x2F00, 0x2D00, 0x2300, 0x2100, 0x2700, 0x2500, 0x3B00, 0x3900, 0x3F00, 0x3D00, 0x3300, 0x3100, 0x3700, 0x3500, + 0x0158, 0x0148, 0x0178, 0x0168, 0x0118, 0x0108, 0x0138, 0x0128, 0x01D8, 0x01C8, 0x01F8, 0x01E8, 0x0198, 0x0188, 0x01B8, 0x01A8, + 0x0058, 0x0048, 0x0078, 0x0068, 0x0018, 0x0008, 0x0038, 0x0028, 0x00D8, 0x00C8, 0x00F8, 0x00E8, 0x0098, 0x0088, 0x00B8, 0x00A8, + 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0, + 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350 +}; + +static unsigned short g_drwavMulawTable[256] = { + 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84, + 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84, + 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004, + 0xF0C4, 0xF144, 0xF1C4, 0xF244, 0xF2C4, 0xF344, 0xF3C4, 0xF444, 0xF4C4, 0xF544, 0xF5C4, 0xF644, 0xF6C4, 0xF744, 0xF7C4, 0xF844, + 0xF8A4, 0xF8E4, 0xF924, 0xF964, 0xF9A4, 0xF9E4, 0xFA24, 0xFA64, 0xFAA4, 0xFAE4, 0xFB24, 0xFB64, 0xFBA4, 0xFBE4, 0xFC24, 0xFC64, + 0xFC94, 0xFCB4, 0xFCD4, 0xFCF4, 0xFD14, 0xFD34, 0xFD54, 0xFD74, 0xFD94, 0xFDB4, 0xFDD4, 0xFDF4, 0xFE14, 0xFE34, 0xFE54, 0xFE74, + 0xFE8C, 0xFE9C, 0xFEAC, 0xFEBC, 0xFECC, 0xFEDC, 0xFEEC, 0xFEFC, 0xFF0C, 0xFF1C, 0xFF2C, 0xFF3C, 0xFF4C, 0xFF5C, 0xFF6C, 0xFF7C, + 0xFF88, 0xFF90, 0xFF98, 0xFFA0, 0xFFA8, 0xFFB0, 0xFFB8, 0xFFC0, 0xFFC8, 0xFFD0, 0xFFD8, 0xFFE0, 0xFFE8, 0xFFF0, 0xFFF8, 0x0000, + 0x7D7C, 0x797C, 0x757C, 0x717C, 0x6D7C, 0x697C, 0x657C, 0x617C, 0x5D7C, 0x597C, 0x557C, 0x517C, 0x4D7C, 0x497C, 0x457C, 0x417C, + 0x3E7C, 0x3C7C, 0x3A7C, 0x387C, 0x367C, 0x347C, 0x327C, 0x307C, 0x2E7C, 0x2C7C, 0x2A7C, 0x287C, 0x267C, 0x247C, 0x227C, 0x207C, + 0x1EFC, 0x1DFC, 0x1CFC, 0x1BFC, 0x1AFC, 0x19FC, 0x18FC, 0x17FC, 0x16FC, 0x15FC, 0x14FC, 0x13FC, 0x12FC, 0x11FC, 0x10FC, 0x0FFC, + 0x0F3C, 0x0EBC, 0x0E3C, 0x0DBC, 0x0D3C, 0x0CBC, 0x0C3C, 0x0BBC, 0x0B3C, 0x0ABC, 0x0A3C, 0x09BC, 0x093C, 0x08BC, 0x083C, 0x07BC, + 0x075C, 0x071C, 0x06DC, 0x069C, 0x065C, 0x061C, 0x05DC, 0x059C, 0x055C, 0x051C, 0x04DC, 0x049C, 0x045C, 0x041C, 0x03DC, 0x039C, + 0x036C, 0x034C, 0x032C, 0x030C, 0x02EC, 0x02CC, 0x02AC, 0x028C, 0x026C, 0x024C, 0x022C, 0x020C, 0x01EC, 0x01CC, 0x01AC, 0x018C, + 0x0174, 0x0164, 0x0154, 0x0144, 0x0134, 0x0124, 0x0114, 0x0104, 0x00F4, 0x00E4, 0x00D4, 0x00C4, 0x00B4, 0x00A4, 0x0094, 0x0084, + 0x0078, 0x0070, 0x0068, 0x0060, 0x0058, 0x0050, 0x0048, 0x0040, 0x0038, 0x0030, 0x0028, 0x0020, 0x0018, 0x0010, 0x0008, 0x0000 +}; + +static DRWAV_INLINE drwav_int16 drwav__alaw_to_s16(drwav_uint8 sampleIn) +{ + return (short)g_drwavAlawTable[sampleIn]; +} + +static DRWAV_INLINE drwav_int16 drwav__mulaw_to_s16(drwav_uint8 sampleIn) +{ + return (short)g_drwavMulawTable[sampleIn]; +} + + + +static void drwav__pcm_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_s16(pOut, pIn, totalSampleCount); + return; + } + + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + for (i = 0; i < totalSampleCount; ++i) { + *pOut++ = ((const drwav_int16*)pIn)[i]; + } + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_s16(pOut, pIn, totalSampleCount); + return; + } + if (bytesPerSample == 4) { + drwav_s32_to_s16(pOut, (const drwav_int32*)pIn, totalSampleCount); + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < totalSampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (drwav_int16)((drwav_int64)sample >> 48); + } +} + +static void drwav__ieee_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + drwav_f32_to_s16(pOut, (const float*)pIn, totalSampleCount); + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_s16(pOut, (const double*)pIn, totalSampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } +} + +static drwav_uint64 drwav_read_pcm_frames_s16__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint32 bytesPerFrame; + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + /* Fast path. */ + if ((pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 16) || pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(drwav_int16) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int16) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_s16__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_s16__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_s16__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_s16__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_s16__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = pIn[i]; + r = x << 8; + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = ((int)(((unsigned int)(((const drwav_uint8*)pIn)[i*3+0]) << 8) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+1]) << 16) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+2])) << 24)) >> 8; + r = x >> 8; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = pIn[i]; + r = x >> 16; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + float x = pIn[i]; + float c; + c = ((x < -1) ? -1 : ((x > 1) ? 1 : x)); + c = c + 1; + r = (int)(c * 32767.5f); + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + double x = pIn[i]; + double c; + c = ((x < -1) ? -1 : ((x > 1) ? 1 : x)); + c = c + 1; + r = (int)(c * 32767.5); + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + for (i = 0; i < sampleCount; ++i) { + pOut[i] = drwav__alaw_to_s16(pIn[i]); + } +} + +DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + for (i = 0; i < sampleCount; ++i) { + pOut[i] = drwav__mulaw_to_s16(pIn[i]); + } +} + + + +static void drwav__pcm_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_f32(pOut, pIn, sampleCount); + return; + } + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + drwav_s16_to_f32(pOut, (const drwav_int16*)pIn, sampleCount); + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_f32(pOut, pIn, sampleCount); + return; + } + if (bytesPerSample == 4) { + drwav_s32_to_f32(pOut, (const drwav_int32*)pIn, sampleCount); + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < sampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (float)((drwav_int64)sample / 9223372036854775807.0); + } +} + +static void drwav__ieee_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + unsigned int i; + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((const float*)pIn)[i]; + } + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_f32(pOut, (const double*)pIn, sampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut)); + return; + } +} + + +static drwav_uint64 drwav_read_pcm_frames_f32__pcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_f32(pBufferOut, sampleData, (size_t)framesRead*pWav->channels, bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__ima(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__ieee(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + /* Fast path. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT && pWav->bitsPerSample == 32) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__alaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__mulaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(float) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(float) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_f32__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_f32__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_f32__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_f32__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_f32__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_f32__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + +#ifdef DR_WAV_LIBSNDFILE_COMPAT + /* + It appears libsndfile uses slightly different logic for the u8 -> f32 conversion to dr_wav, which in my opinion is incorrect. It appears + libsndfile performs the conversion something like "f32 = (u8 / 256) * 2 - 1", however I think it should be "f32 = (u8 / 255) * 2 - 1" (note + the divisor of 256 vs 255). I use libsndfile as a benchmark for testing, so I'm therefore leaving this block here just for my automated + correctness testing. This is disabled by default. + */ + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (pIn[i] / 256.0f) * 2 - 1; + } +#else + for (i = 0; i < sampleCount; ++i) { + float x = pIn[i]; + x = x * 0.00784313725490196078f; /* 0..255 to 0..2 */ + x = x - 1; /* 0..2 to -1..1 */ + + *pOut++ = x; + } +#endif +} + +DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = pIn[i] * 0.000030517578125f; + } +} + +DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + double x; + drwav_uint32 a = ((drwav_uint32)(pIn[i*3+0]) << 8); + drwav_uint32 b = ((drwav_uint32)(pIn[i*3+1]) << 16); + drwav_uint32 c = ((drwav_uint32)(pIn[i*3+2]) << 24); + + x = (double)((drwav_int32)(a | b | c) >> 8); + *pOut++ = (float)(x * 0.00000011920928955078125); + } +} + +DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount) +{ + size_t i; + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (float)(pIn[i] / 2147483648.0); + } +} + +DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (float)pIn[i]; + } +} + +DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = drwav__alaw_to_s16(pIn[i]) / 32768.0f; + } +} + +DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = drwav__mulaw_to_s16(pIn[i]) / 32768.0f; + } +} + + + +static void drwav__pcm_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_s32(pOut, pIn, totalSampleCount); + return; + } + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + drwav_s16_to_s32(pOut, (const drwav_int16*)pIn, totalSampleCount); + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_s32(pOut, pIn, totalSampleCount); + return; + } + if (bytesPerSample == 4) { + for (i = 0; i < totalSampleCount; ++i) { + *pOut++ = ((const drwav_int32*)pIn)[i]; + } + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < totalSampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (drwav_int32)((drwav_int64)sample >> 32); + } +} + +static void drwav__ieee_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + drwav_f32_to_s32(pOut, (const float*)pIn, totalSampleCount); + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_s32(pOut, (const double*)pIn, totalSampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } +} + + +static drwav_uint64 drwav_read_pcm_frames_s32__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + /* Fast path. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 32) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(drwav_int32) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int32) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_s32__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_s32__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_s32__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_s32__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_s32__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_s32__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((int)pIn[i] - 128) << 24; + } +} + +DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = pIn[i] << 16; + } +} + +DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + unsigned int s0 = pIn[i*3 + 0]; + unsigned int s1 = pIn[i*3 + 1]; + unsigned int s2 = pIn[i*3 + 2]; + + drwav_int32 sample32 = (drwav_int32)((s0 << 8) | (s1 << 16) | (s2 << 24)); + *pOut++ = sample32; + } +} + +DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]); + } +} + +DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]); + } +} + +DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((drwav_int32)drwav__alaw_to_s16(pIn[i])) << 16; + } +} + +DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i= 0; i < sampleCount; ++i) { + *pOut++ = ((drwav_int32)drwav__mulaw_to_s16(pIn[i])) << 16; + } +} + + + +static drwav_int16* drwav__read_pcm_frames_and_close_s16(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + drwav_int16* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int16); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (drwav_int16*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_s16(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + +static float* drwav__read_pcm_frames_and_close_f32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + float* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (float*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_f32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + +static drwav_int32* drwav__read_pcm_frames_and_close_s32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + drwav_int32* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int32); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (drwav_int32*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_s32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + + + +DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +#ifndef DR_WAV_NO_STDIO +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + + +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} +#endif + +DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} +#endif /* DR_WAV_NO_CONVERSION_API */ + + +DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks != NULL) { + drwav__free_from_callbacks(p, pAllocationCallbacks); + } else { + drwav__free_default(p, NULL); + } +} + +DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data) +{ + return drwav__bytes_to_u16(data); +} + +DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data) +{ + return drwav__bytes_to_s16(data); +} + +DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data) +{ + return drwav__bytes_to_u32(data); +} + +DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data) +{ + return drwav__bytes_to_s32(data); +} + +DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data) +{ + return drwav__bytes_to_u64(data); +} + +DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data) +{ + return drwav__bytes_to_s64(data); +} + + +DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]) +{ + return drwav__guid_equal(a, b); +} + +DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b) +{ + return drwav__fourcc_equal(a, b); +} + +#endif /* dr_wav_c */ +#endif /* DR_WAV_IMPLEMENTATION */ + +/* +RELEASE NOTES - v0.11.0 +======================= +Version 0.11.0 has breaking API changes. + +Improved Client-Defined Memory Allocation +----------------------------------------- +The main change with this release is the addition of a more flexible way of implementing custom memory allocation routines. The +existing system of DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE are still in place and will be used by default when no custom +allocation callbacks are specified. + +To use the new system, you pass in a pointer to a drwav_allocation_callbacks object to drwav_init() and family, like this: + + void* my_malloc(size_t sz, void* pUserData) + { + return malloc(sz); + } + void* my_realloc(void* p, size_t sz, void* pUserData) + { + return realloc(p, sz); + } + void my_free(void* p, void* pUserData) + { + free(p); + } + + ... + + drwav_allocation_callbacks allocationCallbacks; + allocationCallbacks.pUserData = &myData; + allocationCallbacks.onMalloc = my_malloc; + allocationCallbacks.onRealloc = my_realloc; + allocationCallbacks.onFree = my_free; + drwav_init_file(&wav, "my_file.wav", &allocationCallbacks); + +The advantage of this new system is that it allows you to specify user data which will be passed in to the allocation routines. + +Passing in null for the allocation callbacks object will cause dr_wav to use defaults which is the same as DRWAV_MALLOC, +DRWAV_REALLOC and DRWAV_FREE and the equivalent of how it worked in previous versions. + +Every API that opens a drwav object now takes this extra parameter. These include the following: + + drwav_init() + drwav_init_ex() + drwav_init_file() + drwav_init_file_ex() + drwav_init_file_w() + drwav_init_file_w_ex() + drwav_init_memory() + drwav_init_memory_ex() + drwav_init_write() + drwav_init_write_sequential() + drwav_init_write_sequential_pcm_frames() + drwav_init_file_write() + drwav_init_file_write_sequential() + drwav_init_file_write_sequential_pcm_frames() + drwav_init_file_write_w() + drwav_init_file_write_sequential_w() + drwav_init_file_write_sequential_pcm_frames_w() + drwav_init_memory_write() + drwav_init_memory_write_sequential() + drwav_init_memory_write_sequential_pcm_frames() + drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_pcm_frames_f32() + drwav_open_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_s16_w() + drwav_open_file_and_read_pcm_frames_f32_w() + drwav_open_file_and_read_pcm_frames_s32_w() + drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_pcm_frames_f32() + drwav_open_memory_and_read_pcm_frames_s32() + +Endian Improvements +------------------- +Previously, the following APIs returned little-endian audio data. These now return native-endian data. This improves compatibility +on big-endian architectures. + + drwav_read_pcm_frames() + drwav_read_pcm_frames_s16() + drwav_read_pcm_frames_s32() + drwav_read_pcm_frames_f32() + drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_pcm_frames_s32() + drwav_open_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s16_w() + drwav_open_file_and_read_pcm_frames_s32_w() + drwav_open_file_and_read_pcm_frames_f32_w() + drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_pcm_frames_s32() + drwav_open_memory_and_read_pcm_frames_f32() + +APIs have been added to give you explicit control over whether or not audio data is read or written in big- or little-endian byte +order: + + drwav_read_pcm_frames_le() + drwav_read_pcm_frames_be() + drwav_read_pcm_frames_s16le() + drwav_read_pcm_frames_s16be() + drwav_read_pcm_frames_f32le() + drwav_read_pcm_frames_f32be() + drwav_read_pcm_frames_s32le() + drwav_read_pcm_frames_s32be() + drwav_write_pcm_frames_le() + drwav_write_pcm_frames_be() + +Removed APIs +------------ +The following APIs were deprecated in version 0.10.0 and have now been removed: + + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + + + +RELEASE NOTES - v0.10.0 +======================= +Version 0.10.0 has breaking API changes. There are no significant bug fixes in this release, so if you are affected you do +not need to upgrade. + +Removed APIs +------------ +The following APIs were deprecated in version 0.9.0 and have been completely removed in version 0.10.0: + + drwav_read() + drwav_read_s16() + drwav_read_f32() + drwav_read_s32() + drwav_seek_to_sample() + drwav_write() + drwav_open_and_read_s16() + drwav_open_and_read_f32() + drwav_open_and_read_s32() + drwav_open_file_and_read_s16() + drwav_open_file_and_read_f32() + drwav_open_file_and_read_s32() + drwav_open_memory_and_read_s16() + drwav_open_memory_and_read_f32() + drwav_open_memory_and_read_s32() + drwav::totalSampleCount + +See release notes for version 0.9.0 at the bottom of this file for replacement APIs. + +Deprecated APIs +--------------- +The following APIs have been deprecated. There is a confusing and completely arbitrary difference between drwav_init*() and +drwav_open*(), where drwav_init*() initializes a pre-allocated drwav object, whereas drwav_open*() will first allocated a +drwav object on the heap and then initialize it. drwav_open*() has been deprecated which means you must now use a pre- +allocated drwav object with drwav_init*(). If you need the previous functionality, you can just do a malloc() followed by +a called to one of the drwav_init*() APIs. + + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + +These APIs will be removed completely in a future version. The rationale for this change is to remove confusion between the +two different ways to initialize a drwav object. +*/ + +/* +REVISION HISTORY +================ +v0.12.16 - 2020-12-02 + - Fix a bug when trying to read more bytes than can fit in a size_t. + +v0.12.15 - 2020-11-21 + - Fix compilation with OpenWatcom. + +v0.12.14 - 2020-11-13 + - Minor code clean up. + +v0.12.13 - 2020-11-01 + - Improve compiler support for older versions of GCC. + +v0.12.12 - 2020-09-28 + - Add support for RF64. + - Fix a bug in writing mode where the size of the RIFF chunk incorrectly includes the header section. + +v0.12.11 - 2020-09-08 + - Fix a compilation error on older compilers. + +v0.12.10 - 2020-08-24 + - Fix a bug when seeking with ADPCM formats. + +v0.12.9 - 2020-08-02 + - Simplify sized types. + +v0.12.8 - 2020-07-25 + - Fix a compilation warning. + +v0.12.7 - 2020-07-15 + - Fix some bugs on big-endian architectures. + - Fix an error in s24 to f32 conversion. + +v0.12.6 - 2020-06-23 + - Change drwav_read_*() to allow NULL to be passed in as the output buffer which is equivalent to a forward seek. + - Fix a buffer overflow when trying to decode invalid IMA-ADPCM files. + - Add include guard for the implementation section. + +v0.12.5 - 2020-05-27 + - Minor documentation fix. + +v0.12.4 - 2020-05-16 + - Replace assert() with DRWAV_ASSERT(). + - Add compile-time and run-time version querying. + - DRWAV_VERSION_MINOR + - DRWAV_VERSION_MAJOR + - DRWAV_VERSION_REVISION + - DRWAV_VERSION_STRING + - drwav_version() + - drwav_version_string() + +v0.12.3 - 2020-04-30 + - Fix compilation errors with VC6. + +v0.12.2 - 2020-04-21 + - Fix a bug where drwav_init_file() does not close the file handle after attempting to load an erroneous file. + +v0.12.1 - 2020-04-13 + - Fix some pedantic warnings. + +v0.12.0 - 2020-04-04 + - API CHANGE: Add container and format parameters to the chunk callback. + - Minor documentation updates. + +v0.11.5 - 2020-03-07 + - Fix compilation error with Visual Studio .NET 2003. + +v0.11.4 - 2020-01-29 + - Fix some static analysis warnings. + - Fix a bug when reading f32 samples from an A-law encoded stream. + +v0.11.3 - 2020-01-12 + - Minor changes to some f32 format conversion routines. + - Minor bug fix for ADPCM conversion when end of file is reached. + +v0.11.2 - 2019-12-02 + - Fix a possible crash when using custom memory allocators without a custom realloc() implementation. + - Fix an integer overflow bug. + - Fix a null pointer dereference bug. + - Add limits to sample rate, channels and bits per sample to tighten up some validation. + +v0.11.1 - 2019-10-07 + - Internal code clean up. + +v0.11.0 - 2019-10-06 + - API CHANGE: Add support for user defined memory allocation routines. This system allows the program to specify their own memory allocation + routines with a user data pointer for client-specific contextual data. This adds an extra parameter to the end of the following APIs: + - drwav_init() + - drwav_init_ex() + - drwav_init_file() + - drwav_init_file_ex() + - drwav_init_file_w() + - drwav_init_file_w_ex() + - drwav_init_memory() + - drwav_init_memory_ex() + - drwav_init_write() + - drwav_init_write_sequential() + - drwav_init_write_sequential_pcm_frames() + - drwav_init_file_write() + - drwav_init_file_write_sequential() + - drwav_init_file_write_sequential_pcm_frames() + - drwav_init_file_write_w() + - drwav_init_file_write_sequential_w() + - drwav_init_file_write_sequential_pcm_frames_w() + - drwav_init_memory_write() + - drwav_init_memory_write_sequential() + - drwav_init_memory_write_sequential_pcm_frames() + - drwav_open_and_read_pcm_frames_s16() + - drwav_open_and_read_pcm_frames_f32() + - drwav_open_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_s16() + - drwav_open_file_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_s16_w() + - drwav_open_file_and_read_pcm_frames_f32_w() + - drwav_open_file_and_read_pcm_frames_s32_w() + - drwav_open_memory_and_read_pcm_frames_s16() + - drwav_open_memory_and_read_pcm_frames_f32() + - drwav_open_memory_and_read_pcm_frames_s32() + Set this extra parameter to NULL to use defaults which is the same as the previous behaviour. Setting this NULL will use + DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE. + - Add support for reading and writing PCM frames in an explicit endianness. New APIs: + - drwav_read_pcm_frames_le() + - drwav_read_pcm_frames_be() + - drwav_read_pcm_frames_s16le() + - drwav_read_pcm_frames_s16be() + - drwav_read_pcm_frames_f32le() + - drwav_read_pcm_frames_f32be() + - drwav_read_pcm_frames_s32le() + - drwav_read_pcm_frames_s32be() + - drwav_write_pcm_frames_le() + - drwav_write_pcm_frames_be() + - Remove deprecated APIs. + - API CHANGE: The following APIs now return native-endian data. Previously they returned little-endian data. + - drwav_read_pcm_frames() + - drwav_read_pcm_frames_s16() + - drwav_read_pcm_frames_s32() + - drwav_read_pcm_frames_f32() + - drwav_open_and_read_pcm_frames_s16() + - drwav_open_and_read_pcm_frames_s32() + - drwav_open_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s16() + - drwav_open_file_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s16_w() + - drwav_open_file_and_read_pcm_frames_s32_w() + - drwav_open_file_and_read_pcm_frames_f32_w() + - drwav_open_memory_and_read_pcm_frames_s16() + - drwav_open_memory_and_read_pcm_frames_s32() + - drwav_open_memory_and_read_pcm_frames_f32() + +v0.10.1 - 2019-08-31 + - Correctly handle partial trailing ADPCM blocks. + +v0.10.0 - 2019-08-04 + - Remove deprecated APIs. + - Add wchar_t variants for file loading APIs: + drwav_init_file_w() + drwav_init_file_ex_w() + drwav_init_file_write_w() + drwav_init_file_write_sequential_w() + - Add drwav_target_write_size_bytes() which calculates the total size in bytes of a WAV file given a format and sample count. + - Add APIs for specifying the PCM frame count instead of the sample count when opening in sequential write mode: + drwav_init_write_sequential_pcm_frames() + drwav_init_file_write_sequential_pcm_frames() + drwav_init_file_write_sequential_pcm_frames_w() + drwav_init_memory_write_sequential_pcm_frames() + - Deprecate drwav_open*() and drwav_close(): + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + - Minor documentation updates. + +v0.9.2 - 2019-05-21 + - Fix warnings. + +v0.9.1 - 2019-05-05 + - Add support for C89. + - Change license to choice of public domain or MIT-0. + +v0.9.0 - 2018-12-16 + - API CHANGE: Add new reading APIs for reading by PCM frames instead of samples. Old APIs have been deprecated and + will be removed in v0.10.0. Deprecated APIs and their replacements: + drwav_read() -> drwav_read_pcm_frames() + drwav_read_s16() -> drwav_read_pcm_frames_s16() + drwav_read_f32() -> drwav_read_pcm_frames_f32() + drwav_read_s32() -> drwav_read_pcm_frames_s32() + drwav_seek_to_sample() -> drwav_seek_to_pcm_frame() + drwav_write() -> drwav_write_pcm_frames() + drwav_open_and_read_s16() -> drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_f32() -> drwav_open_and_read_pcm_frames_f32() + drwav_open_and_read_s32() -> drwav_open_and_read_pcm_frames_s32() + drwav_open_file_and_read_s16() -> drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_f32() -> drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_s32() -> drwav_open_file_and_read_pcm_frames_s32() + drwav_open_memory_and_read_s16() -> drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_f32() -> drwav_open_memory_and_read_pcm_frames_f32() + drwav_open_memory_and_read_s32() -> drwav_open_memory_and_read_pcm_frames_s32() + drwav::totalSampleCount -> drwav::totalPCMFrameCount + - API CHANGE: Rename drwav_open_and_read_file_*() to drwav_open_file_and_read_*(). + - API CHANGE: Rename drwav_open_and_read_memory_*() to drwav_open_memory_and_read_*(). + - Add built-in support for smpl chunks. + - Add support for firing a callback for each chunk in the file at initialization time. + - This is enabled through the drwav_init_ex(), etc. family of APIs. + - Handle invalid FMT chunks more robustly. + +v0.8.5 - 2018-09-11 + - Const correctness. + - Fix a potential stack overflow. + +v0.8.4 - 2018-08-07 + - Improve 64-bit detection. + +v0.8.3 - 2018-08-05 + - Fix C++ build on older versions of GCC. + +v0.8.2 - 2018-08-02 + - Fix some big-endian bugs. + +v0.8.1 - 2018-06-29 + - Add support for sequential writing APIs. + - Disable seeking in write mode. + - Fix bugs with Wave64. + - Fix typos. + +v0.8 - 2018-04-27 + - Bug fix. + - Start using major.minor.revision versioning. + +v0.7f - 2018-02-05 + - Restrict ADPCM formats to a maximum of 2 channels. + +v0.7e - 2018-02-02 + - Fix a crash. + +v0.7d - 2018-02-01 + - Fix a crash. + +v0.7c - 2018-02-01 + - Set drwav.bytesPerSample to 0 for all compressed formats. + - Fix a crash when reading 16-bit floating point WAV files. In this case dr_wav will output silence for + all format conversion reading APIs (*_s16, *_s32, *_f32 APIs). + - Fix some divide-by-zero errors. + +v0.7b - 2018-01-22 + - Fix errors with seeking of compressed formats. + - Fix compilation error when DR_WAV_NO_CONVERSION_API + +v0.7a - 2017-11-17 + - Fix some GCC warnings. + +v0.7 - 2017-11-04 + - Add writing APIs. + +v0.6 - 2017-08-16 + - API CHANGE: Rename dr_* types to drwav_*. + - Add support for custom implementations of malloc(), realloc(), etc. + - Add support for Microsoft ADPCM. + - Add support for IMA ADPCM (DVI, format code 0x11). + - Optimizations to drwav_read_s16(). + - Bug fixes. + +v0.5g - 2017-07-16 + - Change underlying type for booleans to unsigned. + +v0.5f - 2017-04-04 + - Fix a minor bug with drwav_open_and_read_s16() and family. + +v0.5e - 2016-12-29 + - Added support for reading samples as signed 16-bit integers. Use the _s16() family of APIs for this. + - Minor fixes to documentation. + +v0.5d - 2016-12-28 + - Use drwav_int* and drwav_uint* sized types to improve compiler support. + +v0.5c - 2016-11-11 + - Properly handle JUNK chunks that come before the FMT chunk. + +v0.5b - 2016-10-23 + - A minor change to drwav_bool8 and drwav_bool32 types. + +v0.5a - 2016-10-11 + - Fixed a bug with drwav_open_and_read() and family due to incorrect argument ordering. + - Improve A-law and mu-law efficiency. + +v0.5 - 2016-09-29 + - API CHANGE. Swap the order of "channels" and "sampleRate" parameters in drwav_open_and_read*(). Rationale for this is to + keep it consistent with dr_audio and dr_flac. + +v0.4b - 2016-09-18 + - Fixed a typo in documentation. + +v0.4a - 2016-09-18 + - Fixed a typo. + - Change date format to ISO 8601 (YYYY-MM-DD) + +v0.4 - 2016-07-13 + - API CHANGE. Make onSeek consistent with dr_flac. + - API CHANGE. Rename drwav_seek() to drwav_seek_to_sample() for clarity and consistency with dr_flac. + - Added support for Sony Wave64. + +v0.3a - 2016-05-28 + - API CHANGE. Return drwav_bool32 instead of int in onSeek callback. + - Fixed a memory leak. + +v0.3 - 2016-05-22 + - Lots of API changes for consistency. + +v0.2a - 2016-05-16 + - Fixed Linux/GCC build. + +v0.2 - 2016-05-11 + - Added support for reading data as signed 32-bit PCM for consistency with dr_flac. + +v0.1a - 2016-05-07 + - Fixed a bug in drwav_open_file() where the file handle would not be closed if the loader failed to initialize. + +v0.1 - 2016-05-04 + - Initial versioned release. +*/ + +/* +This software is available as a choice of the following licenses. Choose +whichever you prefer. + +=============================================================================== +ALTERNATIVE 1 - Public Domain (www.unlicense.org) +=============================================================================== +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. + +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to <http://unlicense.org/> + +=============================================================================== +ALTERNATIVE 2 - MIT No Attribution +=============================================================================== +Copyright 2020 David Reid + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ diff --git a/Examples/OldMain/main.cpp b/Examples/OldMain/main.cpp new file mode 100644 index 0000000..6e991b7 --- /dev/null +++ b/Examples/OldMain/main.cpp @@ -0,0 +1,684 @@ +#include "whisper.h" + +// third-party utilities +// use your favorite implementations +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" + +#include <cmath> +#include <fstream> +#include <cstdio> +#include <string> +#include <thread> +#include <vector> + +// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] +// Lowest is red, middle is yellow, highest is green. +const std::vector<std::string> k_colors = { + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", +}; + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +std::string to_timestamp(int64_t t, bool comma = false) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + + return std::string(buf); +} + +int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + +// helper function to replace substrings +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + for (size_t pos = 0; ; pos += replace.length()) { + pos = s.find(search, pos); + if (pos == std::string::npos) break; + s.erase(pos, search.length()); + s.insert(pos, replace); + } +} + +// command-line parameters +struct whisper_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t max_context = -1; + int32_t max_len = 0; + + float word_thold = 0.01f; + + bool speed_up = false; + bool translate = false; + bool diarize = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool print_special = false; + bool print_colors = false; + bool print_progress = false; + bool no_timestamps = false; + + std::string language = "en"; + std::string prompt; + std::string model = "models/ggml-base.en.bin"; + + std::vector<std::string> fname_inp = {}; +}; + +void whisper_print_usage(int argc, char ** argv, const whisper_params & params); + +bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") { + whisper_print_usage(argc, argv, params); + exit(0); + } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, "\n"); +} + +struct whisper_print_user_data { + const whisper_params * params; + + const std::vector<std::vector<float>> * pcmf32s; +}; + +void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { + const auto & params = *((whisper_print_user_data *) user_data)->params; + const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; + + const int n_segments = whisper_full_n_segments(ctx); + + // print the last n_new segments + const int s0 = n_segments - n_new; + if (s0 == 0) { + printf("\n"); + } + + for (int i = s0; i < n_segments; i++) { + if (params.no_timestamps) { + if (params.print_colors) { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } + } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + printf("%s", text); + } + fflush(stdout); + } else { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + std::string speaker; + + if (params.diarize && pcmf32s.size() == 2) { + const int64_t n_samples = pcmf32s[0].size(); + + const int64_t is0 = timestamp_to_sample(t0, n_samples); + const int64_t is1 = timestamp_to_sample(t1, n_samples); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for (int64_t j = is0; j < is1; j++) { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } + + if (energy0 > 1.1*energy1) { + speaker = "(speaker 0)"; + } else if (energy1 > 1.1*energy0) { + speaker = "(speaker 1)"; + } else { + speaker = "(speaker ?)"; + } + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } + + if (params.print_colors) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } + } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); + } + printf("\n"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + + printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text); + } + } + } +} + +bool output_txt(struct whisper_context * ctx, const char * fname) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + fout << text << "\n"; + } + + return true; +} + +bool output_vtt(struct whisper_context * ctx, const char * fname) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + fout << "WEBVTT\n\n"; + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; + fout << text << "\n\n"; + } + + return true; +} + +bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + fout << i + 1 + params.offset_n << "\n"; + fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; + fout << text << "\n\n"; + } + + return true; +} + +// karaoke video generation +// outputs a bash script that uses ffmpeg to generate a video with the subtitles +// TODO: font parameter adjustments +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) { + std::ofstream fout(fname); + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + // TODO: become parameter + static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + + fout << "#!/bin/bash" << "\n"; + fout << "\n"; + + fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \""; + + for (int i = 0; i < whisper_full_n_segments(ctx); i++) { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + const int n = whisper_full_n_tokens(ctx, i); + + std::vector<whisper_token_data> tokens(n); + for (int j = 0; j < n; ++j) { + tokens[j] = whisper_full_get_token_data(ctx, i, j); + } + + if (i > 0) { + fout << ","; + } + + // background text + fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; + + bool is_first = true; + + for (int j = 0; j < n; ++j) { + const auto & token = tokens[j]; + + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + std::string txt_bg; + std::string txt_fg; // highlight token + std::string txt_ul; // underline + + txt_bg = "> "; + txt_fg = "> "; + txt_ul = "\\ \\ "; + + { + for (int k = 0; k < n; ++k) { + const auto & token2 = tokens[k]; + + if (tokens[k].id >= whisper_token_eot(ctx)) { + continue; + } + + const std::string txt = whisper_token_to_str(ctx, token2.id); + + txt_bg += txt; + + if (k == j) { + for (int l = 0; l < (int) txt.size(); ++l) { + txt_fg += txt[l]; + txt_ul += "_"; + } + txt_fg += "|"; + } else { + for (int l = 0; l < (int) txt.size(); ++l) { + txt_fg += "\\ "; + txt_ul += "\\ "; + } + } + } + + ::replace_all(txt_bg, "'", "\u2019"); + ::replace_all(txt_bg, "\"", "\\\""); + ::replace_all(txt_fg, "'", "\u2019"); + ::replace_all(txt_fg, "\"", "\\\""); + } + + if (is_first) { + // background text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'"; + is_first = false; + } + + // foreground text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; + + // underline + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; + } + } + + fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n"; + + fout << "\n\n"; + fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n"; + fout << "\n"; + fout << "echo \" ffplay " << fname_inp << ".mp4\"\n"; + fout << "\n"; + + fout.close(); + + fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname); + + return true; +} + +int main(int argc, char ** argv) { + whisper_params params; + + if (whisper_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + whisper_print_usage(argc, argv, params); + return 2; + } + + if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + + // whisper init + + struct whisper_context * ctx = whisper_init(params.model.c_str()); + + if (ctx == nullptr) { + fprintf(stderr, "error: failed to initialize whisper context\n"); + return 3; + } + + // initial prompt + std::vector<whisper_token> prompt_tokens; + + if (!params.prompt.empty()) { + prompt_tokens.resize(1024); + prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size())); + + fprintf(stderr, "\n"); + fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str()); + fprintf(stderr, "initial tokens: [ "); + for (int i = 0; i < (int) prompt_tokens.size(); ++i) { + fprintf(stderr, "%d ", prompt_tokens[i]); + } + fprintf(stderr, "]\n"); + } + + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { + const auto fname_inp = params.fname_inp[f]; + + std::vector<float> pcmf32; // mono-channel F32 PCM + std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM + + // WAV input + { + drwav wav; + std::vector<uint8_t> wav_data; // used for pipe input from stdin + + if (fname_inp == "-") { + { + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return 4; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } + else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); + return 5; + } + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); + return 6; + } + + if (params.diarize && wav.channels != 2 && params.no_timestamps == false) { + fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str()); + return 6; + } + + if (wav.sampleRate != WHISPER_SAMPLE_RATE) { + fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); + return 8; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); + return 9; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector<int16_t> pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + if (params.diarize) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } + } + + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); + } + + // print some info about the processing + { + fprintf(stderr, "\n"); + if (!whisper_is_multilingual(ctx)) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", + __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, + params.n_threads, params.n_processors, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1); + + fprintf(stderr, "\n"); + } + + // run the inference + { + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.speed_up = params.speed_up; + + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + + whisper_print_user_data user_data = { ¶ms, &pcmf32s }; + + // this callback is called on each new segment + if (!wparams.print_realtime) { + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = &user_data; + } + + // example for abort mechanism + // in this example, we do not abort the processing, but we could if the flag is set to true + // the callback is called before every encoder run - if it returns false, the processing is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 10; + } + } + + // output stuff + { + printf("\n"); + + // output to text file + if (params.output_txt) { + const auto fname_txt = fname_inp + ".txt"; + output_txt(ctx, fname_txt.c_str()); + } + + // output to VTT file + if (params.output_vtt) { + const auto fname_vtt = fname_inp + ".vtt"; + output_vtt(ctx, fname_vtt.c_str()); + } + + // output to SRT file + if (params.output_srt) { + const auto fname_srt = fname_inp + ".srt"; + output_srt(ctx, fname_srt.c_str(), params); + } + + // output to WTS file + if (params.output_wts) { + const auto fname_wts = fname_inp + ".wts"; + output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); + } + } + } + + whisper_print_timings(ctx); + whisper_free(ctx); + + return 0; +} diff --git a/Examples/TranscribeCS/AnsiCodes.cs b/Examples/TranscribeCS/AnsiCodes.cs new file mode 100644 index 0000000..be04ce3 --- /dev/null +++ b/Examples/TranscribeCS/AnsiCodes.cs @@ -0,0 +1,68 @@ +using System.Runtime.InteropServices; + +/// <summary>Utility class to setup console coloring with ANSI codes.</summary> +/// <remarks>The feature requires Windows 10 or newer</remarks> +static class AnsiCodes +{ + const string dll = "kernel32.dll"; + + [DllImport( dll, SetLastError = true )] + static extern IntPtr GetStdHandle( int nStdHandle ); + + const int STD_OUTPUT_HANDLE = -11; + + [Flags] + enum ConsoleModes: uint + { + // Input flags + ENABLE_PROCESSED_INPUT = 0x0001, + ENABLE_LINE_INPUT = 0x0002, + ENABLE_ECHO_INPUT = 0x0004, + ENABLE_WINDOW_INPUT = 0x0008, + ENABLE_MOUSE_INPUT = 0x0010, + ENABLE_INSERT_MODE = 0x0020, + ENABLE_QUICK_EDIT_MODE = 0x0040, + ENABLE_EXTENDED_FLAGS = 0x0080, + ENABLE_AUTO_POSITION = 0x0100, + ENABLE_VIRTUAL_TERMINAL_INPUT = 0x0200, + + // Output flags + ENABLE_PROCESSED_OUTPUT = 0x0001, + ENABLE_WRAP_AT_EOL_OUTPUT = 0x0002, + ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004, + DISABLE_NEWLINE_AUTO_RETURN = 0x0008, + ENABLE_LVB_GRID_WORLDWIDE = 0x0010 + } + + [DllImport( dll, SetLastError = true )] + static extern bool GetConsoleMode( IntPtr hConsoleHandle, out ConsoleModes mode ); + + [DllImport( dll, SetLastError = true )] + static extern bool SetConsoleMode( IntPtr hConsoleHandle, ConsoleModes mode ); + + static AnsiCodes() + { + IntPtr h = GetStdHandle( STD_OUTPUT_HANDLE ); + IntPtr INVALID_HANDLE_VALUE = (IntPtr)( -1 ); + if( h == INVALID_HANDLE_VALUE ) + return; + + if( !GetConsoleMode( h, out ConsoleModes mode ) ) + return; + + if( mode.HasFlag( ConsoleModes.ENABLE_VIRTUAL_TERMINAL_PROCESSING ) ) + { + enabled = true; + return; + } + + mode |= ConsoleModes.ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if( SetConsoleMode( h, mode ) ) + { + enabled = true; + return; + } + } + + public static readonly bool enabled = false; +}
\ No newline at end of file diff --git a/Examples/TranscribeCS/CommandLineArgs.cs b/Examples/TranscribeCS/CommandLineArgs.cs new file mode 100644 index 0000000..4f9fb74 --- /dev/null +++ b/Examples/TranscribeCS/CommandLineArgs.cs @@ -0,0 +1,155 @@ +using System.Globalization; +using System.Reflection; +using Whisper; + +namespace TranscribeCS +{ + sealed record class CommandLineArgs + { + public int n_threads = Environment.ProcessorCount; + public int offset_t_ms = 0; + public int offset_n = 0; + public int duration_ms = 0; + public int max_context = -1; + public int max_len = 0; + + public float word_thold = 0.01f; + + public bool speed_up = false; + public bool translate = false; + public bool diarize = false; + public bool output_txt = false; + public bool output_vtt = false; + public bool output_srt = false; + public bool print_special = false; + public bool print_progress = false; + public bool print_colors = true; + public bool no_timestamps = false; + public int[]? prompt = null; + + public eLanguage language = eLanguage.English; + public string model = string.Empty; + public readonly List<string> fileNames = new List<string>(); + + const bool output_wts = false; + public void apply( ref Parameters p ) + { + p.setFlag( eFullParamsFlags.PrintRealtime, false ); + p.setFlag( eFullParamsFlags.PrintProgress, print_progress ); + p.setFlag( eFullParamsFlags.PrintTimestamps, !no_timestamps ); + p.setFlag( eFullParamsFlags.PrintSpecial, print_special ); + p.setFlag( eFullParamsFlags.Translate, translate ); + p.language = language; + p.cpuThreads = n_threads; + if( max_context >= 0 ) + p.n_max_text_ctx = max_context; + p.offset_ms = offset_t_ms; + p.duration_ms = duration_ms; + p.setFlag( eFullParamsFlags.TokenTimestamps, output_wts || max_len > 0 ); + p.thold_pt = word_thold; + p.max_len = output_wts && max_len == 0 ? 60 : max_len; + p.setFlag( eFullParamsFlags.SpeedupAudio, speed_up ); + } + + public eResultFlags resultFlags() + { + eResultFlags flags = eResultFlags.None; + bool wts = output_wts || max_len > 0; + if( !no_timestamps || wts ) + flags |= eResultFlags.Timestamps; + if( wts || print_colors ) + flags |= eResultFlags.Tokens; + return flags; + } + + static eLanguage parseLanguage( string lang ) => + Library.languageFromCode( lang ) ?? throw new ArgumentException( $"Unknown language code \"{lang}\"" ); + + public CommandLineArgs( string[] argv ) + { + for( int i = 0; i < argv.Length; i++ ) + { + string arg = argv[ i ]; + if( arg[ 0 ] != '-' ) + { + fileNames.Add( arg ); + continue; + } + if( arg == "-h" || arg == "--help" ) + { + printUsage(); + throw new OperationCanceledException(); + } + else if( arg == "-t" || arg == "--threads" ) n_threads = int.Parse( argv[ ++i ] ); + else if( arg == "-ot" || arg == "--offset-t" ) offset_t_ms = int.Parse( argv[ ++i ] ); + else if( arg == "-on" || arg == "--offset-n" ) offset_n = int.Parse( argv[ ++i ] ); + else if( arg == "-d" || arg == "--duration" ) duration_ms = int.Parse( argv[ ++i ] ); + else if( arg == "-mc" || arg == "--max-context" ) max_context = int.Parse( argv[ ++i ] ); + else if( arg == "-ml" || arg == "--max-len" ) max_len = int.Parse( argv[ ++i ] ); + else if( arg == "-wt" || arg == "--word-thold" ) word_thold = float.Parse( argv[ ++i ], CultureInfo.InvariantCulture ); + else if( arg == "-su" || arg == "--speed-up" ) speed_up = true; + else if( arg == "-tr" || arg == "--translate" ) translate = true; + else if( arg == "-di" || arg == "--diarize" ) diarize = true; + else if( arg == "-otxt" || arg == "--output-txt" ) output_txt = true; + else if( arg == "-ovtt" || arg == "--output-vtt" ) output_vtt = true; + else if( arg == "-osrt" || arg == "--output-srt" ) output_srt = true; + else if( arg == "-ps" || arg == "--print-special" ) print_special = true; + else if( arg == "-nc" || arg == "--no-colors" ) print_colors = false; + else if( arg == "-pp" || arg == "--print-progress" ) print_progress = true; + else if( arg == "-nt" || arg == "--no-timestamps" ) no_timestamps = true; + else if( arg == "-l" || arg == "--language" ) language = parseLanguage( argv[ ++i ] ); + else if( arg == "--prompt" ) prompt = parsePrompt( argv[ ++i ] ); + else if( arg == "-m" || arg == "--model" ) model = argv[ ++i ]; + else if( arg == "-f" || arg == "--file" ) fileNames.Add( argv[ ++i ] ); + else + throw new ArgumentException( $"Unknown argument: \"{arg}\"" ); + } + if( string.IsNullOrWhiteSpace( model ) ) + throw new ArgumentException( "The model file is not provided in the arguments" ); + if( !File.Exists( model ) ) + throw new FileNotFoundException( "Model not found", model ); + if( fileNames.Count <= 0 ) + throw new ArgumentException( "Please supply at least 1 input audio file to process" ); + } + + static string cstr( bool b ) => b.ToString(); + + static int[]? parsePrompt( string str ) + { + if( string.IsNullOrWhiteSpace( str ) ) + return null; + // TODO: expose whisper_tokenize function, as a method of iModel COM interface + throw new NotImplementedException(); + } + + void printUsage() + { + Console.WriteLine(); + + Console.WriteLine( "usage: {0} [options] file0.mp3 file1.wma ...", Path.GetFileName( Assembly.GetExecutingAssembly().Location ) ); + Console.WriteLine(); + Console.WriteLine( "options:" ); + Console.WriteLine( " -h, --help [default] show this help message and exit" ); + Console.WriteLine( " -t N, --threads N [{0,-7:D}] number of threads to use during computation", n_threads ); + Console.WriteLine( " -ot N, --offset-t N [{0,-7:D}] time offset in milliseconds", offset_t_ms ); + Console.WriteLine( " -on N, --offset-n N [{0,-7:D}] segment index offset", offset_n ); + Console.WriteLine( " -d N, --duration N [{0,-7:D}] duration of audio to process in milliseconds", duration_ms ); + Console.WriteLine( " -mc N, --max-context N [{0,-7:D}] maximum number of text context tokens to store", max_context ); + Console.WriteLine( " -ml N, --max-len N [{0,-7:D}] maximum segment length in characters", max_len ); + Console.WriteLine( " -wt N, --word-thold N [{0,-7:F2}] word timestamp probability threshold", word_thold ); + Console.WriteLine( " -su, --speed-up [{0,-7}] speed up audio by x2 (reduced accuracy)", cstr( speed_up ) ); + Console.WriteLine( " -tr, --translate [{0,-7}] translate from source language to english", cstr( translate ) ); + Console.WriteLine( " -di, --diarize [{0,-7}] stereo audio diarization", cstr( diarize ) ); + Console.WriteLine( " -otxt, --output-txt [{0,-7}] output result in a text file", cstr( output_txt ) ); + Console.WriteLine( " -ovtt, --output-vtt [{0,-7}] output result in a vtt file", cstr( output_vtt ) ); + Console.WriteLine( " -osrt, --output-srt [{0,-7}] output result in a srt file", cstr( output_srt ) ); + Console.WriteLine( " -ps, --print-special [{0,-7}] print special tokens", cstr( print_special ) ); + Console.WriteLine( " -nc, --no-colors [{0,-7}] do not print colors", cstr( !print_colors ) ); + Console.WriteLine( " -nt, --no-timestamps [{0,-7}] do not print timestamps", cstr( no_timestamps ) ); + Console.WriteLine( " -l LANG, --language LANG [{0,-7}] spoken language", language.getCode() ); + Console.WriteLine( " --prompt PROMPT [ ] initial prompt" ); + Console.WriteLine( " -m FNAME, --model FNAME [{0,-7}] model path", model ); + Console.WriteLine( " -f FNAME, --file FNAME [{0,-7}] path of the input audio file", "" ); + } + } +}
\ No newline at end of file diff --git a/Examples/TranscribeCS/Transcribe.cs b/Examples/TranscribeCS/Transcribe.cs new file mode 100644 index 0000000..6a1e500 --- /dev/null +++ b/Examples/TranscribeCS/Transcribe.cs @@ -0,0 +1,114 @@ +using System.Globalization; +using Whisper; + +namespace TranscribeCS +{ + /// <summary>Implementation of Callbacks abstract class, to print these segments as soon as they’re produced by the library.</summary> + sealed class Transcribe: Callbacks + { + readonly CommandLineArgs args; + readonly eResultFlags resultFlags; + + public Transcribe( CommandLineArgs args ) + { + this.args = args; + resultFlags = args.resultFlags(); + Console.OutputEncoding = System.Text.Encoding.UTF8; + } + + // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] + // Lowest is red, middle is yellow, highest is green. + readonly string[] k_colors = new string[] + { + "\x1B[38;5;196m", "\x1B[38;5;202m", "\x1B[38;5;208m", "\x1B[38;5;214m", "\x1B[38;5;220m", + "\x1B[38;5;226m", "\x1B[38;5;190m", "\x1B[38;5;154m", "\x1B[38;5;118m", "\x1B[38;5;82m" + }; + + int colorIndex( in sToken tok ) + { + float p = tok.probability; + float p3 = p * p * p; + int col = (int)( p3 * k_colors.Length ); + col = Math.Clamp( col, 0, k_colors.Length - 1 ); + return col; + } + + public static string printTime( TimeSpan ts ) => + ts.ToString( "hh':'mm':'ss'.'fff", CultureInfo.InvariantCulture ); + public static string printTimeWithComma( TimeSpan ts ) => + ts.ToString( "hh':'mm':'ss','fff", CultureInfo.InvariantCulture ); + + protected override void onNewSegment( Context sender, int countNew ) + { + TranscribeResult res = sender.results( resultFlags ); + ReadOnlySpan<sToken> tokens = res.tokens; + + int s0 = res.segments.Length - countNew; + if( s0 == 0 ) + Console.WriteLine(); + + for( int i = s0; i < res.segments.Length; i++ ) + { + sSegment seg = res.segments[ i ]; + + if( args.no_timestamps ) + { + if( args.print_colors && AnsiCodes.enabled ) + { + foreach( sToken tok in res.getTokens( seg ) ) + { + if( !args.print_special && tok.hasFlag( eTokenFlags.Special ) ) + continue; + Console.Write( "{0}{1}{2}", k_colors[ colorIndex( tok ) ], tok.text, "\x1B[0m" ); + } + } + else + Console.Write( seg.text ); + Console.Out.Flush(); + continue; + } + + string speaker = ""; +#if false + if( args.diarize && pcmf32s.size() == 2 ) + { + const size_t n_samples = pcmf32s[ 0 ].size(); + const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples ); + const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples ); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for( int64_t j = is0; j < is1; j++ ) + { + energy0 += fabs( pcmf32s[ 0 ][ j ] ); + energy1 += fabs( pcmf32s[ 1 ][ j ] ); + } + + if( energy0 > 1.1 * energy1 ) + speaker = "(speaker 0)"; + else if( energy1 > 1.1 * energy0 ) + speaker = "(speaker 1)"; + else + speaker = "(speaker ?)"; + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } +#endif + if( args.print_colors && AnsiCodes.enabled ) + { + Console.Write( "[{0} --> {1}] ", printTime( seg.time.begin ), printTime( seg.time.end ) ); + foreach( sToken tok in res.getTokens( seg ) ) + { + if( !args.print_special && tok.hasFlag( eTokenFlags.Special ) ) + continue; + Console.Write( "{0}{1}{2}{3}", speaker, k_colors[ colorIndex( tok ) ], tok.text, "\x1B[0m" ); + } + Console.WriteLine(); + } + else + Console.WriteLine( "[{0} --> {1}] {2}{3}", printTime( seg.time.begin ), printTime( seg.time.end ), speaker, seg.text ); + } + } + } +}
\ No newline at end of file diff --git a/Examples/TranscribeCS/TranscribeCS.cs b/Examples/TranscribeCS/TranscribeCS.cs new file mode 100644 index 0000000..9b828e3 --- /dev/null +++ b/Examples/TranscribeCS/TranscribeCS.cs @@ -0,0 +1,102 @@ +using Whisper; + +namespace TranscribeCS +{ + static class Program + { + static readonly bool streamAudio = true; + + static int Main( string[] args ) + { + try + { + CommandLineArgs cla; + try + { + cla = new CommandLineArgs( args ); + } + catch( OperationCanceledException ) + { + return 1; + } + const eLoggerFlags loggerFlags = eLoggerFlags.UseStandardError | eLoggerFlags.SkipFormatMessage; + Library.setLogSink( eLogLevel.Debug, loggerFlags ); + + using iModel model = Library.loadModel( cla.model ); + using Context context = model.createContext(); + cla.apply( ref context.parameters ); + using iMediaFoundation mf = Library.initMediaFoundation(); + Transcribe transcribe = new Transcribe( cla ); + + foreach( string audioFile in cla.fileNames ) + { + if( streamAudio ) + { + using iAudioReader reader = mf.openAudioFile( audioFile ); + context.runFull( reader, transcribe, null, cla.prompt ); + } + else + { + using iAudioBuffer buffer = mf.loadAudioFile( audioFile ); + context.runFull( buffer, transcribe, cla.prompt ); + } + // When asked to, produce these text files + if( cla.output_txt ) + writeTextFile( context, audioFile ); + if( cla.output_srt ) + writeSubRip( context, audioFile, cla ); + if( cla.output_vtt ) + writeWebVTT( context, audioFile ); + } + + context.timingsPrint(); + return 0; + } + catch( Exception ex ) + { + Console.WriteLine( ex.Message ); + return ex.HResult; + } + } + + static void writeTextFile( Context context, string audioPath ) + { + using var stream = File.CreateText( Path.ChangeExtension( audioPath, ".txt" ) ); + foreach( sSegment seg in context.results().segments ) + stream.WriteLine( seg.text ); + } + + static void writeSubRip( Context context, string audioPath, CommandLineArgs cliArgs ) + { + using var stream = File.CreateText( Path.ChangeExtension( audioPath, ".srt" ) ); + var segments = context.results( eResultFlags.Timestamps ).segments; + + for( int i = 0; i < segments.Length; i++ ) + { + stream.WriteLine( i + 1 + cliArgs.offset_n ); + sSegment seg = segments[ i ]; + string begin = Transcribe.printTimeWithComma( seg.time.begin ); + string end = Transcribe.printTimeWithComma( seg.time.end ); + stream.WriteLine( "{0} --> {1}", begin, end ); + stream.WriteLine( seg.text ); + stream.WriteLine(); + } + } + + static void writeWebVTT( Context context, string audioPath ) + { + using var stream = File.CreateText( Path.ChangeExtension( audioPath, ".vtt" ) ); + stream.WriteLine( "WEBVTT" ); + stream.WriteLine(); + + foreach( sSegment seg in context.results( eResultFlags.Timestamps ).segments ) + { + string begin = Transcribe.printTime( seg.time.begin ); + string end = Transcribe.printTime( seg.time.end ); + stream.WriteLine( "{0} --> {1}", begin, end ); + stream.WriteLine( seg.text ); + stream.WriteLine(); + } + } + } +}
\ No newline at end of file diff --git a/Examples/TranscribeCS/TranscribeCS.csproj b/Examples/TranscribeCS/TranscribeCS.csproj new file mode 100644 index 0000000..e9b8d0f --- /dev/null +++ b/Examples/TranscribeCS/TranscribeCS.csproj @@ -0,0 +1,19 @@ +<Project Sdk="Microsoft.NET.Sdk"> + <PropertyGroup> + <OutputType>Exe</OutputType> + <TargetFramework>net6.0-windows</TargetFramework> + <ImplicitUsings>enable</ImplicitUsings> + <Nullable>enable</Nullable> + <CheckForOverflowUnderflow>true</CheckForOverflowUnderflow> + <AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath> + <Platforms>x64</Platforms> + </PropertyGroup> + <ItemGroup> + <Content Include="..\..\x64\$(Configuration)\Whisper.dll" Link="Whisper.dll"> + <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> + </Content> + </ItemGroup> + <ItemGroup> + <ProjectReference Include="..\..\WhisperNet\WhisperNet.csproj" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/Examples/WhisperDesktop/AppState.cpp b/Examples/WhisperDesktop/AppState.cpp new file mode 100644 index 0000000..6697e6a --- /dev/null +++ b/Examples/WhisperDesktop/AppState.cpp @@ -0,0 +1,192 @@ +#include "stdafx.h" +#include "AppState.h" +#include "Utils/miscUtils.h" +#include <commctrl.h> +#pragma comment(lib, "Comctl32.lib") +// #pragma comment(linker,"/manifestdependency:\"type='win32' name='Microsoft.Windows.Common-Controls' version='6.0.0.0' processorArchitecture='amd64' publicKeyToken='6595b64144ccf1df' language='*'\"") +#include "CircleIndicator.h" + +namespace +{ + static const HKEY regKeyRoot = HKEY_CURRENT_USER; + const LPCTSTR regKey = LR"(SOFTWARE\const.me\WhisperDesktop)"; + const LPCTSTR regValPath = L"modelPath"; + const LPCTSTR regValImpl = L"modelImpl"; + const LPCTSTR regValLang = L"language"; + const LPCTSTR regValLastScreen = L"screen"; + + static HRESULT readString( CRegKey& k, LPCTSTR name, CString& rdi ) + { + ULONG nChars = 0; + LSTATUS lss = k.QueryStringValue( name, nullptr, &nChars ); + if( lss != ERROR_SUCCESS ) + return HRESULT_FROM_WIN32( lss ); + if( nChars == 0 ) + { + rdi = L""; + return S_FALSE; + } + + lss = k.QueryStringValue( name, rdi.GetBufferSetLength( nChars ), &nChars ); + rdi.ReleaseBuffer(); + if( lss != ERROR_SUCCESS ) + return HRESULT_FROM_WIN32( lss ); + + return S_OK; + } + + using Whisper::eModelImplementation; +} + +HRESULT AppState::startup() +{ + HRESULT hr = CoInitializeEx( nullptr, COINIT_MULTITHREADED ); + if( FAILED( hr ) ) + { + reportFatalError( "CoInitializeEx failed", hr ); + return hr; + } + coInit = true; + + LSTATUS lss = registryKey.Create( regKeyRoot, regKey ); + if( lss != ERROR_SUCCESS ) + { + hr = HRESULT_FROM_WIN32( lss ); + reportFatalError( "Unable to open the registry key", hr ); + return hr; + } + + INITCOMMONCONTROLSEX init; + init.dwSize = sizeof( init ); + init.dwICC = ICC_LINK_CLASS | ICC_PROGRESS_CLASS | ICC_STANDARD_CLASSES | ICC_TAB_CLASSES; + const BOOL icc = InitCommonControlsEx( &init ); + if( !icc ) + { + reportFatalError( "InitCommonControlsEx failed", HRESULT_FROM_WIN32( GetLastError() ) ); + return E_FAIL; + } + + hr = initMediaFoundation( &mediaFoundation ); + if( FAILED( hr ) ) + { + reportFatalError( "Unable to initialize Media Foundation runtime", hr ); + return hr; + } + + hr = console.initialize(); + if( FAILED( hr ) ) + { + reportFatalError( "Unable to initialize logging", hr ); + return hr; + } + + hr = CircleIndicator::registerClass(); + if( FAILED( hr ) ) + { + reportFatalError( "Unable to register custom controls", hr ); + return hr; + } + appIcon.LoadIcon( IDI_WHISPERDESKTOP ); + return S_OK; +} + +AppState::~AppState() +{ + if( coInit ) + { + CoUninitialize(); + coInit = false; + } +} + +HRESULT AppState::findModelSource() +{ + CHECK( readString( registryKey, regValPath, source.path ) ); + + { + CAtlFile file; + CHECK( file.Create( source.path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING ) ); + ULONGLONG len; + CHECK( file.GetSize( len ) ); + source.sizeInBytes = len; + } + + CString impl; + CHECK( readString( registryKey, regValImpl, impl ) ); + CHECK( implParse( impl, source.impl ) ); + source.found = true; + return S_OK; +} + +HRESULT AppState::saveModelSource() +{ + LSTATUS lss = registryKey.SetStringValue( regValPath, source.path ); + if( lss != ERROR_SUCCESS ) + return HRESULT_FROM_WIN32( lss ); + + LPCTSTR impl = implString( source.impl ); + if( nullptr == impl ) + return E_INVALIDARG; + lss = registryKey.SetStringValue( regValImpl, impl ); + if( lss != ERROR_SUCCESS ) + return HRESULT_FROM_WIN32( lss ); + + return S_OK; +} + +uint32_t AppState::languageRead() +{ + DWORD dw; + LSTATUS lss = registryKey.QueryDWORDValue( regValLang, dw ); + if( lss == ERROR_SUCCESS ) + return dw; + return UINT_MAX; +} + +void AppState::languageWrite( uint32_t key ) +{ + registryKey.SetDWORDValue( regValLang, key ); +} + +CString AppState::stringLoad( LPCTSTR name ) +{ + CString res; + readString( registryKey, name, res ); + return res; +} +void AppState::stringStore( LPCTSTR name, LPCTSTR value ) +{ + registryKey.SetStringValue( name, value ); +} +uint32_t AppState::dwordLoad( LPCTSTR name, uint32_t fallback ) +{ + DWORD dw; + LSTATUS lss = registryKey.QueryDWORDValue( name, dw ); + if( lss == ERROR_SUCCESS ) + return dw; + return fallback; +} +void AppState::dwordStore( LPCTSTR name, uint32_t value ) +{ + registryKey.SetDWORDValue( name, value ); +} + +void AppState::lastScreenSave( HRESULT code ) +{ + dwordStore( regValLastScreen, (uint32_t)code ); +} + +HRESULT AppState::lastScreenLoad() +{ + return (HRESULT)dwordLoad( regValLastScreen, SCREEN_TRANSCRIBE ); +} + +void AppState::setupIcon( CWindow* wnd ) +{ + HICON ic = appIcon; + if( nullptr != ic ) + { + wnd->SendMessage( WM_SETICON, ICON_SMALL, (LPARAM)ic ); + wnd->SendMessage( WM_SETICON, ICON_BIG, (LPARAM)ic ); + } +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/AppState.h b/Examples/WhisperDesktop/AppState.h new file mode 100644 index 0000000..0a45a21 --- /dev/null +++ b/Examples/WhisperDesktop/AppState.h @@ -0,0 +1,51 @@ +#pragma once +#include "Utils/DebugConsole.h" + +class AppState +{ + bool coInit = false; + CRegKey registryKey; + CIcon appIcon; +public: + + struct ModelSource + { + CString path; + bool found = false; + Whisper::eModelImplementation impl = (Whisper::eModelImplementation)0; + uint64_t sizeInBytes = 0; + }; + ModelSource source; + + DebugConsole console; + CComPtr<Whisper::iMediaFoundation> mediaFoundation; + CComPtr<Whisper::iModel> model; + + ~AppState(); + + // Setup the initial things + HRESULT startup(); + + HRESULT findModelSource(); + + HRESULT saveModelSource(); + + uint32_t languageRead(); + void languageWrite( uint32_t key ); + + CString stringLoad( LPCTSTR name ); + void stringStore( LPCTSTR name, LPCTSTR value ); + uint32_t dwordLoad( LPCTSTR name, uint32_t fallback ); + void dwordStore( LPCTSTR name, uint32_t value ); + + bool automaticallyLoadModel = true; + + void lastScreenSave( HRESULT code ); + HRESULT lastScreenLoad(); + + void setupIcon( CWindow* wnd ); +}; + +constexpr HRESULT SCREEN_MODEL = 1; +constexpr HRESULT SCREEN_TRANSCRIBE = 2; +constexpr HRESULT SCREEN_CAPTURE = 3;
\ No newline at end of file diff --git a/Examples/WhisperDesktop/CaptureDlg.cpp b/Examples/WhisperDesktop/CaptureDlg.cpp new file mode 100644 index 0000000..f1030dd --- /dev/null +++ b/Examples/WhisperDesktop/CaptureDlg.cpp @@ -0,0 +1,505 @@ +#include "stdafx.h" +#include "CaptureDlg.h" + +HRESULT CaptureDlg::show() +{ + auto res = DoModal( nullptr ); + if( res == -1 ) + return HRESULT_FROM_WIN32( GetLastError() ); + switch( res ) + { + case IDC_BACK: + return SCREEN_MODEL; + case IDC_TRANSCRIBE: + return SCREEN_TRANSCRIBE; + } + return S_OK; +} + +static const LPCTSTR regValDevice = L"captureDevice"; +static const LPCTSTR regValOutPath = L"captureTextFile"; +static const LPCTSTR regValOutFormat = L"captureTextFlags"; + +enum struct CaptureDlg::eTextFlags : uint32_t +{ + Save = 1, + Append = 2, + Timestamps = 4, +}; + +LRESULT CaptureDlg::OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ) +{ + // First DDX call, hooks up variables to controls. + DoDataExchange( false ); + + languageSelector.initialize( m_hWnd, IDC_LANGUAGE, appState ); + cbTranslate.initialize( m_hWnd, IDC_TRANSLATE, appState ); + cbConsole.initialize( m_hWnd, IDC_CONSOLE, appState ); + + pendingState.initialize( + // Controls to disable while pending, re-enable afterwards + { + languageSelector, + cbCaptureDevice, + checkSave, checkAppend, checkTimestamps, transcribeOutputPath, transcribeOutputBrowse, + GetDlgItem( IDC_DEV_REFRESH ), + GetDlgItem( IDC_BACK ), + GetDlgItem( IDC_TRANSCRIBE ), + GetDlgItem( IDCANCEL ), + }, + // Controls to show while pending, hide afterwards + { + voiceActivity, GetDlgItem( IDC_VOICE_ACTIVITY_LBL ), + transcribeActivity, GetDlgItem( IDC_TRANS_LBL ), + stalled, GetDlgItem( IDC_STALL_LBL ), + progressBar, + } ); + + stalled.setActiveColor( flipRgb( 0xffcc33 ) ); + + HRESULT hr = work.create( this ); + if( FAILED( hr ) ) + { + reportError( m_hWnd, L"CreateThreadpoolWork failed", nullptr, hr ); + EndDialog( IDCANCEL ); + } + + listDevices(); + selectDevice( appState.stringLoad( regValDevice ) ); + + constexpr uint32_t defaultFlags = (uint32_t)eTextFlags::Append; + uint32_t flags = appState.dwordLoad( regValOutFormat, defaultFlags ); + if( flags & (uint32_t)eTextFlags::Save ) + checkSave.SetCheck( BST_CHECKED ); + if( flags & (uint32_t)eTextFlags::Append ) + checkAppend.SetCheck( BST_CHECKED ); + if( flags & (uint32_t)eTextFlags::Timestamps ) + checkTimestamps.SetCheck( BST_CHECKED ); + + transcribeOutputPath.SetWindowText( appState.stringLoad( regValOutPath ) ); + onSaveTextCheckbox(); + + appState.lastScreenSave( SCREEN_CAPTURE ); + appState.setupIcon( this ); + ATLVERIFY( CenterWindow() ); + return 0; +} + +HRESULT __stdcall CaptureDlg::listDevicesCallback( int len, const Whisper::sCaptureDevice* buffer, void* pv ) noexcept +{ + std::vector<sCaptureDevice>& devices = *( std::vector<sCaptureDevice> * )pv; + devices.resize( len ); + for( int i = 0; i < len; i++ ) + { + devices[ i ].displayName = buffer[ i ].displayName; + devices[ i ].endpoint = buffer[ i ].endpoint; + } + return S_OK; +} + +bool CaptureDlg::listDevices() +{ + appState.mediaFoundation->listCaptureDevices( &listDevicesCallback, &devices ); + cbCaptureDevice.ResetContent(); + for( const auto& dev : devices ) + cbCaptureDevice.AddString( dev.displayName ); + return !devices.empty(); +} + +void CaptureDlg::onDeviceRefresh() +{ + // Save the current selection + const int curSel = cbCaptureDevice.GetCurSel(); + CString str; + if( curSel >= 0 && curSel < (int)devices.size() ) + str = std::move( devices[ curSel ].endpoint ); + + // Refresh + listDevices(); + + // Restore the selection + selectDevice( str ); + + const size_t len = devices.size(); + if( len == 0 ) + { + MessageBox( L"No capture devices found on this computer.\nIf you have a USB microphone, connect it to this PC,\nand press “refresh” button.", + L"Capture Devices", MB_OK | MB_ICONWARNING ); + } + else + { + const char* suffix = ( len != 1 ) ? "s" : ""; + str.Format( L"Detected %zu audio capture device%S.", len, suffix ); + MessageBox( str, L"Capture Devices", MB_OK | MB_ICONINFORMATION ); + } +} + +bool CaptureDlg::selectDevice( LPCTSTR endpoint ) +{ + if( nullptr != endpoint && 0 != *endpoint ) + { + for( size_t i = 0; i < devices.size(); i++ ) + { + if( devices[ i ].endpoint == endpoint ) + { + cbCaptureDevice.SetCurSel( (int)i ); + return true; + } + } + } + + if( !devices.empty() ) + cbCaptureDevice.SetCurSel( 0 ); + return false; +} + +void CaptureDlg::onSaveTextCheckbox() +{ + const BOOL enabled = ( checkSave.GetCheck() == BST_CHECKED ); + std::array<HWND, 4> controls = { checkAppend, checkTimestamps, transcribeOutputPath, transcribeOutputBrowse }; + for( HWND w : controls ) + ::EnableWindow( w, enabled ); +} + +void CaptureDlg::onBrowseResult() +{ + LPCTSTR title = L"Output Text File"; + LPCTSTR outputFilters = L"Text files (*.txt)\0*.txt\0\0"; + CString path; + transcribeOutputPath.GetWindowText( path ); + if( !getSaveFileName( m_hWnd, title, outputFilters, path ) ) + return; + + LPCTSTR ext = PathFindExtension( path ); + if( 0 == *ext ) + { + wchar_t* const buffer = path.GetBufferSetLength( path.GetLength() + 5 ); + PathRenameExtension( buffer, L".txt" ); + path.ReleaseBuffer(); + } + + transcribeOutputPath.SetWindowText( path ); +} + +CaptureDlg::eTextFlags CaptureDlg::getOutputFlags() +{ + uint32_t flags = 0; + if( checkSave.GetCheck() == BST_CHECKED ) + flags |= (uint32_t)eTextFlags::Save; + if( checkAppend.GetCheck() == BST_CHECKED ) + flags |= (uint32_t)eTextFlags::Append; + if( checkTimestamps.GetCheck() == BST_CHECKED ) + flags |= (uint32_t)eTextFlags::Timestamps; + return (eTextFlags)flags; +} + +void CaptureDlg::setPending( bool nowPending ) +{ + pendingState.setPending( nowPending ); + if( nowPending ) + { + progressBar.SetMarquee( TRUE, 0 ); + btnRunCapture.SetWindowText( L"Stop" ); + } + else + { + progressBar.SetMarquee( FALSE, 0 ); + btnRunCapture.SetWindowText( L"Capture" ); + btnRunCapture.EnableWindow( TRUE ); + captureRunning = false; + } +} + +void CaptureDlg::onRunCapture() +{ + if( captureRunning ) + { + threadState.stopRequested = true; + btnRunCapture.EnableWindow( FALSE ); + return; + } + + int dev = cbCaptureDevice.GetCurSel(); + if( dev < 0 || dev >= (int)devices.size() ) + { + showError( L"Please select a capture device", S_FALSE ); + return; + } + threadState.endpoint = devices[ dev ].endpoint; + threadState.language = languageSelector.selectedLanguage(); + threadState.translate = cbTranslate.checked(); + if( isInvalidTranslate( m_hWnd, threadState.language, threadState.translate ) ) + return; + + threadState.flags = getOutputFlags(); + if( (uint32_t)threadState.flags & (uint32_t)eTextFlags::Save ) + { + transcribeOutputPath.GetWindowText( threadState.textOutputPath ); + if( threadState.textOutputPath.GetLength() <= 0 ) + { + showError( L"Please specify the output text file", S_FALSE ); + return; + } + appState.stringStore( regValOutPath, threadState.textOutputPath ); + } + else + cbConsole.ensureChecked(); + + languageSelector.saveSelection( appState ); + cbTranslate.saveSelection( appState ); + appState.stringStore( regValDevice, threadState.endpoint ); + appState.dwordStore( regValOutFormat, (uint32_t)threadState.flags ); + + captureRunning = true; + threadState.errorMessage = L""; + threadState.stopRequested = false; + threadState.captureParams.minDuration = 7; + threadState.captureParams.maxDuration = 11; + setPending( true ); + work.post(); +} + +void __declspec( noinline ) CaptureDlg::getThreadError() +{ + getLastError( threadState.errorMessage ); +} + +#define CHECK_EX( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) { getThreadError(); return __hr; } } + +static HRESULT appendDate( CString& str, const SYSTEMTIME& time ) +{ + constexpr DWORD dateFlags = DATE_LONGDATE; + int cc = GetDateFormatEx( LOCALE_NAME_USER_DEFAULT, dateFlags, &time, nullptr, nullptr, 0, nullptr ); + if( 0 == cc ) + return getLastHr(); + + const int oldLength = str.GetLength(); + wchar_t* const buffer = str.GetBufferSetLength( oldLength + cc ); + cc = GetDateFormatEx( LOCALE_NAME_USER_DEFAULT, dateFlags, &time, nullptr, buffer + oldLength, cc, nullptr ); + if( 0 != cc ) + { + str.ReleaseBuffer(); + return S_OK; + } + HRESULT hr = getLastHr(); + str.ReleaseBuffer(); + return hr; +} + +static HRESULT appendTime( CString& str, const SYSTEMTIME& time ) +{ + constexpr DWORD timeFlags = 0; + int cc = GetTimeFormatEx( LOCALE_NAME_USER_DEFAULT, timeFlags, &time, nullptr, nullptr, 0 ); + if( 0 == cc ) + return getLastHr(); + + const int oldLength = str.GetLength(); + wchar_t* const buffer = str.GetBufferSetLength( oldLength + cc ); + cc = GetTimeFormatEx( LOCALE_NAME_USER_DEFAULT, timeFlags, &time, nullptr, buffer + oldLength, cc ); + if( 0 != cc ) + { + str.ReleaseBuffer(); + return S_OK; + } + HRESULT hr = getLastHr(); + str.ReleaseBuffer(); + return hr; +} + +static HRESULT printDateTime( CAtlFile& file ) +{ + SYSTEMTIME time; + GetLocalTime( &time ); + + CString str; + str = L"==== Captured on "; + CHECK( appendDate( str, time ) ); + str += L", "; + CHECK( appendTime( str, time ) ); + str += L" ====\r\n"; + + CStringA u8; + makeUtf8( u8, str ); + return file.Write( cstr( u8 ), (DWORD)u8.GetLength() ); +} + +inline HRESULT CaptureDlg::runCapture() +{ + clearLastError(); + using namespace Whisper; + CComPtr<iAudioCapture> capture; + CHECK_EX( appState.mediaFoundation->openCaptureDevice( threadState.endpoint, threadState.captureParams, &capture ) ); + + HRESULT hr; + CAtlFile file; + const uint32_t flags = (uint32_t)threadState.flags; + if( flags & (uint32_t)eTextFlags::Save ) + { + const bool append = 0 != ( flags & (uint32_t)eTextFlags::Append ); + const DWORD creation = append ? OPEN_ALWAYS : CREATE_ALWAYS; + hr = file.Create( threadState.textOutputPath, GENERIC_WRITE, FILE_SHARE_READ, creation ); + if( FAILED( hr ) ) + { + threadState.errorMessage = L"Unable to create the output text file"; + return hr; + } + if( append ) + { + ULONGLONG size; + CHECK( file.GetSize( size ) ); + if( size == 0 ) + CHECK( writeUtf8Bom( file ) ) + else + CHECK( file.Seek( 0, SEEK_END ) ); + } + else + { + CHECK( writeUtf8Bom( file ) ); + } + + if( flags & (uint32_t)eTextFlags::Timestamps ) + CHECK( printDateTime( file ) ); + + threadState.file = &file; + } + else + threadState.file = nullptr; + + CComPtr<iContext> context; + CHECK_EX( appState.model->createContext( &context ) ); + + sFullParams fullParams; + CHECK_EX( context->fullDefaultParams( eSamplingStrategy::Greedy, &fullParams ) ); + fullParams.language = threadState.language; + fullParams.setFlag( eFullParamsFlags::Translate, threadState.translate ); + fullParams.resetFlag( eFullParamsFlags::PrintRealtime ); + fullParams.new_segment_callback = &newSegmentCallback; + fullParams.new_segment_callback_user_data = this; + + sCaptureCallbacks callbacks; + callbacks.shouldCancel = &cbCancel; + callbacks.captureStatus = &cbStatus; + callbacks.pv = this; + + CHECK_EX( context->runCapture( fullParams, callbacks, capture ) ); + threadState.file = nullptr; + + context->timingsPrint(); + return S_OK; +} + +void __stdcall CaptureDlg::poolCallback() noexcept +{ + const HRESULT hr = runCapture(); + PostMessage( WM_CALLBACK_COMPLETION, hr ); +} + +void CaptureDlg::showError( LPCTSTR text, HRESULT hr ) +{ + reportError( m_hWnd, text, L"Capture failed", hr ); +} + +LRESULT CaptureDlg::onThreadQuit( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ) +{ + setPending( false ); + + const HRESULT hr = (HRESULT)wParam; + if( FAILED( hr ) ) + { + LPCTSTR failMessage = L"Capture failed"; + + if( threadState.errorMessage.GetLength() > 0 ) + { + CString tmp = failMessage; + tmp += L"\n"; + tmp += threadState.errorMessage; + showError( tmp, hr ); + } + else + showError( failMessage, hr ); + + return 0; + } + else + { + if( (uint32_t)threadState.flags & (uint32_t)eTextFlags::Save ) + ShellExecute( NULL, L"open", threadState.textOutputPath, NULL, NULL, SW_SHOW ); + } + + return 0; +} + +LRESULT CaptureDlg::onThreadStatus( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ) +{ + using namespace Whisper; + const uint8_t newStatus = (uint8_t)wParam; + // Update the GUI + voiceActivity.setActive( 0 != ( newStatus & (uint8_t)eCaptureStatus::Voice ) ); + transcribeActivity.setActive( 0 != ( newStatus & (uint8_t)eCaptureStatus::Transcribing ) ); + stalled.setActive( 0 != ( newStatus & (uint8_t)eCaptureStatus::Stalled ) ); + return 0; +} + +HRESULT __stdcall CaptureDlg::cbCancel( void* pv ) noexcept +{ + CaptureDlg& dialog = *(CaptureDlg*)pv; + return dialog.threadState.stopRequested ? S_OK : S_FALSE; +} + +HRESULT __stdcall CaptureDlg::cbStatus( void* pv, Whisper::eCaptureStatus status ) noexcept +{ + CaptureDlg& dialog = *(CaptureDlg*)pv; + if( dialog.PostMessage( WM_CALLBACK_STATUS, (uint8_t)status ) ) + return S_OK; + return getLastHr(); +} + +HRESULT __cdecl CaptureDlg::newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept +{ + using namespace Whisper; + CComPtr<iTranscribeResult> result; + const eResultFlags flags = eResultFlags::Timestamps | eResultFlags::Tokens; + CHECK( ctx->getResults( flags, &result ) ); + CHECK( logNewSegments( result, n_new ) ); + + CaptureDlg& dialog = *(CaptureDlg*)user_data; + return dialog.appendTextFile( result, n_new ); +} + +HRESULT CaptureDlg::appendTextFile( Whisper::iTranscribeResult* results, uint32_t newSegments ) +{ + if( nullptr == threadState.file || 0 == newSegments ) + return S_OK; + + using namespace Whisper; + sTranscribeLength length; + CHECK( results->getSize( length ) ); + + const size_t len = length.countSegments; + size_t i = len - newSegments; + + const sSegment* const segments = results->getSegments(); + CStringA str; + for( ; i < len; i++ ) + { + const sSegment& seg = segments[ i ]; + if( 0 != ( (uint32_t)threadState.flags & (uint32_t)eTextFlags::Timestamps ) ) + { + str = "["; + printTimeStamp( str, seg.time.begin ); + str += " --> "; + printTimeStamp( str, seg.time.end ); + str += "] "; + } + else + str = ""; + + str += seg.text; + str += "\r\n"; + + CHECK( threadState.file->Write( cstr( str ), (DWORD)str.GetLength() ) ); + } + + CHECK( threadState.file->Flush() ); + return S_OK; +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/CaptureDlg.h b/Examples/WhisperDesktop/CaptureDlg.h new file mode 100644 index 0000000..c66b10d --- /dev/null +++ b/Examples/WhisperDesktop/CaptureDlg.h @@ -0,0 +1,143 @@ +#pragma once +#include "AppState.h" +#include "Utils/WTL/atlddx.h" +#include "Utils/miscUtils.h" +#include "Utils/LanguageDropdown.h" +#include "Utils/TranslateCheckbox.h" +#include "Utils/PendingState.h" +#include "CircleIndicator.h" + +class CaptureDlg : + public CDialogImpl<CaptureDlg>, + public CWinDataExchange<CaptureDlg>, + public iThreadPoolCallback +{ + AppState& appState; + +public: + static constexpr UINT IDD = IDD_CAPTURE_DIALOG; + static constexpr UINT WM_CALLBACK_COMPLETION = WM_APP + 1; + static constexpr UINT WM_CALLBACK_STATUS = WM_APP + 2; + + CaptureDlg( AppState& app ) : appState( app ) { } + + HRESULT show(); + + BEGIN_MSG_MAP( CaptureDlg ) + MESSAGE_HANDLER( WM_INITDIALOG, OnInitDialog ) + ON_BUTTON_CLICK( IDC_CONSOLE, cbConsole.click ) + ON_BUTTON_CLICK( IDC_DEV_REFRESH, onDeviceRefresh ); + ON_BUTTON_CLICK( IDC_BROWSE_RESULT, onBrowseResult ); + ON_BUTTON_CLICK( IDC_SAVE_TEXT, onSaveTextCheckbox ); + ON_BUTTON_CLICK( IDC_RUN_CAPTURE, onRunCapture ); + + ON_BUTTON_CLICK( IDCANCEL, onClose ) + ON_BUTTON_CLICK( IDC_BACK, onBack ) + ON_BUTTON_CLICK( IDC_TRANSCRIBE, onTranscribe ); + + MESSAGE_HANDLER( WM_CALLBACK_COMPLETION, onThreadQuit ); + MESSAGE_HANDLER( WM_CALLBACK_STATUS, onThreadStatus ); + END_MSG_MAP() + + BEGIN_DDX_MAP( CaptureDlg ) + DDX_CONTROL_HANDLE( IDC_DEVICE, cbCaptureDevice ) + DDX_CONTROL_HANDLE( IDC_RUN_CAPTURE, btnRunCapture ); + DDX_CONTROL_HANDLE( IDC_TRANSCRIBE_PROGRESS, progressBar ); + DDX_CONTROL_HANDLE( IDC_SAVE_TEXT, checkSave ) + DDX_CONTROL_HANDLE( IDC_SAVE_APPEND, checkAppend ) + DDX_CONTROL_HANDLE( IDC_SAVE_TIMESTAMPS, checkTimestamps ) + DDX_CONTROL_HANDLE( IDC_PATH_RESULT, transcribeOutputPath ) + DDX_CONTROL_HANDLE( IDC_BROWSE_RESULT, transcribeOutputBrowse ); + + DDX_CONTROL( IDC_VOICE_ACTIVITY, voiceActivity ); + DDX_CONTROL( IDC_TRANS_STATUS, transcribeActivity ); + DDX_CONTROL( IDC_STALL_STATUS, stalled ); + + END_DDX_MAP() + +private: + PendingState pendingState; + void setPending( bool nowPending ); + + LRESULT OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ); + + void onClose() + { + ATLVERIFY( EndDialog( IDCANCEL ) ); + } + void onBack() + { + ATLVERIFY( EndDialog( IDC_BACK ) ); + } + void onTranscribe() + { + ATLVERIFY( EndDialog( IDC_TRANSCRIBE ) ); + } + + // List capture devices, and populate the combobox + bool listDevices(); + void onDeviceRefresh(); + bool selectDevice( LPCTSTR endpoint ); + + static HRESULT __stdcall listDevicesCallback( int len, const Whisper::sCaptureDevice* buffer, void* pv ) noexcept; + ConsoleCheckbox cbConsole; + LanguageDropdown languageSelector; + TranslateCheckbox cbTranslate; + CComboBox cbCaptureDevice; + + void onBrowseResult(); + + enum struct eTextFlags : uint32_t; + CButton checkSave, checkAppend, checkTimestamps; + CEdit transcribeOutputPath; + CButton transcribeOutputBrowse; + void onSaveTextCheckbox(); + eTextFlags getOutputFlags(); + + CButton btnRunCapture; + CProgressBarCtrl progressBar; + ThreadPoolWork work; + + struct sCaptureDevice + { + CString displayName; + CString endpoint; + }; + std::vector<sCaptureDevice> devices; + + void showError( LPCTSTR text, HRESULT hr ); + + CircleIndicator voiceActivity; + CircleIndicator transcribeActivity; + CircleIndicator stalled; + + struct sThreadState + { + volatile bool stopRequested; + bool translate; + eTextFlags flags; + CAtlFile* file; + uint32_t language; + Whisper::sCaptureParams captureParams; + CString endpoint; + CString textOutputPath; + CString errorMessage; + }; + sThreadState threadState; + bool captureRunning = false; + + void getThreadError(); + void onRunCapture(); + HRESULT runCapture(); + void __stdcall poolCallback() noexcept override final; + + LRESULT onThreadQuit( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ); + LRESULT onThreadStatus( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ); + + static HRESULT __stdcall cbCancel( void* pv ) noexcept; + static HRESULT __stdcall cbStatus( void* pv, Whisper::eCaptureStatus status ) noexcept; + + static HRESULT __cdecl newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept; + + HRESULT appendTextFile( Whisper::iTranscribeResult* results, uint32_t newSegments ); +};
\ No newline at end of file diff --git a/Examples/WhisperDesktop/CircleIndicator.cpp b/Examples/WhisperDesktop/CircleIndicator.cpp new file mode 100644 index 0000000..593bd4a --- /dev/null +++ b/Examples/WhisperDesktop/CircleIndicator.cpp @@ -0,0 +1,118 @@ +#include "stdafx.h" +#include "CircleIndicator.h" +#include <atltypes.h> +#include "AppState.h" + +namespace +{ + static const LPCTSTR windowClassName = L"CircleIndicator"; + + // Font with these symbols, shipped with Windows since forever: + // https://learn.microsoft.com/en-us/typography/font-list/segoe-ui-symbol + static const LPCTSTR fontName = L"Segoe UI Symbol"; + + // Outlined circle + static const LPCTSTR circleOutline = L"⚪"; + // Filled circle + static const LPCTSTR circleFilled = L"⚫"; + + // Font size for that symbol font, in points + constexpr int fontSizePoints = 14; + + // Default active color for the indicator + constexpr uint32_t defaultActiveColor = 0x7FFF7F; +} + +CircleIndicator::CircleIndicator() : + activeColor( defaultActiveColor ) +{ } + +ATL::CWndClassInfo& CircleIndicator::GetWndClassInfo() +{ + // Use custom class style with CS_PARENTDC, and COLOR_3DFACE for the background + static ATL::CWndClassInfo wc = + { + { sizeof( WNDCLASSEX ), + CS_HREDRAW | CS_VREDRAW | CS_PARENTDC, + StartWindowProc, + 0, 0, NULL, NULL, NULL, (HBRUSH)( COLOR_3DFACE + 1 ), NULL, windowClassName, NULL }, + NULL, NULL, IDC_ARROW, TRUE, 0, _T( "" ) + }; + return wc; +} + +// Class registration +HRESULT CircleIndicator::registerClass() +{ + WNDPROC pUnusedWndSuperProc = nullptr; + ATOM a = GetWndClassInfo().Register( &pUnusedWndSuperProc ); + if( 0 != a ) + return S_OK; + return getLastHr(); +} + +HRESULT CircleIndicator::createFont( int height ) +{ + LOGFONT logFont; + memset( &logFont, 0, sizeof( logFont ) ); + logFont.lfHeight = height; + logFont.lfCharSet = ANSI_CHARSET; + logFont.lfOutPrecision = OUT_TT_PRECIS; + logFont.lfClipPrecision = CLIP_DEFAULT_PRECIS; + wcsncpy_s( logFont.lfFaceName, fontName, _TRUNCATE ); + font.CreateFontIndirect( &logFont ); + if( font ) + return S_OK; + return E_FAIL; +} + +void CircleIndicator::onDestroy() +{ + if( font ) + font.DeleteObject(); +} + +void CircleIndicator::onPaint( CDCHandle dc ) +{ + CRect rectInt32; + GetClientRect( &rectInt32 ); + + CPaintDC pdc( m_hWnd ); + + const int logPixels = pdc.GetDeviceCaps( LOGPIXELSY ); + int fontSize = -MulDiv( fontSizePoints, logPixels, 72 ); + if( !font || fontHeight != fontSize ) + { + if( font ) + font.DeleteObject(); + HRESULT hr = createFont( fontSize ); + if( FAILED( hr ) ) + return; + fontHeight = fontSize; + } + + pdc.SetBkColor( GetSysColor( COLOR_3DFACE ) ); + pdc.SelectFont( font ); + pdc.SetBkMode( OPAQUE ); + constexpr UINT textFormat = DT_CENTER | DT_VCENTER; + + if( isActive ) + { + pdc.SetTextColor( activeColor ); + pdc.DrawText( circleFilled, 1, rectInt32, textFormat ); + pdc.SetBkMode( TRANSPARENT ); + } + + pdc.SetTextColor( 0 ); + pdc.DrawText( circleOutline, 1, rectInt32, textFormat ); +} + +void CircleIndicator::setActive( bool nowActive ) +{ + if( nowActive == isActive ) + return; + + // Repaint the control + isActive = nowActive; + InvalidateRect( nullptr ); +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/CircleIndicator.h b/Examples/WhisperDesktop/CircleIndicator.h new file mode 100644 index 0000000..492524d --- /dev/null +++ b/Examples/WhisperDesktop/CircleIndicator.h @@ -0,0 +1,36 @@ +#pragma once +#include "Utils/miscUtils.h" +#include "Utils/WTL/atlcrack.h" + +// This control renders a black circle, and in the active state, the circle is filled with a bright color. +class CircleIndicator: public CWindowImpl<CircleIndicator> +{ +public: + static ATL::CWndClassInfo& GetWndClassInfo(); + + BEGIN_MSG_MAP( CircleIndicator ) + MSG_WM_PAINT( onPaint ) + MSG_WM_DESTROY( onDestroy ) + END_MSG_MAP() + + // Class registration + static HRESULT registerClass(); + + void setActive( bool nowActive ); + + void setActiveColor( uint32_t col ) + { + activeColor = col; + } + CircleIndicator(); + +private: + bool isActive = false; + uint32_t activeColor; + int fontHeight = 0; + CFont font; + HRESULT createFont( int height ); + + void onDestroy(); + void onPaint( CDCHandle dc ); +};
\ No newline at end of file diff --git a/Examples/WhisperDesktop/LoadModelDlg.cpp b/Examples/WhisperDesktop/LoadModelDlg.cpp new file mode 100644 index 0000000..1b2bf03 --- /dev/null +++ b/Examples/WhisperDesktop/LoadModelDlg.cpp @@ -0,0 +1,206 @@ +#include "stdafx.h" +#include "LoadModelDlg.h" +#include "Utils/miscUtils.h" +#include "Utils/logger.h" + +constexpr int progressMaxInteger = 1024 * 8; + +HRESULT LoadModelDlg::show() +{ + auto res = DoModal( nullptr ); + if( res == -1 ) + return HRESULT_FROM_WIN32( GetLastError() ); + if( res == IDOK ) + { + HRESULT hr = appState.lastScreenLoad(); + switch( hr ) + { + case SCREEN_TRANSCRIBE: + case SCREEN_CAPTURE: + return hr; + default: + return SCREEN_TRANSCRIBE; + } + } + return S_OK; +} + +LRESULT LoadModelDlg::OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ) +{ + // First DDX call, hooks up variables to controls. + DoDataExchange( false ); + + cbConsole.initialize( m_hWnd, IDC_CONSOLE, appState ); + implPopulateCombobox( cbModelType, appState.source.impl ); + modelPath.SetWindowTextW( appState.source.path ); + + HRESULT hr = work.create( this ); + if( FAILED( hr ) ) + { + CString text = L"CreateThreadpoolWork failed\n"; + text += formatErrorMessage( hr ); + ::MessageBox( m_hWnd, text, L"Unable to load the model", MB_OK | MB_ICONWARNING ); + return TRUE; + } + + editorsWindows.reserve( 5 ); + editorsWindows = { modelPath, cbModelType, GetDlgItem( IDC_BROWSE ), GetDlgItem( IDOK ), GetDlgItem( IDCANCEL ) }; + pendingWindows.reserve( 2 ); + pendingWindows = { GetDlgItem( IDC_PENDING_TEXT ), progressBar }; + + progressBar.SetRange32( 0, progressMaxInteger ); + progressBar.SetStep( 1 ); + + appState.setupIcon( this ); + ATLVERIFY( CenterWindow() ); + if( !appState.source.found || !appState.automaticallyLoadModel ) + return 0; + + // AppState.findModelSource() method has located model parameters in registry; + // Post a notification identical to the "OK" button click event. + PostMessage( WM_COMMAND, IDOK, (LPARAM)( GetDlgItem( IDOK ).m_hWnd ) ); + + return 0; +} + +LRESULT LoadModelDlg::OnBrowse( UINT, INT, HWND, BOOL& bHandled ) +{ + bHandled = TRUE; + + CString path; + modelPath.GetWindowText( path ); + if( !getOpenFileName( m_hWnd, L"Select a GGML Model File", L"Binary files (*.bin)\0*.bin\0\0", path ) ) + return 0; + + modelPath.SetWindowText( path ); + appState.source.path = path; + return 0; +} + +LRESULT LoadModelDlg::validationError( LPCTSTR message ) +{ + reportError( m_hWnd, message, L"Unable to load the model" ); + return 0; +} + +LRESULT LoadModelDlg::validationError( LPCTSTR message, HRESULT hr ) +{ + reportError( m_hWnd, message, L"Unable to load the model", hr ); + return 0; +} + +void LoadModelDlg::setPending( bool nowPending ) +{ + const BOOL enable = nowPending ? FALSE : TRUE; + for( HWND w : editorsWindows ) + ::EnableWindow( w, enable ); + + const int show = nowPending ? SW_NORMAL : SW_HIDE; + for( HWND w : pendingWindows ) + ::ShowWindow( w, show ); + + if( nowPending ) + progressBar.SetMarquee( TRUE, 0 ); + else + progressBar.SetMarquee( FALSE, 0 ); +} + +LRESULT LoadModelDlg::OnOk( UINT, INT, HWND, BOOL& bHandled ) +{ + modelPath.GetWindowText( path ); + if( path.GetLength() <= 0 ) + return validationError( L"Please select a model GGML file" ); + + { + CAtlFile file; + HRESULT hr = file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING ); + if( FAILED( hr ) ) + return validationError( L"Unable to open the model file", hr ); + + ULONGLONG cb = 0; + file.GetSize( cb ); + appState.source.sizeInBytes = cb; + } + + impl = implGetValue( cbModelType ); + if( impl == (Whisper::eModelImplementation)0 ) + return validationError( L"Please select a model type" ); + + setPending( true ); + work.post(); + return 0; +} + +void __stdcall LoadModelDlg::poolCallback() noexcept +{ + CComPtr<Whisper::iModel> model; + clearLastError(); + loadError = L""; + Whisper::sLoadModelCallbacks lmcb; + lmcb.cancel = nullptr; + lmcb.progress = &LoadModelDlg::progressCallback; + lmcb.pv = this; + HRESULT hr = Whisper::loadModel( path, impl, &lmcb, &model ); + if( SUCCEEDED( hr ) ) + appState.model = model; + else + getLastError( loadError ); + + this->PostMessage( WM_CALLBACK_STATUS, (WPARAM)hr ); +} + +HRESULT __stdcall LoadModelDlg::progressCallback( double val, void* pv ) noexcept +{ + LoadModelDlg& dialog = *(LoadModelDlg*)pv; + constexpr double mul = progressMaxInteger; + int pos = lround( mul * val ); + dialog.progressBar.PostMessage( PBM_SETPOS, pos, 0 ); + return S_OK; +} + +LRESULT LoadModelDlg::OnCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled ) +{ + setPending( false ); + + bHandled = TRUE; + const HRESULT hr = (HRESULT)wParam; + if( FAILED( hr ) ) + { + LPCTSTR failMessage = L"Error loading the model"; + if( loadError.GetLength() > 0 ) + { + CString tmp = failMessage; + tmp += L"\n"; + tmp += loadError; + return validationError( tmp, hr ); + } + else + return validationError( failMessage, hr ); + } + + appState.source.path = path; + appState.source.impl = impl; + appState.saveModelSource(); + + EndDialog( IDOK ); + return 0; +} + +LRESULT LoadModelDlg::OnHyperlink( int idCtrl, LPNMHDR pnmh, BOOL& bHandled ) +{ + const UINT code = pnmh->code; + switch( code ) + { + case NM_CLICK: + case NM_RETURN: + break; + default: + return 0; + } + + PNMLINK pNMLink = (PNMLINK)pnmh; + LPCTSTR url = pNMLink->item.szUrl; + ShellExecute( NULL, L"open", url, NULL, NULL, SW_SHOW ); + bHandled = TRUE; + return 0; +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/LoadModelDlg.h b/Examples/WhisperDesktop/LoadModelDlg.h new file mode 100644 index 0000000..a8d7aea --- /dev/null +++ b/Examples/WhisperDesktop/LoadModelDlg.h @@ -0,0 +1,69 @@ +#pragma once +#include "AppState.h" +#include "Utils/WTL/atlddx.h" +#include "Utils/miscUtils.h" + +class LoadModelDlg: + public CDialogImpl<LoadModelDlg>, + public CWinDataExchange<LoadModelDlg>, + public iThreadPoolCallback +{ + AppState& appState; +public: + static constexpr UINT IDD = IDD_OPEN_MODEL; + static constexpr UINT WM_CALLBACK_STATUS = WM_APP + 1; + + LoadModelDlg( AppState& app ) : appState( app ) { } + + HRESULT show(); + + BEGIN_MSG_MAP( LoadModelDlg ) + MESSAGE_HANDLER( WM_INITDIALOG, OnInitDialog ) + ON_BUTTON_CLICK( IDC_CONSOLE, cbConsole.click ) + COMMAND_ID_HANDLER( IDCANCEL, OnCommand ) + COMMAND_ID_HANDLER( IDOK, OnOk ) + COMMAND_ID_HANDLER( IDC_BROWSE, OnBrowse ) + MESSAGE_HANDLER( WM_CALLBACK_STATUS, OnCallbackStatus ) + NOTIFY_ID_HANDLER( IDC_HYPERLINKS, OnHyperlink ) + END_MSG_MAP() + + BEGIN_DDX_MAP( LoadModelDlg ) + DDX_CONTROL_HANDLE( IDC_PATH, modelPath ) + DDX_CONTROL_HANDLE( IDC_MODEL_TYPE, cbModelType ) + DDX_CONTROL_HANDLE( IDC_PROGRESS, progressBar ); + END_DDX_MAP() + +private: + std::vector<HWND> editorsWindows; + std::vector<HWND> pendingWindows; + void setPending( bool nowPending ); + + LRESULT OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ); + LRESULT OnCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled ); + + LRESULT OnCommand( UINT, INT nIdentifier, HWND, BOOL& bHandled ) + { + ATLVERIFY( EndDialog( nIdentifier ) ); + return 0; + } + LRESULT OnBrowse( UINT, INT, HWND, BOOL& bHandled ); + LRESULT OnOk( UINT, INT, HWND, BOOL& bHandled ); + + ConsoleCheckbox cbConsole; + CComboBox cbModelType; + CEdit modelPath; + CProgressBarCtrl progressBar; + + LRESULT validationError( LPCTSTR message ); + LRESULT validationError( LPCTSTR message, HRESULT hr ); + + ThreadPoolWork work; + CString path; + Whisper::eModelImplementation impl; + CString loadError; + void __stdcall poolCallback() noexcept override final; + + LRESULT OnHyperlink( int idCtrl, LPNMHDR pnmh, BOOL& bHandled ); + + static HRESULT __stdcall progressCallback( double val, void* pv ) noexcept; +};
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Resource.h b/Examples/WhisperDesktop/Resource.h new file mode 100644 index 0000000..f4a9af5 --- /dev/null +++ b/Examples/WhisperDesktop/Resource.h @@ -0,0 +1,61 @@ +//{{NO_DEPENDENCIES}} +// Microsoft Visual C++ generated include file. +// Used by WhisperDesktop.rc +// +#define IDC_MYICON 2 +#define IDD_WHISPERDESKTOP_DIALOG 102 +#define IDD_ABOUTBOX 103 +#define IDM_ABOUT 104 +#define IDI_WHISPERDESKTOP 107 +#define IDI_SMALL 108 +#define IDR_MAINFRAME 128 +#define IDD_OPEN_MODEL 129 +#define IDD_MAIN_DIALOG 130 +#define IDD_TRANSCRIBE_DIALOG 130 +#define IDD_CAPTURE_DIALOG 131 +#define IDC_PATH 1000 +#define IDC_BROWSE 1001 +#define IDC_MODEL_TYPE 1002 +#define IDC_PATH_RESULT 1002 +#define IDC_PROGRESS 1003 +#define IDC_BROWSE_RESULT 1003 +#define IDC_SYSLINK1 1004 +#define IDC_HYPERLINKS 1004 +#define IDC_TRANSCRIBE_PROGRESS 1004 +#define IDC_PENDING_TEXT 1005 +#define IDC_MODEL_DESC 1006 +#define IDC_LANGUAGE 1007 +#define IDC_OUTPUT_FORMAT 1008 +#define IDC_PATH_MEDIA 1009 +#define IDC_DEVICE 1009 +#define IDC_BROWSE_MEDIA 1010 +#define IDC_TRANSCRIBE 1011 +#define IDC_BACK 1012 +#define IDC_CHECK1 1013 +#define IDC_CONSOLE 1013 +#define IDC_CAPTURE 1014 +#define IDC_DEV_REFRESH 1015 +#define IDC_SAVE_TEXT 1016 +#define IDC_SAVE_APPEND 1017 +#define IDC_SAVE_TIMESTAMPS 1018 +#define IDC_RUN_CAPTURE 1019 +#define IDC_VOICE_ACTIVITY 1020 +#define IDC_VOICE_ACTIVITY_LBL 1021 +#define IDC_TRANS_STATUS 1022 +#define IDC_TRANS_LBL 1023 +#define IDC_STALL_STATUS 1024 +#define IDC_STALL_LBL 1025 +#define IDC_TRANSLATE 1026 +#define IDC_STATIC -1 + +// Next default values for new objects +// +#ifdef APSTUDIO_INVOKED +#ifndef APSTUDIO_READONLY_SYMBOLS +#define _APS_NO_MFC 1 +#define _APS_NEXT_RESOURCE_VALUE 131 +#define _APS_NEXT_COMMAND_VALUE 32771 +#define _APS_NEXT_CONTROL_VALUE 1027 +#define _APS_NEXT_SYMED_VALUE 110 +#endif +#endif diff --git a/Examples/WhisperDesktop/TranscribeDlg.cpp b/Examples/WhisperDesktop/TranscribeDlg.cpp new file mode 100644 index 0000000..14bec05 --- /dev/null +++ b/Examples/WhisperDesktop/TranscribeDlg.cpp @@ -0,0 +1,493 @@ +#include "stdafx.h" +#include "TranscribeDlg.h" +#include "Utils/logger.h" + +HRESULT TranscribeDlg::show() +{ + auto res = DoModal( nullptr ); + if( res == -1 ) + return HRESULT_FROM_WIN32( GetLastError() ); + switch( res ) + { + case IDC_BACK: + return SCREEN_MODEL; + case IDC_CAPTURE: + return SCREEN_CAPTURE; + } + return S_OK; +} + +constexpr int progressMaxInteger = 1024 * 8; + +static const LPCTSTR regValInput = L"sourceMedia"; +static const LPCTSTR regValOutFormat = L"resultFormat"; +static const LPCTSTR regValOutPath = L"resultPath"; + +LRESULT TranscribeDlg::OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ) +{ + // First DDX call, hooks up variables to controls. + DoDataExchange( false ); + printModelDescription(); + languageSelector.initialize( m_hWnd, IDC_LANGUAGE, appState ); + cbConsole.initialize( m_hWnd, IDC_CONSOLE, appState ); + cbTranslate.initialize( m_hWnd, IDC_TRANSLATE, appState ); + populateOutputFormats(); + + pendingState.initialize( + { + languageSelector, + sourceMediaPath, GetDlgItem( IDC_BROWSE_MEDIA ), + transcribeOutFormat, + transcribeOutputPath, GetDlgItem( IDC_BROWSE_RESULT ), + GetDlgItem( IDC_TRANSCRIBE ), + GetDlgItem( IDCANCEL ), + GetDlgItem( IDC_BACK ), + GetDlgItem( IDC_CAPTURE ) + }, + { + progressBar, GetDlgItem( IDC_PENDING_TEXT ) + } ); + + HRESULT hr = work.create( this ); + if( FAILED( hr ) ) + { + reportError( m_hWnd, L"CreateThreadpoolWork failed", nullptr, hr ); + EndDialog( IDCANCEL ); + } + + progressBar.SetRange32( 0, progressMaxInteger ); + progressBar.SetStep( 1 ); + + sourceMediaPath.SetWindowText( appState.stringLoad( regValInput ) ); + transcribeOutFormat.SetCurSel( (int)appState.dwordLoad( regValOutFormat, 0 ) ); + transcribeOutputPath.SetWindowText( appState.stringLoad( regValOutPath ) ); + BOOL unused; + OnOutFormatChange( 0, 0, nullptr, unused ); + + appState.lastScreenSave( SCREEN_TRANSCRIBE ); + appState.setupIcon( this ); + ATLVERIFY( CenterWindow() ); + return 0; +} + +void TranscribeDlg::printModelDescription() +{ + CString text; + if( S_OK == appState.model->isMultilingual() ) + text = L"Multilingual"; + else + text = L"Single-language"; + text += L" model \""; + LPCTSTR path = appState.source.path; + path = ::PathFindFileName( path ); + text += path; + text += L"\", "; + const int64_t cb = appState.source.sizeInBytes; + if( cb < 1 << 30 ) + { + constexpr double mul = 1.0 / ( 1 << 20 ); + double mb = (double)cb * mul; + text.AppendFormat( L"%.1f MB", mb ); + } + else + { + constexpr double mul = 1.0 / ( 1 << 30 ); + double gb = (double)cb * mul; + text.AppendFormat( L"%.2f GB", gb ); + } + text += L" on disk, "; + text += implString( appState.source.impl ); + text += L" implementation"; + + modelDesc.SetWindowText( text ); +} + +void TranscribeDlg::populateOutputFormats() +{ + transcribeOutFormat.AddString( L"None" ); + transcribeOutFormat.AddString( L"Text File" ); + transcribeOutFormat.AddString( L"SubRip subtitles" ); + transcribeOutFormat.AddString( L"WebVTT subtitles" ); +} + +enum struct TranscribeDlg::eOutputFormat : uint8_t +{ + None = 0, + Text = 1, + SubRip = 2, + WebVTT = 3 +}; + +LRESULT TranscribeDlg::OnOutFormatChange( UINT, INT, HWND, BOOL& bHandled ) +{ + const BOOL enabled = transcribeOutFormat.GetCurSel() != 0; + transcribeOutputPath.EnableWindow( enabled ); + transcribeOutputBrowse.EnableWindow( enabled ); + return 0; +} + +void TranscribeDlg::onBrowseMedia() +{ + LPCTSTR title = L"Input audio file to transcribe"; + LPCTSTR filters = L"Multimedia Files\0*.wav;*.wave;*.mp3;*.wma;*.mp4;*.mpeg4;*.mkv\0\0"; + + CString path; + sourceMediaPath.GetWindowText( path ); + if( getOpenFileName( m_hWnd, title, filters, path ) ) + sourceMediaPath.SetWindowText( path ); +} + +static const LPCTSTR outputFilters = L"Text files (*.txt)\0*.txt\0SubRip subtitles (*.srt)\0*.srt\0WebVTT subtitles (*.vtt)\0*.vtt\0\0"; +static const std::array<LPCTSTR, 3> outputExtensions = +{ + L".txt", L".srt", L".vtt" +}; + +void TranscribeDlg::onBrowseOutput() +{ + const DWORD origFilterIndex = (DWORD)transcribeOutFormat.GetCurSel() - 1; + + LPCTSTR title = L"Output Text File"; + CString path; + transcribeOutputPath.GetWindowText( path ); + DWORD filterIndex = origFilterIndex; + if( !getSaveFileName( m_hWnd, title, outputFilters, path, &filterIndex ) ) + return; + + LPCTSTR ext = PathFindExtension( path ); + if( 0 == *ext && filterIndex < outputExtensions.size() ) + { + wchar_t* const buffer = path.GetBufferSetLength( path.GetLength() + 5 ); + PathRenameExtension( buffer, outputExtensions[ filterIndex ] ); + path.ReleaseBuffer(); + } + + transcribeOutputPath.SetWindowText( path ); + if( filterIndex != origFilterIndex ) + transcribeOutFormat.SetCurSel( filterIndex + 1 ); +} + +void TranscribeDlg::setPending( bool nowPending ) +{ + pendingState.setPending( nowPending ); +} + +void TranscribeDlg::transcribeError( LPCTSTR text, HRESULT hr ) +{ + reportError( m_hWnd, text, L"Unable to transcribe audio", hr ); +} + +void TranscribeDlg::onTranscribe() +{ + // Validate input + sourceMediaPath.GetWindowText( transcribeArgs.pathMedia ); + if( transcribeArgs.pathMedia.GetLength() <= 0 ) + { + transcribeError( L"Please select an input audio file" ); + return; + } + + if( !PathFileExists( transcribeArgs.pathMedia ) ) + { + transcribeError( L"Input audio file does not exist", HRESULT_FROM_WIN32( ERROR_FILE_NOT_FOUND ) ); + return; + } + + transcribeArgs.language = languageSelector.selectedLanguage(); + transcribeArgs.translate = cbTranslate.checked(); + if( isInvalidTranslate( m_hWnd, transcribeArgs.language, transcribeArgs.translate ) ) + return; + + transcribeArgs.format = (eOutputFormat)(uint8_t)transcribeOutFormat.GetCurSel(); + if( transcribeArgs.format != eOutputFormat::None ) + { + transcribeOutputPath.GetWindowText( transcribeArgs.pathOutput ); + if( transcribeArgs.pathOutput.GetLength() <= 0 ) + { + transcribeError( L"Please select an output text file" ); + return; + } + appState.stringStore( regValOutPath, transcribeArgs.pathOutput ); + } + else + cbConsole.ensureChecked(); + + appState.dwordStore( regValOutFormat, (uint32_t)(int)transcribeArgs.format ); + languageSelector.saveSelection( appState ); + cbTranslate.saveSelection( appState ); + appState.stringStore( regValInput, transcribeArgs.pathMedia ); + + setPending( true ); + + work.post(); +} + +void __stdcall TranscribeDlg::poolCallback() noexcept +{ + HRESULT hr = transcribe(); + PostMessage( WM_CALLBACK_STATUS, (WPARAM)hr ); +} + +static void printTime( CString& rdi, int64_t ticks ) +{ + const Whisper::sTimeSpan ts{ (uint64_t)ticks }; + const Whisper::sTimeSpanFields fields = ts; + + if( fields.days != 0 ) + { + rdi.AppendFormat( L"%i days, %i hours", fields.days, (int)fields.hours ); + return; + } + if( ( fields.hours | fields.minutes ) != 0 ) + { + rdi.AppendFormat( L"%02d:%02d:%02d", (int)fields.hours, (int)fields.minutes, (int)fields.seconds ); + return; + } + rdi.AppendFormat( L"%.3f seconds", (double)ticks / 1E7 ); +} + +LRESULT TranscribeDlg::onCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled ) +{ + setPending( false ); + const HRESULT hr = (HRESULT)wParam; + if( FAILED( hr ) ) + { + LPCTSTR failMessage = L"Transcribe failed"; + + if( transcribeArgs.errorMessage.GetLength() > 0 ) + { + CString tmp = failMessage; + tmp += L"\n"; + tmp += transcribeArgs.errorMessage; + transcribeError( tmp, hr ); + } + else + transcribeError( failMessage, hr ); + + return 0; + } + + const int64_t elapsed = ( GetTickCount64() - transcribeArgs.startTime ) * 10'000; + const int64_t media = transcribeArgs.mediaDuration; + CString message = L"Transcribed the audio\nMedia duration: "; + printTime( message, media ); + message += L"\nProcessing time: "; + printTime( message, elapsed ); + message += L"\nRelative processing speed: "; + double mul = (double)media / (double)elapsed; + message.AppendFormat( L"%g", mul ); + + MessageBox( message, L"Transcribe Completed", MB_OK | MB_ICONINFORMATION ); + return 0; +} + +void TranscribeDlg::getThreadError() +{ + getLastError( transcribeArgs.errorMessage ); +} + +#define CHECK_EX( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) { getThreadError(); return __hr; } } + +HRESULT TranscribeDlg::transcribe() +{ + transcribeArgs.startTime = GetTickCount64(); + clearLastError(); + transcribeArgs.errorMessage = L""; + + using namespace Whisper; + CComPtr<iAudioReader> reader; + + CHECK_EX( appState.mediaFoundation->openAudioFile( transcribeArgs.pathMedia, false, &reader ) ); + CHECK_EX( reader->getDuration( transcribeArgs.mediaDuration ) ); + + const eOutputFormat format = transcribeArgs.format; + CAtlFile outputFile; + if( format != eOutputFormat::None ) + CHECK( outputFile.Create( transcribeArgs.pathOutput, GENERIC_WRITE, 0, CREATE_ALWAYS ) ); + + transcribeArgs.resultFlags = eResultFlags::Timestamps | eResultFlags::Tokens; + + CComPtr<iContext> context; + CHECK_EX( appState.model->createContext( &context ) ); + + sFullParams fullParams; + CHECK_EX( context->fullDefaultParams( eSamplingStrategy::Greedy, &fullParams ) ); + fullParams.language = transcribeArgs.language; + fullParams.setFlag( eFullParamsFlags::Translate, transcribeArgs.translate ); + fullParams.resetFlag( eFullParamsFlags::PrintRealtime ); + + fullParams.new_segment_callback_user_data = this; + fullParams.new_segment_callback = &newSegmentCallbackStatic; + + // Setup the progress indication sink + sProgressSink progressSink{ &progressCallbackStatic, this }; + // Run the transcribe + CHECK_EX( context->runStreamed( fullParams, progressSink, reader ) ); + + context->timingsPrint(); + + if( format == eOutputFormat::None ) + return S_OK; + + CComPtr<iTranscribeResult> result; + CHECK_EX( context->getResults( transcribeArgs.resultFlags, &result ) ); + + sTranscribeLength len; + CHECK_EX( result->getSize( len ) ); + const sSegment* const segments = result->getSegments(); + + switch( format ) + { + case eOutputFormat::Text: + return writeTextFile( segments, len.countSegments, outputFile ); + case eOutputFormat::SubRip: + return writeSubRip( segments, len.countSegments, outputFile ); + case eOutputFormat::WebVTT: + return writeWebVTT( segments, len.countSegments, outputFile ); + default: + return E_FAIL; + } +} + +#undef CHECK_EX + +inline HRESULT TranscribeDlg::progressCallback( double p ) noexcept +{ + constexpr double mul = progressMaxInteger; + int pos = lround( mul * p ); + progressBar.PostMessage( PBM_SETPOS, pos, 0 ); + return S_OK; +} + +HRESULT __cdecl TranscribeDlg::progressCallbackStatic( double p, Whisper::iContext* ctx, void* pv ) noexcept +{ + TranscribeDlg* dlg = (TranscribeDlg*)pv; + return dlg->progressCallback( p ); +} + +namespace +{ + HRESULT write( CAtlFile& file, const CStringA& line ) + { + if( line.GetLength() > 0 ) + CHECK( file.Write( cstr( line ), (DWORD)line.GetLength() ) ); + return S_OK; + } + + void printTime( CStringA& rdi, Whisper::sTimeSpan time, bool comma ) + { + Whisper::sTimeSpanFields fields = time; + const char separator = comma ? ',' : '.'; + rdi.AppendFormat( "%02d:%02d:%02d%c%03d", + (int)fields.hours, + (int)fields.minutes, + (int)fields.seconds, + separator, + fields.ticks / 10'000 ); + } + + const char* skipBlank( const char* rsi ) + { + while( true ) + { + const char c = *rsi; + if( c == ' ' || c == '\t' ) + { + rsi++; + continue; + } + return rsi; + } + } +} + +using Whisper::sSegment; + + +HRESULT TranscribeDlg::writeTextFile( const sSegment* const segments, const size_t length, CAtlFile& file ) +{ + using namespace Whisper; + CHECK( writeUtf8Bom( file ) ); + CStringA line; + for( size_t i = 0; i < length; i++ ) + { + line = skipBlank( segments[ i ].text ); + line += "\r\n"; + CHECK( write( file, line ) ); + } + return S_OK; +} + +HRESULT TranscribeDlg::writeSubRip( const sSegment* const segments, const size_t length, CAtlFile& file ) +{ + CHECK( writeUtf8Bom( file ) ); + CStringA line; + for( size_t i = 0; i < length; i++ ) + { + const sSegment& seg = segments[ i ]; + + line.Format( "%zu\r\n", i + 1 ); + printTime( line, seg.time.begin, true ); + line += " --> "; + printTime( line, seg.time.end, true ); + line += "\r\n"; + line += skipBlank( seg.text ); + line += "\r\n\r\n"; + CHECK( write( file, line ) ); + } + return S_OK; +} + +HRESULT TranscribeDlg::writeWebVTT( const sSegment* const segments, const size_t length, CAtlFile& file ) +{ + CHECK( writeUtf8Bom( file ) ); + CStringA line; + line = "WEBVTT\r\n\r\n"; + CHECK( write( file, line ) ); + + for( size_t i = 0; i < length; i++ ) + { + const sSegment& seg = segments[ i ]; + line = ""; + + printTime( line, seg.time.begin, false ); + line += " --> "; + printTime( line, seg.time.end, false ); + line += "\r\n"; + line += skipBlank( seg.text ); + line += "\r\n\r\n"; + CHECK( write( file, line ) ); + } + return S_OK; +} + +inline HRESULT TranscribeDlg::newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new ) +{ + using namespace Whisper; + CComPtr<iTranscribeResult> result; + CHECK( ctx->getResults( transcribeArgs.resultFlags, &result ) ); + return logNewSegments( result, n_new ); +} + +HRESULT __cdecl TranscribeDlg::newSegmentCallbackStatic( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept +{ + TranscribeDlg* dlg = (TranscribeDlg*)user_data; + return dlg->newSegmentCallback( ctx, n_new ); +} + +void TranscribeDlg::onWmClose() +{ + if( GetDlgItem( IDCANCEL ).IsWindowEnabled() ) + { + EndDialog( IDCANCEL ); + return; + } + + constexpr UINT flags = MB_YESNO | MB_ICONQUESTION | MB_DEFBUTTON2; + const int res = this->MessageBox( L"Transcribe is in progress.\nDo you want to quit anyway?", L"Confirm exit", flags ); + if( res != IDYES ) + return; + + // TODO: instead of ExitProcess(), implement another callback in the DLL API, for proper cancellation of the background task + ExitProcess( 1 ); +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/TranscribeDlg.h b/Examples/WhisperDesktop/TranscribeDlg.h new file mode 100644 index 0000000..354e002 --- /dev/null +++ b/Examples/WhisperDesktop/TranscribeDlg.h @@ -0,0 +1,124 @@ +#pragma once +#include "AppState.h" +#include "Utils/WTL/atlddx.h" +#include "Utils/WTL/atlcrack.h" +#include "Utils/miscUtils.h" +#include "Utils/LanguageDropdown.h" +#include "Utils/TranslateCheckbox.h" +#include "Utils/PendingState.h" + +class TranscribeDlg : + public CDialogImpl<TranscribeDlg>, + public CWinDataExchange<TranscribeDlg>, + public iThreadPoolCallback +{ + AppState& appState; + +public: + static constexpr UINT IDD = IDD_TRANSCRIBE_DIALOG; + static constexpr UINT WM_CALLBACK_STATUS = WM_APP + 1; + + TranscribeDlg( AppState& app ) : appState( app ) { } + + // Show this dialog modally, without parent. + HRESULT show(); + + BEGIN_MSG_MAP( LoadModelDlg ) + MESSAGE_HANDLER( WM_INITDIALOG, OnInitDialog ) + ON_BUTTON_CLICK( IDC_CONSOLE, cbConsole.click ) + ON_BUTTON_CLICK( IDCANCEL, onClose ) + ON_BUTTON_CLICK( IDC_BACK, onBack ) + ON_BUTTON_CLICK( IDC_BROWSE_MEDIA, onBrowseMedia ) + ON_BUTTON_CLICK( IDC_BROWSE_RESULT, onBrowseOutput ) + ON_BUTTON_CLICK( IDC_TRANSCRIBE, onTranscribe ) + ON_BUTTON_CLICK( IDC_CAPTURE, onCapture ); + COMMAND_HANDLER( IDC_OUTPUT_FORMAT, CBN_SELCHANGE, OnOutFormatChange ) + MESSAGE_HANDLER( WM_CALLBACK_STATUS, onCallbackStatus ) + MSG_WM_CLOSE( onWmClose ) + END_MSG_MAP() + + BEGIN_DDX_MAP( LoadModelDlg ) + DDX_CONTROL_HANDLE( IDC_MODEL_DESC, modelDesc ) + DDX_CONTROL_HANDLE( IDC_PATH_MEDIA, sourceMediaPath ) + DDX_CONTROL_HANDLE( IDC_OUTPUT_FORMAT, transcribeOutFormat ) + DDX_CONTROL_HANDLE( IDC_PATH_RESULT, transcribeOutputPath ) + DDX_CONTROL_HANDLE( IDC_BROWSE_RESULT, transcribeOutputBrowse ); + DDX_CONTROL_HANDLE( IDC_TRANSCRIBE_PROGRESS, progressBar ); + END_DDX_MAP() + +private: + PendingState pendingState; + void setPending( bool nowPending ); + void transcribeError( LPCTSTR text, HRESULT hr = S_FALSE ); + + LRESULT OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ); + + void onClose() + { + ATLVERIFY( EndDialog( IDCANCEL ) ); + } + void onBack() + { + ATLVERIFY( EndDialog( IDC_BACK ) ); + } + + void printModelDescription(); + CStatic modelDesc; + ConsoleCheckbox cbConsole; + + LanguageDropdown languageSelector; + TranslateCheckbox cbTranslate; + + CEdit sourceMediaPath; + CEdit transcribeOutputPath; + CButton transcribeOutputBrowse; + CComboBox transcribeOutFormat; + CProgressBarCtrl progressBar; + void populateOutputFormats(); + + LRESULT OnOutFormatChange( UINT, INT, HWND, BOOL& bHandled ); + void onBrowseMedia(); + void onBrowseOutput(); + void onTranscribe(); + void onCapture() + { + EndDialog( IDC_CAPTURE ); + } + + ThreadPoolWork work; + + enum struct eOutputFormat : uint8_t; + + struct TranscribeArgs + { + CString pathMedia; + CString pathOutput; + uint32_t language; + bool translate; + eOutputFormat format; + Whisper::eResultFlags resultFlags; + uint64_t startTime; + int64_t mediaDuration; + CString errorMessage; + }; + TranscribeArgs transcribeArgs; + + void __stdcall poolCallback() noexcept override final; + + LRESULT onCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled ); + + HRESULT transcribe(); + void getThreadError(); + + static HRESULT writeTextFile( const Whisper::sSegment* const segments, const size_t length, CAtlFile& file ); + static HRESULT writeSubRip( const Whisper::sSegment* const segments, const size_t length, CAtlFile& file ); + static HRESULT writeWebVTT( const Whisper::sSegment* const segments, const size_t length, CAtlFile& file ); + + static HRESULT __cdecl newSegmentCallbackStatic( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept; + HRESULT newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new ); + + static HRESULT __cdecl progressCallbackStatic( double p, Whisper::iContext* ctx, void* pv ) noexcept; + HRESULT progressCallback( double p ) noexcept; + + void onWmClose(); +};
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/DebugConsole.cpp b/Examples/WhisperDesktop/Utils/DebugConsole.cpp new file mode 100644 index 0000000..640efb0 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/DebugConsole.cpp @@ -0,0 +1,289 @@ +// https://github.com/Const-me/vis_avs_dx/blob/master/avs_dx/DxVisuals/Interop/ConsoleLogger.cpp +#include "stdafx.h" +#include "DebugConsole.h" +#include "miscUtils.h" +#include "../AppState.h" +#include "logger.h" + +namespace +{ + using Whisper::eLogLevel; + + constexpr uint16_t defaultAttributes = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE; + + inline uint16_t textAttributes( eLogLevel lvl ) + { + switch( lvl ) + { + case eLogLevel::Error: + return FOREGROUND_RED | FOREGROUND_INTENSITY; + case eLogLevel::Warning: + return FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_INTENSITY; + case eLogLevel::Info: + return FOREGROUND_GREEN | FOREGROUND_INTENSITY; + case eLogLevel::Debug: + return FOREGROUND_BLUE | FOREGROUND_INTENSITY; + } + return defaultAttributes; + } + + // Background stuff: accumulate messages in a small buffer, in case user will want to see them in the console. + // Ideally, we should accumulate them in a more efficient data structure, maybe a circular buffer. + // However, we don't have that many messages per second, this simple solution that uses std::deque is probably good enough for the job. + static constexpr uint16_t bufferSize = 64; + + using Lock = CComCritSecLock<CComAutoCriticalSection>; +#define LOCK() Lock __lock{ critSec } + + thread_local CStringA threadError; +} + +HRESULT DebugConsole::Entry::print( HANDLE hConsole, CString& tempString ) const +{ + if( !SetConsoleTextAttribute( hConsole, textAttributes( level ) ) ) + return getLastHr(); + + makeUtf16( tempString, message ); + tempString += L"\r\n"; + if( !WriteConsoleW( hConsole, tempString, (DWORD)tempString.GetLength(), nullptr, nullptr ) ) + return getLastHr(); + return S_OK; +} + +void clearLastError() +{ + threadError = ""; +} + +bool getLastError( CString& rdi ) +{ + if( threadError.GetLength() <= 0 ) + { + rdi = L""; + return false; + } + else + { + makeUtf16( rdi, threadError ); + threadError = ""; + return true; + } +} + +inline void DebugConsole::logSink( eLogLevel lvl, const char* message ) +{ + LOCK(); + + // Add to the buffer + while( buffer.size() >= bufferSize ) + buffer.pop_front(); + buffer.emplace_back( Entry{ lvl, message } ); + + // If the console window is shown, print there, too. + if( output ) + buffer.rbegin()->print( output, tempString ); +} + +void __stdcall DebugConsole::logSinkStatic( void* context, eLogLevel lvl, const char* message ) +{ + if( lvl == eLogLevel::Error ) + threadError = message; + + DebugConsole* con = (DebugConsole*)context; + con->logSink( lvl, message ); +} + +HRESULT DebugConsole::initialize( Whisper::eLogLevel level ) +{ + if( nullptr != pGlobalInstance ) + return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); + pGlobalInstance = this; + + Whisper::sLoggerSetup setup; + setup.sink = &logSinkStatic; + setup.context = this; + setup.level = level; + setup.flags = Whisper::eLoggerFlags::SkipFormatMessage; + return Whisper::setupLogger( setup ); +} + +DebugConsole::~DebugConsole() +{ + hide(); + + Whisper::sLoggerSetup setup; + memset( &setup, 0, sizeof( setup ) ); + Whisper::setupLogger( setup ); + + pGlobalInstance = nullptr; +} + +DebugConsole* DebugConsole::pGlobalInstance = nullptr; + +void DebugConsole::windowClosed() +{ + LOCK(); + if( FreeConsole() ) + { + // Apparently, FreeConsole already closes that handle: https://stackoverflow.com/q/12676312/126995 + output.Detach(); + } + output.Close(); + + for( CButton* b : checkboxes ) + { + if( !*b ) + continue; + if( !b->IsWindow() ) + continue; + PostMessage( *b, BM_SETCHECK, BST_UNCHECKED, 0 ); + } +} + +BOOL __stdcall DebugConsole::consoleHandlerRoutine( DWORD dwCtrlType ) +{ + switch( dwCtrlType ) + { + case CTRL_CLOSE_EVENT: + case CTRL_C_EVENT: + case CTRL_BREAK_EVENT: + pGlobalInstance->windowClosed(); + return TRUE; + } + return TRUE; +} + +HRESULT DebugConsole::show() +{ + HWND hWnd = GetConsoleWindow(); + if( nullptr != hWnd ) + { + ShowWindow( hWnd, SW_RESTORE ); + SetForegroundWindow( hWnd ); + return S_FALSE; + } + + if( !AllocConsole() ) + return getLastHr(); + + output.Close(); + output.Attach( GetStdHandle( STD_OUTPUT_HANDLE ) ); + if( !output ) + return getLastHr(); + + constexpr UINT cp = CP_UTF8; + if( IsValidCodePage( cp ) ) + SetConsoleOutputCP( cp ); + + // Enable ANSI color coding + DWORD mode = 0; + if( !GetConsoleMode( output, &mode ) ) + return getLastHr(); + if( 0 == ( mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING ) ) + { + mode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if( !SetConsoleMode( output, mode ) ) + return getLastHr(); + } + + SetConsoleTitle( L"Whisper Desktop Debug Console" ); + + SetConsoleCtrlHandler( &consoleHandlerRoutine, TRUE ); + + // Disable close command in the sys.menu of the new console, otherwise the whole process will quit: https://stackoverflow.com/a/12015131/126995 + HWND hwnd = ::GetConsoleWindow(); + if( hwnd != nullptr ) + { + HMENU hMenu = ::GetSystemMenu( hwnd, FALSE ); + if( hMenu != NULL ) + DeleteMenu( hMenu, SC_CLOSE, MF_BYCOMMAND ); + } + + // Print old log entries + for( const auto& e : buffer ) + CHECK( e.print( output, tempString ) ); + + const CStringA msg = "Press Control+C or Control+Break to close this window\r\n"; + if( !SetConsoleTextAttribute( output, FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE | FOREGROUND_INTENSITY ) ) + return getLastHr(); + if( !WriteConsoleA( output, cstr( msg ), msg.GetLength(), nullptr, nullptr ) ) + return getLastHr(); + + return S_OK; +} + +HRESULT DebugConsole::hide() +{ + LOCK(); + if( !output ) + return S_FALSE; + windowClosed(); + return S_OK; +} + +void DebugConsole::addCheckbox( CButton& cb ) +{ + checkboxes.emplace( &cb ); +} +void DebugConsole::removeCheckbox( CButton& cb ) +{ + checkboxes.erase( &cb ); +} + +HRESULT ConsoleCheckbox::initialize( HWND dialog, int idc, AppState& state ) +{ + control = GetDlgItem( dialog, idc ); + assert( control ); + + console = &state.console; + if( state.console.isVisible() ) + control.SetCheck( BST_CHECKED ); + + state.console.addCheckbox( control ); + return S_OK; +} + +void ConsoleCheckbox::click() +{ + const int state = control.GetCheck(); + if( state == BST_CHECKED ) + console->show(); + else + console->hide(); +} + +void ConsoleCheckbox::ensureChecked() +{ + const int state = control.GetCheck(); + if( state == BST_CHECKED ) + return; + control.SetCheck( BST_CHECKED ); + console->show(); +} + +void DebugConsole::log( eLogLevel lvl, const char* pszFormat, va_list args ) +{ + LOCK(); + // Add to the buffer + while( buffer.size() >= bufferSize ) + buffer.pop_front(); + + tempStringA.FormatV( pszFormat, args ); + buffer.emplace_back( Entry{ lvl, tempStringA } ); + + // If the console window is shown, print there, too. + if( output ) + buffer.rbegin()->print( output, tempString ); +} + +void DebugConsole::logMessage( eLogLevel lvl, const char* pszFormat, va_list args ) +{ + if( nullptr == pGlobalInstance ) + return; + pGlobalInstance->log( lvl, pszFormat, args ); +} + +void logMessage( Whisper::eLogLevel lvl, const char8_t* pczFormat, va_list args ) +{ + DebugConsole::logMessage( lvl, (const char*)pczFormat, args ); +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/DebugConsole.h b/Examples/WhisperDesktop/Utils/DebugConsole.h new file mode 100644 index 0000000..a9ee8f2 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/DebugConsole.h @@ -0,0 +1,64 @@ +#pragma once +#include <whisperWindows.h> +#include <deque> +#include <unordered_set> + +class AppState; +class DebugConsole +{ + using eLogLevel = Whisper::eLogLevel; + + struct Entry + { + eLogLevel level; + CStringA message; + HRESULT print( HANDLE hConsole, CString& tempString ) const; + }; + + CComAutoCriticalSection critSec; + std::deque<Entry> buffer; + CString tempString; + CHandle output; + + inline void logSink( eLogLevel lvl, const char* message ); + static void __stdcall logSinkStatic( void* context, eLogLevel lvl, const char* message ); + + static BOOL __stdcall consoleHandlerRoutine( DWORD dwCtrlType ); + + static DebugConsole* pGlobalInstance; + void windowClosed(); + + std::unordered_set<CButton*> checkboxes; + + CStringA tempStringA; + void log( eLogLevel lvl, const char* pszFormat, va_list args ); + +public: + HRESULT initialize( eLogLevel level = eLogLevel::Debug ); + ~DebugConsole(); + + HRESULT show(); + HRESULT hide(); + bool isVisible() const { return output; } + + void addCheckbox( CButton& cb ); + void removeCheckbox( CButton& cb ); + + static void logMessage( eLogLevel lvl, const char* pszFormat, va_list args ); +}; + +class ConsoleCheckbox +{ + CButton control; + DebugConsole* console = nullptr; + +public: + HRESULT initialize( HWND dialog, int idc, AppState& state ); + void click(); + ~ConsoleCheckbox() + { + if( nullptr != console ) + console->removeCheckbox( control ); + } + void ensureChecked(); +};
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/LanguageDropdown.cpp b/Examples/WhisperDesktop/Utils/LanguageDropdown.cpp new file mode 100644 index 0000000..36cd1a8 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/LanguageDropdown.cpp @@ -0,0 +1,87 @@ +#include "stdafx.h" +#include "LanguageDropdown.h" +#include "miscUtils.h" + +namespace +{ + inline wchar_t toUpper( wchar_t c ) + { + size_t st = (uint16_t)c; + st = reinterpret_cast<size_t>( CharUpperW( reinterpret_cast<LPWSTR>( st ) ) ); + return (wchar_t)(uint16_t)st; + } + + void makeTitleCase( CString& s ) + { + bool cap = true; + for( int i = 0; i < s.GetLength(); i++ ) + { + wchar_t c = s[ i ]; + if( cap ) + { + c = toUpper( c ); + s.SetAt( i, c ); + } + cap = false; + if( c == ' ' ) + cap = true; + } + } +} + +int LanguageDropdown::getInitialSelection( AppState& state ) const +{ + constexpr uint32_t english = 0x6E65; + + // Load preference from the registry + uint32_t id = state.languageRead(); + if( id == UINT_MAX ) + id = english; + + auto it = std::find( keys.begin(), keys.end(), id ); + if( it == keys.end() ) + { + id = english; + it = std::find( keys.begin(), keys.end(), id ); + assert( it != keys.end() ); + } + + ptrdiff_t idx = it - keys.begin(); + return (int)idx; +} + +void LanguageDropdown::initialize( HWND owner, int idc, AppState& state ) +{ + m_hWnd = GetDlgItem( owner, idc ); + assert( nullptr != m_hWnd ); + + Whisper::sLanguageList list; + Whisper::getSupportedLanguages( list ); + + const size_t length = list.length; + keys.resize( length ); + CString utf16; + for( size_t i = 0; i < length; i++ ) + { + keys[ i ] = list.pointer[ i ].key; + makeUtf16( utf16, list.pointer[ i ].name ); + makeTitleCase( utf16 ); + SendMessage( m_hWnd, CB_ADDSTRING, 0, (LPARAM)cstr( utf16 ) ); + } + + const int curSel = getInitialSelection( state ); + SendMessage( m_hWnd, CB_SETCURSEL, curSel, 0 ); +} + +uint32_t LanguageDropdown::selectedLanguage() +{ + const int cs = (int)SendMessage( m_hWnd, CB_GETCURSEL, 0, 0 ); + if( cs < 0 || cs >= keys.size() ) + return UINT_MAX; + return keys[ cs ]; +} + +void LanguageDropdown::saveSelection( AppState& state ) +{ + state.languageWrite( selectedLanguage() ); +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/LanguageDropdown.h b/Examples/WhisperDesktop/Utils/LanguageDropdown.h new file mode 100644 index 0000000..640b81e --- /dev/null +++ b/Examples/WhisperDesktop/Utils/LanguageDropdown.h @@ -0,0 +1,26 @@ +#pragma once +#include "../AppState.h" + +// Dropdown list which implements language selector control +class LanguageDropdown +{ + HWND m_hWnd = nullptr; + std::vector<uint32_t> keys; + int getInitialSelection( AppState& state ) const; + +public: + operator HWND() const + { + return m_hWnd; + } + + // Query language list form the native library, populate the combo box + // Then load the last saved language selection from registry, and preselect an item. + void initialize( HWND owner, int idc, AppState& state ); + + // Get the ID of the currently selected language, or UINT_MAX if nothing's selected + uint32_t selectedLanguage(); + + // Get the ID of the currently selected language, and store in registry + void saveSelection( AppState& state ); +};
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/PendingState.cpp b/Examples/WhisperDesktop/Utils/PendingState.cpp new file mode 100644 index 0000000..404ae4e --- /dev/null +++ b/Examples/WhisperDesktop/Utils/PendingState.cpp @@ -0,0 +1,40 @@ +#include "stdafx.h" +#include "PendingState.h" + +void PendingState::initialize( std::initializer_list<HWND> editors, std::initializer_list<HWND> pending ) +{ + editorsWindows = editors; + wasEnabled.resize( editorsWindows.size() ); + pendingWindows = pending; +} + +void PendingState::setPending( bool nowPending ) +{ + if( nowPending ) + { + for( size_t i = 0; i < editorsWindows.size(); i++ ) + { + BOOL e = IsWindowEnabled( editorsWindows[ i ] ); + if( e ) + { + wasEnabled[ i ] = true; + EnableWindow( editorsWindows[ i ], FALSE ); + } + else + wasEnabled[ i ] = false; + } + } + else + { + for( size_t i = 0; i < editorsWindows.size(); i++ ) + { + if( !wasEnabled[ i ] ) + continue; + EnableWindow( editorsWindows[ i ], TRUE ); + } + } + + const int show = nowPending ? SW_NORMAL : SW_HIDE; + for( HWND w : pendingWindows ) + ::ShowWindow( w, show ); +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/PendingState.h b/Examples/WhisperDesktop/Utils/PendingState.h new file mode 100644 index 0000000..0b34e13 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/PendingState.h @@ -0,0 +1,12 @@ +#pragma once + +// Utility class to switch visual state of dialog controls between idle and pending +class PendingState +{ + std::vector<HWND> editorsWindows; + std::vector<bool> wasEnabled; + std::vector<HWND> pendingWindows; +public: + void initialize( std::initializer_list<HWND> editors, std::initializer_list<HWND> pending ); + void setPending( bool nowPending ); +};
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp b/Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp new file mode 100644 index 0000000..c5e6ac0 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp @@ -0,0 +1,25 @@ +#include "stdafx.h" +#include "TranslateCheckbox.h" + +static const LPCTSTR regValTranslate = L"translate"; + +void TranslateCheckbox::initialize( HWND owner, int idc, AppState& state ) +{ + m_hWnd = GetDlgItem( owner, idc ); + assert( nullptr != m_hWnd ); + + if( state.dwordLoad( regValTranslate, 0 ) != 0 ) + ::SendMessage( m_hWnd, BM_SETCHECK, BST_CHECKED, 0L ); +} + +bool TranslateCheckbox::checked() +{ + assert( nullptr != m_hWnd ); + const int state = ( int )::SendMessage( m_hWnd, BM_GETCHECK, 0, 0 ); + return state == BST_CHECKED; +} + +void TranslateCheckbox::saveSelection( AppState& state ) +{ + state.dwordStore( regValTranslate, checked() ? TRUE : FALSE ); +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/TranslateCheckbox.h b/Examples/WhisperDesktop/Utils/TranslateCheckbox.h new file mode 100644 index 0000000..2b1db12 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/TranslateCheckbox.h @@ -0,0 +1,18 @@ +#pragma once +#include "../AppState.h" + +class TranslateCheckbox +{ + HWND m_hWnd = nullptr; +public: + operator HWND() const + { + return m_hWnd; + } + + void initialize( HWND owner, int idc, AppState& state ); + + bool checked(); + + void saveSelection( AppState& state ); +};
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/WTL/atlapp.h b/Examples/WhisperDesktop/Utils/WTL/atlapp.h new file mode 100644 index 0000000..8be6edc --- /dev/null +++ b/Examples/WhisperDesktop/Utils/WTL/atlapp.h @@ -0,0 +1,1225 @@ +// Windows Template Library - WTL version 10.0 +// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved. +// +// This file is a part of the Windows Template Library. +// The use and distribution terms for this software are covered by the +// Microsoft Public License (http://opensource.org/licenses/MS-PL) +// which can be found in the file MS-PL.txt at the root folder. + +#ifndef __ATLAPP_H__ +#define __ATLAPP_H__ + +#pragma once + +#ifndef __cplusplus + #error WTL requires C++ compilation (use a .cpp suffix) +#endif + +#ifndef __ATLBASE_H__ + #error atlapp.h requires atlbase.h to be included first +#endif + +#ifdef _WIN32_WCE + #error WTL10 doesn't support Windows CE +#endif + +#ifdef _ATL_NO_COMMODULE + #error WTL doesn't support _ATL_NO_COMMODULE +#endif + +#ifdef _ATL_NO_WIN_SUPPORT + #error WTL doesn't support _ATL_NO_WIN_SUPPORT +#endif + +#if (_MSC_VER < 1400) + #error WTL10 requires C++ compiler version 14 (Visual C++ 2005) or higher +#endif + +#if (WINVER < 0x0501) + #error WTL requires WINVER >= 0x0501 +#endif + +#if (_WIN32_WINNT < 0x0501) + #error WTL requires _WIN32_WINNT >= 0x0501 +#endif + +#if (_WIN32_IE < 0x0600) + #error WTL requires _WIN32_IE >= 0x0600 +#endif + +#if (_ATL_VER < 0x0800) + #error WTL10 requires ATL version 8 or higher +#endif + +#ifdef _ATL_MIN_CRT + #error WTL10 doesn't support _ATL_MIN_CRT +#endif + +#ifdef _ATL_NO_MSIMG + #error WTL10 doesn't support _ATL_NO_MSIMG +#endif + +#include <limits.h> +#ifdef _MT + #include <process.h> // for _beginthreadex +#endif + +#include <commctrl.h> +#pragma comment(lib, "comctl32.lib") + +#include <commdlg.h> +#include <shellapi.h> + +// Check for VS2005 without newer WinSDK +#if (_MSC_VER == 1400) && !defined(RB_GETEXTENDEDSTYLE) + #error WTL10 requires WinSDK 6.0 ot higher +#endif + +#include <uxtheme.h> +#pragma comment(lib, "uxtheme.lib") + +#if defined(_SYSINFOAPI_H_) && defined(NOT_BUILD_WINDOWS_DEPRECATE) + #include <VersionHelpers.h> +#endif + +#include "atlres.h" + + +/////////////////////////////////////////////////////////////////////////////// +// WTL version number + +#define _WTL_VER 0x1000 // version 10.0 + + +/////////////////////////////////////////////////////////////////////////////// +// Classes in this file: +// +// CMessageFilter +// CIdleHandler +// CMessageLoop +// +// CAppModule +// CServerAppModule +// +// Global functions: +// AtlInitCommonControls() +// AtlGetDefaultGuiFont() +// AtlCreateControlFont() +// AtlCreateBoldFont() +// AtlGetStringPtr() + + +/////////////////////////////////////////////////////////////////////////////// +// Miscellaneous global support + +// define useful macros from winuser.h +#ifndef IS_INTRESOURCE + #define IS_INTRESOURCE(_r) (((ULONG_PTR)(_r) >> 16) == 0) +#endif // IS_INTRESOURCE + +// protect template members from windowsx.h macros +#ifdef _INC_WINDOWSX + #undef SubclassWindow +#endif // _INC_WINDOWSX + +// define useful macros from windowsx.h +#ifndef GET_X_LPARAM + #define GET_X_LPARAM(lParam) ((int)(short)LOWORD(lParam)) +#endif +#ifndef GET_Y_LPARAM + #define GET_Y_LPARAM(lParam) ((int)(short)HIWORD(lParam)) +#endif + +// Dummy structs for compiling with /CLR +#ifdef _MANAGED + __if_not_exists(_IMAGELIST::_IMAGELIST) { struct _IMAGELIST { }; } + __if_not_exists(_TREEITEM::_TREEITEM) { struct _TREEITEM { }; } + __if_not_exists(_PSP::_PSP) { struct _PSP { }; } +#endif + +// Forward declaration for ATL11 fix +#if (_ATL_VER >= 0x0B00) + namespace ATL { HRESULT AtlGetCommCtrlVersion(LPDWORD pdwMajor, LPDWORD pdwMinor); } +#endif + +#ifndef WM_MOUSEHWHEEL + #define WM_MOUSEHWHEEL 0x020E +#endif + +// Used for stack allocations with ATL::CTempBuffer +#ifndef _WTL_STACK_ALLOC_THRESHOLD + #define _WTL_STACK_ALLOC_THRESHOLD 512 +#endif + + +namespace WTL +{ + +DECLARE_TRACE_CATEGORY(atlTraceUI) +#ifdef _DEBUG + __declspec(selectany) ATL::CTraceCategory atlTraceUI(_T("atlTraceUI")); +#endif // _DEBUG + +// Common Controls initialization helper +inline BOOL AtlInitCommonControls(DWORD dwFlags) +{ + INITCOMMONCONTROLSEX iccx = { sizeof(INITCOMMONCONTROLSEX), dwFlags }; + BOOL bRet = ::InitCommonControlsEx(&iccx); + ATLASSERT(bRet); + return bRet; +} + +// Default GUI font helper - "MS Shell Dlg" stock font +inline HFONT AtlGetDefaultGuiFont() +{ + return (HFONT)::GetStockObject(DEFAULT_GUI_FONT); +} + +// Control font helper - default font for controls not in a dialog +// (NOTE: Caller owns the font, and should destroy it when it's no longer needed) +inline HFONT AtlCreateControlFont() +{ + LOGFONT lf = {}; + ATLVERIFY(::SystemParametersInfo(SPI_GETICONTITLELOGFONT, sizeof(LOGFONT), &lf, 0) != FALSE); + HFONT hFont = ::CreateFontIndirect(&lf); + ATLASSERT(hFont != NULL); + return hFont; +} + +// Bold font helper +// (NOTE: Caller owns the font, and should destroy it when it's no longer needed) +inline HFONT AtlCreateBoldFont(HFONT hFont = NULL) +{ + LOGFONT lf = {}; + if(hFont == NULL) + ATLVERIFY(::SystemParametersInfo(SPI_GETICONTITLELOGFONT, sizeof(LOGFONT), &lf, 0) != FALSE); + else + ATLVERIFY(::GetObject(hFont, sizeof(LOGFONT), &lf) == sizeof(LOGFONT)); + lf.lfWeight = FW_BOLD; + HFONT hFontBold = ::CreateFontIndirect(&lf); + ATLASSERT(hFontBold != NULL); + return hFontBold; +} + +// Resource string pointer +inline LPCWSTR AtlGetStringPtr(UINT uID, int* pch = NULL) +{ + LPCWSTR lpstr = NULL; + int nRet = ::LoadStringW(ATL::_AtlBaseModule.GetResourceInstance(), uID, (LPWSTR)&lpstr, 0); + if(pch != NULL) + *pch = nRet; + return lpstr; +} + + +/////////////////////////////////////////////////////////////////////////////// +// RunTimeHelper - helper functions for Windows version and structure sizes + +#ifndef _WTL_NO_RUNTIME_STRUCT_SIZE + +#ifndef _SIZEOF_STRUCT + #define _SIZEOF_STRUCT(structname, member) (((int)((LPBYTE)(&((structname*)0)->member) - ((LPBYTE)((structname*)0)))) + sizeof(((structname*)0)->member)) +#endif + +#if (_WIN32_WINNT >= 0x0600) && !defined(REBARBANDINFO_V6_SIZE) + #define REBARBANDINFO_V6_SIZE _SIZEOF_STRUCT(REBARBANDINFO, cxHeader) +#endif // (_WIN32_WINNT >= 0x0600) && !defined(REBARBANDINFO_V6_SIZE) + +#if (_WIN32_WINNT >= 0x0600) && !defined(LVGROUP_V5_SIZE) + #define LVGROUP_V5_SIZE _SIZEOF_STRUCT(LVGROUP, uAlign) +#endif // (_WIN32_WINNT >= 0x0600) && !defined(LVGROUP_V5_SIZE) + +#if (_WIN32_WINNT >= 0x0600) && !defined(LVTILEINFO_V5_SIZE) + #define LVTILEINFO_V5_SIZE _SIZEOF_STRUCT(LVTILEINFO, puColumns) +#endif // (_WIN32_WINNT >= 0x0600) && !defined(LVTILEINFO_V5_SIZE) + +#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) && !defined(MCHITTESTINFO_V1_SIZE) + #define MCHITTESTINFO_V1_SIZE _SIZEOF_STRUCT(MCHITTESTINFO, st) +#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) && !defined(MCHITTESTINFO_V1_SIZE) + +#if (WINVER >= 0x0600) && !defined(NONCLIENTMETRICS_V1_SIZE) + #define NONCLIENTMETRICS_V1_SIZE _SIZEOF_STRUCT(NONCLIENTMETRICS, lfMessageFont) +#endif // (WINVER >= 0x0600) && !defined(NONCLIENTMETRICS_V1_SIZE) + +#ifndef TTTOOLINFO_V2_SIZE + #define TTTOOLINFO_V2_SIZE _SIZEOF_STRUCT(TTTOOLINFO, lParam) +#endif + +#endif // !_WTL_NO_RUNTIME_STRUCT_SIZE + +namespace RunTimeHelper +{ + inline bool IsCommCtrl6() + { + DWORD dwMajor = 0, dwMinor = 0; + HRESULT hRet = ATL::AtlGetCommCtrlVersion(&dwMajor, &dwMinor); + return (SUCCEEDED(hRet) && (dwMajor >= 6)); + } + + inline bool IsVista() + { +#ifdef _versionhelpers_H_INCLUDED_ + return ::IsWindowsVistaOrGreater(); +#else // !_versionhelpers_H_INCLUDED_ + OSVERSIONINFO ovi = { sizeof(OSVERSIONINFO) }; + BOOL bRet = ::GetVersionEx(&ovi); + return ((bRet != FALSE) && (ovi.dwMajorVersion >= 6)); +#endif // _versionhelpers_H_INCLUDED_ + } + + inline bool IsThemeAvailable() + { + return IsCommCtrl6() && (::IsThemeActive() != FALSE) && (::IsAppThemed() != FALSE); + } + + inline bool IsWin7() + { +#ifdef _versionhelpers_H_INCLUDED_ + return ::IsWindows7OrGreater(); +#else // !_versionhelpers_H_INCLUDED_ + OSVERSIONINFO ovi = { sizeof(OSVERSIONINFO) }; + BOOL bRet = ::GetVersionEx(&ovi); + return ((bRet != FALSE) && ((ovi.dwMajorVersion > 6) || ((ovi.dwMajorVersion == 6) && (ovi.dwMinorVersion >= 1)))); +#endif // _versionhelpers_H_INCLUDED_ + } + + inline bool IsRibbonUIAvailable() + { + static INT iRibbonUI = -1; + +#if defined(NTDDI_WIN7) && (NTDDI_VERSION >= NTDDI_WIN7) + if (iRibbonUI == -1) + { + HMODULE hRibbonDLL = ::LoadLibrary(_T("propsys.dll")); + if (hRibbonDLL != NULL) + { + const GUID CLSID_UIRibbonFramework = { 0x926749fa, 0x2615, 0x4987, { 0x88, 0x45, 0xc3, 0x3e, 0x65, 0xf2, 0xb9, 0x57 } }; + // block - create instance + { + ATL::CComPtr<IUnknown> pIUIFramework; + iRibbonUI = SUCCEEDED(pIUIFramework.CoCreateInstance(CLSID_UIRibbonFramework)) ? 1 : 0; + } + ::FreeLibrary(hRibbonDLL); + } + else + { + iRibbonUI = 0; + } + } +#endif // defined(NTDDI_WIN7) && (NTDDI_VERSION >= NTDDI_WIN7) + + return (iRibbonUI == 1); + } + + inline UINT SizeOf_REBARBANDINFO() + { + UINT uSize = sizeof(REBARBANDINFO); +#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600) + if(!(IsVista() && IsCommCtrl6())) + uSize = REBARBANDINFO_V6_SIZE; +#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600) + return uSize; + } + + inline UINT SizeOf_LVGROUP() + { + UINT uSize = sizeof(LVGROUP); +#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600) + if(!IsVista()) + uSize = LVGROUP_V5_SIZE; +#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600) + return uSize; + } + + inline UINT SizeOf_LVTILEINFO() + { + UINT uSize = sizeof(LVTILEINFO); +#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600) + if(!IsVista()) + uSize = LVTILEINFO_V5_SIZE; +#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600) + return uSize; + } + + inline UINT SizeOf_MCHITTESTINFO() + { + UINT uSize = sizeof(MCHITTESTINFO); +#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) + if(!(IsVista() && IsCommCtrl6())) + uSize = MCHITTESTINFO_V1_SIZE; +#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) + return uSize; + } + + inline UINT SizeOf_NONCLIENTMETRICS() + { + UINT uSize = sizeof(NONCLIENTMETRICS); +#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (WINVER >= 0x0600) + if(!IsVista()) + uSize = NONCLIENTMETRICS_V1_SIZE; +#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (WINVER >= 0x0600) + return uSize; + } + + inline UINT SizeOf_TOOLINFO() + { + UINT uSize = sizeof(TOOLINFO); +#ifndef _WTL_NO_RUNTIME_STRUCT_SIZE + if(!IsVista()) + uSize = TTTOOLINFO_V2_SIZE; +#endif + return uSize; + } +} // namespace RunTimeHelper + + +/////////////////////////////////////////////////////////////////////////////// +// ModuleHelper - helper functions for ATL (deprecated) + +namespace ModuleHelper +{ + inline HINSTANCE GetModuleInstance() + { + return ATL::_AtlBaseModule.GetModuleInstance(); + } + + inline HINSTANCE GetResourceInstance() + { + return ATL::_AtlBaseModule.GetResourceInstance(); + } + + inline void AddCreateWndData(ATL::_AtlCreateWndData* pData, void* pObject) + { + ATL::_AtlWinModule.AddCreateWndData(pData, pObject); + } + + inline void* ExtractCreateWndData() + { + return ATL::_AtlWinModule.ExtractCreateWndData(); + } +} // namespace ModuleHelper + + +/////////////////////////////////////////////////////////////////////////////// +// SecureHelper - WTL10 requires use of secure functions +// these are here only for compatibility with existing projects + +namespace SecureHelper +{ + inline void strcpyA_x(char* lpstrDest, size_t cchDest, const char* lpstrSrc) + { + ATL::Checked::strcpy_s(lpstrDest, cchDest, lpstrSrc); + } + + inline void strcpyW_x(wchar_t* lpstrDest, size_t cchDest, const wchar_t* lpstrSrc) + { + ATL::Checked::wcscpy_s(lpstrDest, cchDest, lpstrSrc); + } + + inline void strcpy_x(LPTSTR lpstrDest, size_t cchDest, LPCTSTR lpstrSrc) + { +#ifdef _UNICODE + strcpyW_x(lpstrDest, cchDest, lpstrSrc); +#else + strcpyA_x(lpstrDest, cchDest, lpstrSrc); +#endif + } + + inline errno_t strncpyA_x(char* lpstrDest, size_t cchDest, const char* lpstrSrc, size_t cchCount) + { + return ATL::Checked::strncpy_s(lpstrDest, cchDest, lpstrSrc, cchCount); + } + + inline errno_t strncpyW_x(wchar_t* lpstrDest, size_t cchDest, const wchar_t* lpstrSrc, size_t cchCount) + { + return ATL::Checked::wcsncpy_s(lpstrDest, cchDest, lpstrSrc, cchCount); + } + + inline errno_t strncpy_x(LPTSTR lpstrDest, size_t cchDest, LPCTSTR lpstrSrc, size_t cchCount) + { +#ifdef _UNICODE + return strncpyW_x(lpstrDest, cchDest, lpstrSrc, cchCount); +#else + return strncpyA_x(lpstrDest, cchDest, lpstrSrc, cchCount); +#endif + } + + inline void strcatA_x(char* lpstrDest, size_t cchDest, const char* lpstrSrc) + { + ATL::Checked::strcat_s(lpstrDest, cchDest, lpstrSrc); + } + + inline void strcatW_x(wchar_t* lpstrDest, size_t cchDest, const wchar_t* lpstrSrc) + { + ATL::Checked::wcscat_s(lpstrDest, cchDest, lpstrSrc); + } + + inline void strcat_x(LPTSTR lpstrDest, size_t cchDest, LPCTSTR lpstrSrc) + { +#ifdef _UNICODE + strcatW_x(lpstrDest, cchDest, lpstrSrc); +#else + strcatA_x(lpstrDest, cchDest, lpstrSrc); +#endif + } + + inline void memcpy_x(void* pDest, size_t cbDest, const void* pSrc, size_t cbSrc) + { + ATL::Checked::memcpy_s(pDest, cbDest, pSrc, cbSrc); + } + + inline void memmove_x(void* pDest, size_t cbDest, const void* pSrc, size_t cbSrc) + { + ATL::Checked::memmove_s(pDest, cbDest, pSrc, cbSrc); + } + + inline int vsprintf_x(LPTSTR lpstrBuff, size_t cchBuff, LPCTSTR lpstrFormat, va_list args) + { + return _vstprintf_s(lpstrBuff, cchBuff, lpstrFormat, args); + } + + inline int wvsprintf_x(LPTSTR lpstrBuff, size_t cchBuff, LPCTSTR lpstrFormat, va_list args) + { + return _vstprintf_s(lpstrBuff, cchBuff, lpstrFormat, args); + } + + inline int sprintf_x(LPTSTR lpstrBuff, size_t cchBuff, LPCTSTR lpstrFormat, ...) + { + va_list args; + va_start(args, lpstrFormat); + int nRes = vsprintf_x(lpstrBuff, cchBuff, lpstrFormat, args); + va_end(args); + return nRes; + } + + inline int wsprintf_x(LPTSTR lpstrBuff, size_t cchBuff, LPCTSTR lpstrFormat, ...) + { + va_list args; + va_start(args, lpstrFormat); + int nRes = wvsprintf_x(lpstrBuff, cchBuff, lpstrFormat, args); + va_end(args); + return nRes; + } +} // namespace SecureHelper + + +/////////////////////////////////////////////////////////////////////////////// +// MinCrtHelper - WTL10 doesn't support _ATL_MIN_CRT, +// these are here only for compatibility with existing projects + +namespace MinCrtHelper +{ + inline int _isspace(TCHAR ch) + { + return _istspace(ch); + } + + inline int _isdigit(TCHAR ch) + { + return _istdigit(ch); + } + + inline int _atoi(LPCTSTR str) + { + return _ttoi(str); + } + + inline LPCTSTR _strrchr(LPCTSTR str, TCHAR ch) + { + return _tcsrchr(str, ch); + } + + inline LPTSTR _strrchr(LPTSTR str, TCHAR ch) + { + return _tcsrchr(str, ch); + } +} // namespace MinCrtHelper + + +/////////////////////////////////////////////////////////////////////////////// +// GenericWndClass - generic window class usable for subclassing + +// Use in dialog templates to specify a placeholder to be subclassed +// Specify as a custom control with class name WTL_GenericWindow +// Call Rregister() before creating dialog (for example, in WinMain) +namespace GenericWndClass +{ + inline LPCTSTR GetName() + { + return _T("WTL_GenericWindow"); + } + + inline ATOM Register() + { + WNDCLASSEX wc = { sizeof(WNDCLASSEX) }; + wc.lpfnWndProc = ::DefWindowProc; + wc.hInstance = ModuleHelper::GetModuleInstance(); + wc.hCursor = ::LoadCursor(NULL, IDC_ARROW); + wc.hbrBackground = (HBRUSH)(COLOR_WINDOW + 1); + wc.lpszClassName = GetName(); + ATOM atom = ::RegisterClassEx(&wc); + ATLASSERT(atom != 0); + return atom; + } + + inline BOOL Unregister() // only needed for DLLs or tmp use + { + return ::UnregisterClass(GetName(), ModuleHelper::GetModuleInstance()); + } +} // namespace GenericWndClass + + +/////////////////////////////////////////////////////////////////////////////// +// CMessageFilter - Interface for message filter support + +class ATL_NO_VTABLE CMessageFilter +{ +public: + virtual BOOL PreTranslateMessage(MSG* pMsg) = 0; +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CIdleHandler - Interface for idle processing + +class ATL_NO_VTABLE CIdleHandler +{ +public: + virtual BOOL OnIdle() = 0; +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CMessageLoop - message loop implementation + +class CMessageLoop +{ +public: + ATL::CSimpleArray<CMessageFilter*> m_aMsgFilter; + ATL::CSimpleArray<CIdleHandler*> m_aIdleHandler; + MSG m_msg; + + CMessageLoop() + { + memset(&m_msg, 0, sizeof(m_msg)); + } + + virtual ~CMessageLoop() + { } + +// Message filter operations + BOOL AddMessageFilter(CMessageFilter* pMessageFilter) + { + return m_aMsgFilter.Add(pMessageFilter); + } + + BOOL RemoveMessageFilter(CMessageFilter* pMessageFilter) + { + return m_aMsgFilter.Remove(pMessageFilter); + } + +// Idle handler operations + BOOL AddIdleHandler(CIdleHandler* pIdleHandler) + { + return m_aIdleHandler.Add(pIdleHandler); + } + + BOOL RemoveIdleHandler(CIdleHandler* pIdleHandler) + { + return m_aIdleHandler.Remove(pIdleHandler); + } + +// message loop + int Run() + { + BOOL bDoIdle = TRUE; + int nIdleCount = 0; + BOOL bRet = FALSE; + + for(;;) + { + while(bDoIdle && !::PeekMessage(&m_msg, NULL, 0, 0, PM_NOREMOVE)) + { + if(!OnIdle(nIdleCount++)) + bDoIdle = FALSE; + } + + bRet = ::GetMessage(&m_msg, NULL, 0, 0); + + if(bRet == -1) + { + ATLTRACE2(atlTraceUI, 0, _T("::GetMessage returned -1 (error)\n")); + continue; // error, don't process + } + else if(!bRet) + { + ATLTRACE2(atlTraceUI, 0, _T("CMessageLoop::Run - exiting\n")); + break; // WM_QUIT, exit message loop + } + + if(!PreTranslateMessage(&m_msg)) + { + ::TranslateMessage(&m_msg); + ::DispatchMessage(&m_msg); + } + + if(IsIdleMessage(&m_msg)) + { + bDoIdle = TRUE; + nIdleCount = 0; + } + } + + return (int)m_msg.wParam; + } + +// Overrideables + // Override to change message filtering + virtual BOOL PreTranslateMessage(MSG* pMsg) + { + // loop backwards + for(int i = m_aMsgFilter.GetSize() - 1; i >= 0; i--) + { + CMessageFilter* pMessageFilter = m_aMsgFilter[i]; + if((pMessageFilter != NULL) && pMessageFilter->PreTranslateMessage(pMsg)) + return TRUE; + } + return FALSE; // not translated + } + + // override to change idle processing + virtual BOOL OnIdle(int /*nIdleCount*/) + { + for(int i = 0; i < m_aIdleHandler.GetSize(); i++) + { + CIdleHandler* pIdleHandler = m_aIdleHandler[i]; + if(pIdleHandler != NULL) + pIdleHandler->OnIdle(); + } + return FALSE; // don't continue + } + + // override to change non-idle messages + virtual BOOL IsIdleMessage(MSG* pMsg) const + { + // These messages should NOT cause idle processing + switch(pMsg->message) + { + case WM_MOUSEMOVE: + case WM_NCMOUSEMOVE: + case WM_PAINT: + case 0x0118: // WM_SYSTIMER (caret blink) + return FALSE; + } + + return TRUE; + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CStaticDataInitCriticalSectionLock and CWindowCreateCriticalSectionLock +// internal classes to manage critical sections for ATL (deprecated) + +class CStaticDataInitCriticalSectionLock +{ +public: + ATL::CComCritSecLock<ATL::CComCriticalSection> m_cslock; + + CStaticDataInitCriticalSectionLock() : m_cslock(ATL::_pAtlModule->m_csStaticDataInitAndTypeInfo, false) + { } + + HRESULT Lock() + { + return m_cslock.Lock(); + } + + void Unlock() + { + m_cslock.Unlock(); + } +}; + + +class CWindowCreateCriticalSectionLock +{ +public: + ATL::CComCritSecLock<ATL::CComCriticalSection> m_cslock; + + CWindowCreateCriticalSectionLock() : m_cslock(ATL::_AtlWinModule.m_csWindowCreate, false) + { } + + HRESULT Lock() + { + return m_cslock.Lock(); + } + + void Unlock() + { + m_cslock.Unlock(); + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CAppModule - module class for an application + +#if (_MSC_VER == 1400) // VS2005 + #pragma warning(push) + #pragma warning(disable : 4244) + #pragma warning(disable : 4312) +#endif + +class CAppModule : public ATL::CComModule +{ +public: + DWORD m_dwMainThreadID; + ATL::CSimpleMap<DWORD, CMessageLoop*>* m_pMsgLoopMap; + ATL::CSimpleArray<HWND>* m_pSettingChangeNotify; + + CAppModule() : m_dwMainThreadID(0), m_pMsgLoopMap(NULL), m_pSettingChangeNotify(NULL) + { } + +// Overrides of CComModule::Init and Term + HRESULT Init(ATL::_ATL_OBJMAP_ENTRY* pObjMap, HINSTANCE hInstance, const GUID* pLibID = NULL) + { + HRESULT hRet = CComModule::Init(pObjMap, hInstance, pLibID); + if(FAILED(hRet)) + return hRet; + + m_dwMainThreadID = ::GetCurrentThreadId(); + typedef ATL::CSimpleMap<DWORD, CMessageLoop*> _mapClass; + m_pMsgLoopMap = NULL; + ATLTRY(m_pMsgLoopMap = new _mapClass); + if(m_pMsgLoopMap == NULL) + return E_OUTOFMEMORY; + m_pSettingChangeNotify = NULL; + + return hRet; + } + + void Term() + { + TermSettingChangeNotify(); + delete m_pMsgLoopMap; + CComModule::Term(); + } + +// Message loop map methods + BOOL AddMessageLoop(CMessageLoop* pMsgLoop) + { + CStaticDataInitCriticalSectionLock lock; + if(FAILED(lock.Lock())) + { + ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::AddMessageLoop.\n")); + ATLASSERT(FALSE); + return FALSE; + } + + ATLASSERT(pMsgLoop != NULL); + ATLASSERT(m_pMsgLoopMap->Lookup(::GetCurrentThreadId()) == NULL); // not in map yet + + BOOL bRet = m_pMsgLoopMap->Add(::GetCurrentThreadId(), pMsgLoop); + + lock.Unlock(); + + return bRet; + } + + BOOL RemoveMessageLoop() + { + CStaticDataInitCriticalSectionLock lock; + if(FAILED(lock.Lock())) + { + ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::RemoveMessageLoop.\n")); + ATLASSERT(FALSE); + return FALSE; + } + + BOOL bRet = m_pMsgLoopMap->Remove(::GetCurrentThreadId()); + + lock.Unlock(); + + return bRet; + } + + CMessageLoop* GetMessageLoop(DWORD dwThreadID = ::GetCurrentThreadId()) const + { + CStaticDataInitCriticalSectionLock lock; + if(FAILED(lock.Lock())) + { + ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::GetMessageLoop.\n")); + ATLASSERT(FALSE); + return NULL; + } + + CMessageLoop* pLoop = m_pMsgLoopMap->Lookup(dwThreadID); + + lock.Unlock(); + + return pLoop; + } + +// Setting change notify methods + // Note: Call this from the main thread for MSDI apps + BOOL InitSettingChangeNotify(DLGPROC pfnDlgProc = _SettingChangeDlgProc) + { + CStaticDataInitCriticalSectionLock lock; + if(FAILED(lock.Lock())) + { + ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::InitSettingChangeNotify.\n")); + ATLASSERT(FALSE); + return FALSE; + } + + if(m_pSettingChangeNotify == NULL) + { + typedef ATL::CSimpleArray<HWND> _notifyClass; + ATLTRY(m_pSettingChangeNotify = new _notifyClass); + ATLASSERT(m_pSettingChangeNotify != NULL); + } + + BOOL bRet = (m_pSettingChangeNotify != NULL); + if(bRet && (m_pSettingChangeNotify->GetSize() == 0)) + { + // init everything + _ATL_EMPTY_DLGTEMPLATE templ; + HWND hNtfWnd = ::CreateDialogIndirect(GetModuleInstance(), &templ, NULL, pfnDlgProc); + ATLASSERT(::IsWindow(hNtfWnd)); + if(::IsWindow(hNtfWnd)) + { + ::SetWindowLongPtr(hNtfWnd, GWLP_USERDATA, (LONG_PTR)this); + bRet = m_pSettingChangeNotify->Add(hNtfWnd); + } + else + { + bRet = FALSE; + } + } + + lock.Unlock(); + + return bRet; + } + + void TermSettingChangeNotify() + { + CStaticDataInitCriticalSectionLock lock; + if(FAILED(lock.Lock())) + { + ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::TermSettingChangeNotify.\n")); + ATLASSERT(FALSE); + return; + } + + if((m_pSettingChangeNotify != NULL) && (m_pSettingChangeNotify->GetSize() > 0)) + ::DestroyWindow((*m_pSettingChangeNotify)[0]); + delete m_pSettingChangeNotify; + m_pSettingChangeNotify = NULL; + + lock.Unlock(); + } + + BOOL AddSettingChangeNotify(HWND hWnd) + { + CStaticDataInitCriticalSectionLock lock; + if(FAILED(lock.Lock())) + { + ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::AddSettingChangeNotify.\n")); + ATLASSERT(FALSE); + return FALSE; + } + + ATLASSERT(::IsWindow(hWnd)); + BOOL bRet = FALSE; + if(InitSettingChangeNotify() != FALSE) + bRet = m_pSettingChangeNotify->Add(hWnd); + + lock.Unlock(); + + return bRet; + } + + BOOL RemoveSettingChangeNotify(HWND hWnd) + { + CStaticDataInitCriticalSectionLock lock; + if(FAILED(lock.Lock())) + { + ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::RemoveSettingChangeNotify.\n")); + ATLASSERT(FALSE); + return FALSE; + } + + BOOL bRet = FALSE; + if(m_pSettingChangeNotify != NULL) + bRet = m_pSettingChangeNotify->Remove(hWnd); + + lock.Unlock(); + + return bRet; + } + +// Implementation - setting change notify dialog template and dialog procedure + struct _ATL_EMPTY_DLGTEMPLATE : DLGTEMPLATE + { + _ATL_EMPTY_DLGTEMPLATE() + { + memset(this, 0, sizeof(_ATL_EMPTY_DLGTEMPLATE)); + style = WS_POPUP; + } + WORD wMenu, wClass, wTitle; + }; + + static INT_PTR CALLBACK _SettingChangeDlgProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam) + { + if(uMsg == WM_SETTINGCHANGE) + { + CAppModule* pModule = (CAppModule*)::GetWindowLongPtr(hWnd, GWLP_USERDATA); + ATLASSERT(pModule != NULL); + ATLASSERT(pModule->m_pSettingChangeNotify != NULL); + const UINT uTimeout = 1500; // ms + for(int i = 1; i < pModule->m_pSettingChangeNotify->GetSize(); i++) + ::SendMessageTimeout((*pModule->m_pSettingChangeNotify)[i], uMsg, wParam, lParam, SMTO_ABORTIFHUNG, uTimeout, NULL); + + return TRUE; + } + + return FALSE; + } +}; + +#if (_MSC_VER == 1400) // VS2005 + #pragma warning(pop) +#endif + + +/////////////////////////////////////////////////////////////////////////////// +// CServerAppModule - module class for a COM server application + +class CServerAppModule : public CAppModule +{ +public: + HANDLE m_hEventShutdown; + bool m_bActivity; + DWORD m_dwTimeOut; + DWORD m_dwPause; + + CServerAppModule() : m_hEventShutdown(NULL), m_bActivity(false), m_dwTimeOut(5000), m_dwPause(1000) + { } + +// Override of CAppModule::Init + HRESULT Init(ATL::_ATL_OBJMAP_ENTRY* pObjMap, HINSTANCE hInstance, const GUID* pLibID = NULL) + { + m_dwTimeOut = 5000; + m_dwPause = 1000; + return CAppModule::Init(pObjMap, hInstance, pLibID); + } + + void Term() + { + if((m_hEventShutdown != NULL) && ::CloseHandle(m_hEventShutdown)) + m_hEventShutdown = NULL; + CAppModule::Term(); + } + +// COM Server methods + LONG Unlock() throw() + { + LONG lRet = CComModule::Unlock(); + if(lRet == 0) + { + m_bActivity = true; + ::SetEvent(m_hEventShutdown); // tell monitor that we transitioned to zero + } + return lRet; + } + + void MonitorShutdown() + { + for(;;) + { + ::WaitForSingleObject(m_hEventShutdown, INFINITE); + DWORD dwWait = 0; + do + { + m_bActivity = false; + dwWait = ::WaitForSingleObject(m_hEventShutdown, m_dwTimeOut); + } + while(dwWait == WAIT_OBJECT_0); + // timed out + if(!m_bActivity && (m_nLockCnt == 0)) // if no activity let's really bail + { +#if defined(_WIN32_DCOM) && defined(_ATL_FREE_THREADED) + ::CoSuspendClassObjects(); + if(!m_bActivity && (m_nLockCnt == 0)) +#endif + break; + } + } + // This handle should be valid now. If it isn't, + // check if _Module.Term was called first (it shouldn't) + if(::CloseHandle(m_hEventShutdown)) + m_hEventShutdown = NULL; + ::PostThreadMessage(m_dwMainThreadID, WM_QUIT, 0, 0); + } + + bool StartMonitor() + { + m_hEventShutdown = ::CreateEvent(NULL, false, false, NULL); + if(m_hEventShutdown == NULL) + return false; + DWORD dwThreadID = 0; +#ifdef _MT + HANDLE hThread = (HANDLE)_beginthreadex(NULL, 0, (UINT (WINAPI*)(void*))MonitorProc, this, 0, (UINT*)&dwThreadID); +#else + HANDLE hThread = ::CreateThread(NULL, 0, MonitorProc, this, 0, &dwThreadID); +#endif + bool bRet = (hThread != NULL); + if(bRet) + ::CloseHandle(hThread); + return bRet; + } + + static DWORD WINAPI MonitorProc(void* pv) + { + CServerAppModule* p = (CServerAppModule*)pv; + p->MonitorShutdown(); + return 0; + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CRegKeyEx - not used any more, here only for compatibility with old projects + +typedef ATL::CRegKey CRegKeyEx; + +} // namespace WTL + + +/////////////////////////////////////////////////////////////////////////////// +// CString forward reference (enables CString use in atluser.h and atlgdi.h) + +#if (defined(_WTL_USE_CSTRING) || defined(_WTL_FORWARD_DECLARE_CSTRING)) && !defined(__ATLSTR_H__) + #include <atlstr.h> +#endif + +// CString namespace +#define _CSTRING_NS ATL + +// Type classes namespace +#define _WTYPES_NS + + +/////////////////////////////////////////////////////////////////////////////// +// General DLL version helpers (removed in ATL11) + +#if (_ATL_VER >= 0x0B00) + +namespace ATL +{ + +inline HRESULT AtlGetDllVersion(HINSTANCE hInstDLL, DLLVERSIONINFO* pDllVersionInfo) +{ + ATLASSERT(pDllVersionInfo != NULL); + if(pDllVersionInfo == NULL) + return E_INVALIDARG; + + // We must get this function explicitly because some DLLs don't implement it. + DLLGETVERSIONPROC pfnDllGetVersion = (DLLGETVERSIONPROC)::GetProcAddress(hInstDLL, "DllGetVersion"); + if(pfnDllGetVersion == NULL) + return E_NOTIMPL; + + return (*pfnDllGetVersion)(pDllVersionInfo); +} + +inline HRESULT AtlGetDllVersion(LPCTSTR lpstrDllName, DLLVERSIONINFO* pDllVersionInfo) +{ + HINSTANCE hInstDLL = ::LoadLibrary(lpstrDllName); + if(hInstDLL == NULL) + return E_FAIL; + HRESULT hRet = AtlGetDllVersion(hInstDLL, pDllVersionInfo); + ::FreeLibrary(hInstDLL); + return hRet; +} + +// Common Control Versions: +// Win95/WinNT 4.0 maj=4 min=00 +// IE 3.x maj=4 min=70 +// IE 4.0 maj=4 min=71 +inline HRESULT AtlGetCommCtrlVersion(LPDWORD pdwMajor, LPDWORD pdwMinor) +{ + ATLASSERT((pdwMajor != NULL) && (pdwMinor != NULL)); + if((pdwMajor == NULL) || (pdwMinor == NULL)) + return E_INVALIDARG; + + DLLVERSIONINFO dvi; + ::ZeroMemory(&dvi, sizeof(dvi)); + dvi.cbSize = sizeof(dvi); + HRESULT hRet = AtlGetDllVersion(_T("comctl32.dll"), &dvi); + + if(SUCCEEDED(hRet)) + { + *pdwMajor = dvi.dwMajorVersion; + *pdwMinor = dvi.dwMinorVersion; + } + else if(hRet == E_NOTIMPL) + { + // If DllGetVersion is not there, then the DLL is a version + // previous to the one shipped with IE 3.x + *pdwMajor = 4; + *pdwMinor = 0; + hRet = S_OK; + } + + return hRet; +} + +// Shell Versions: +// Win95/WinNT 4.0 maj=4 min=00 +// IE 3.x, IE 4.0 without Web Integrated Desktop maj=4 min=00 +// IE 4.0 with Web Integrated Desktop maj=4 min=71 +// IE 4.01 with Web Integrated Desktop maj=4 min=72 +inline HRESULT AtlGetShellVersion(LPDWORD pdwMajor, LPDWORD pdwMinor) +{ + ATLASSERT((pdwMajor != NULL) && (pdwMinor != NULL)); + if((pdwMajor == NULL) || (pdwMinor == NULL)) + return E_INVALIDARG; + + DLLVERSIONINFO dvi; + ::ZeroMemory(&dvi, sizeof(dvi)); + dvi.cbSize = sizeof(dvi); + HRESULT hRet = AtlGetDllVersion(_T("shell32.dll"), &dvi); + + if(SUCCEEDED(hRet)) + { + *pdwMajor = dvi.dwMajorVersion; + *pdwMinor = dvi.dwMinorVersion; + } + else if(hRet == E_NOTIMPL) + { + // If DllGetVersion is not there, then the DLL is a version + // previous to the one shipped with IE 4.x + *pdwMajor = 4; + *pdwMinor = 0; + hRet = S_OK; + } + + return hRet; +} + +} // namespace ATL + +#endif // (_ATL_VER >= 0x0B00) + + +// These are always included +#include "atlwinx.h" +#include "atluser.h" +#include "atlgdi.h" + +#ifndef _WTL_NO_AUTOMATIC_NAMESPACE +using namespace WTL; +#endif // !_WTL_NO_AUTOMATIC_NAMESPACE + +#endif // __ATLAPP_H__ diff --git a/Examples/WhisperDesktop/Utils/WTL/atlcrack.h b/Examples/WhisperDesktop/Utils/WTL/atlcrack.h new file mode 100644 index 0000000..da6a896 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/WTL/atlcrack.h @@ -0,0 +1,2480 @@ +// Windows Template Library - WTL version 10.0 +// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved. +// +// This file is a part of the Windows Template Library. +// The use and distribution terms for this software are covered by the +// Microsoft Public License (http://opensource.org/licenses/MS-PL) +// which can be found in the file MS-PL.txt at the root folder. + +#ifndef __ATLCRACK_H__ +#define __ATLCRACK_H__ + +#pragma once + +#ifndef __ATLAPP_H__ + #error atlcrack.h requires atlapp.h to be included first +#endif + + +/////////////////////////////////////////////////////////////////////////////// +// Message map macro for cracked handlers + +// Note about message maps with cracked handlers: +// You can use BEGIN_MSG_MAP for classes that derive from CWindowImpl/CDialogImpl, +// but must use BEGIN_MSG_MAP_EX for classes that don't. + +#define BEGIN_MSG_MAP_EX(theClass) \ +public: \ + BOOL m_bMsgHandled; \ + /* "handled" management for cracked handlers */ \ + BOOL IsMsgHandled() const \ + { \ + return m_bMsgHandled; \ + } \ + void SetMsgHandled(BOOL bHandled) \ + { \ + m_bMsgHandled = bHandled; \ + } \ + BOOL ProcessWindowMessage(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam, LRESULT& lResult, DWORD dwMsgMapID = 0) \ + { \ + BOOL bOldMsgHandled = m_bMsgHandled; \ + BOOL bRet = _ProcessWindowMessage(hWnd, uMsg, wParam, lParam, lResult, dwMsgMapID); \ + m_bMsgHandled = bOldMsgHandled; \ + return bRet; \ + } \ + BOOL _ProcessWindowMessage(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam, LRESULT& lResult, DWORD dwMsgMapID) \ + { \ + BOOL bHandled = TRUE; \ + (hWnd); \ + (uMsg); \ + (wParam); \ + (lParam); \ + (lResult); \ + (bHandled); \ + switch(dwMsgMapID) \ + { \ + case 0: + + +/////////////////////////////////////////////////////////////////////////////// +// Standard Windows message macros + +// int OnCreate(LPCREATESTRUCT lpCreateStruct) +#define MSG_WM_CREATE(func) \ + if (uMsg == WM_CREATE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((LPCREATESTRUCT)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnInitDialog(CWindow wndFocus, LPARAM lInitParam) +#define MSG_WM_INITDIALOG(func) \ + if (uMsg == WM_INITDIALOG) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HWND)wParam, lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnCopyData(CWindow wnd, PCOPYDATASTRUCT pCopyDataStruct) +#define MSG_WM_COPYDATA(func) \ + if (uMsg == WM_COPYDATA) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HWND)wParam, (PCOPYDATASTRUCT)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDestroy() +#define MSG_WM_DESTROY(func) \ + if (uMsg == WM_DESTROY) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMove(CPoint ptPos) +#define MSG_WM_MOVE(func) \ + if (uMsg == WM_MOVE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSize(UINT nType, CSize size) +#define MSG_WM_SIZE(func) \ + if (uMsg == WM_SIZE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CSize(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnActivate(UINT nState, BOOL bMinimized, CWindow wndOther) +#define MSG_WM_ACTIVATE(func) \ + if (uMsg == WM_ACTIVATE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)LOWORD(wParam), (BOOL)HIWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSetFocus(CWindow wndOld) +#define MSG_WM_SETFOCUS(func) \ + if (uMsg == WM_SETFOCUS) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnKillFocus(CWindow wndFocus) +#define MSG_WM_KILLFOCUS(func) \ + if (uMsg == WM_KILLFOCUS) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnEnable(BOOL bEnable) +#define MSG_WM_ENABLE(func) \ + if (uMsg == WM_ENABLE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnPaint(CDCHandle dc) +#define MSG_WM_PAINT(func) \ + if (uMsg == WM_PAINT) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HDC)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnClose() +#define MSG_WM_CLOSE(func) \ + if (uMsg == WM_CLOSE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnQueryEndSession(UINT nSource, UINT uLogOff) +#define MSG_WM_QUERYENDSESSION(func) \ + if (uMsg == WM_QUERYENDSESSION) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)wParam, (UINT)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnQueryOpen() +#define MSG_WM_QUERYOPEN(func) \ + if (uMsg == WM_QUERYOPEN) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func(); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnEraseBkgnd(CDCHandle dc) +#define MSG_WM_ERASEBKGND(func) \ + if (uMsg == WM_ERASEBKGND) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSysColorChange() +#define MSG_WM_SYSCOLORCHANGE(func) \ + if (uMsg == WM_SYSCOLORCHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnEndSession(BOOL bEnding, UINT uLogOff) +#define MSG_WM_ENDSESSION(func) \ + if (uMsg == WM_ENDSESSION) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)wParam, (UINT)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnShowWindow(BOOL bShow, UINT nStatus) +#define MSG_WM_SHOWWINDOW(func) \ + if (uMsg == WM_SHOWWINDOW) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)wParam, (int)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnCtlColorEdit(CDCHandle dc, CEdit edit) +#define MSG_WM_CTLCOLOREDIT(func) \ + if (uMsg == WM_CTLCOLOREDIT) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnCtlColorListBox(CDCHandle dc, CListBox listBox) +#define MSG_WM_CTLCOLORLISTBOX(func) \ + if (uMsg == WM_CTLCOLORLISTBOX) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnCtlColorBtn(CDCHandle dc, CButton button) +#define MSG_WM_CTLCOLORBTN(func) \ + if (uMsg == WM_CTLCOLORBTN) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnCtlColorDlg(CDCHandle dc, CWindow wnd) +#define MSG_WM_CTLCOLORDLG(func) \ + if (uMsg == WM_CTLCOLORDLG) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnCtlColorScrollBar(CDCHandle dc, CScrollBar scrollBar) +#define MSG_WM_CTLCOLORSCROLLBAR(func) \ + if (uMsg == WM_CTLCOLORSCROLLBAR) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnCtlColorStatic(CDCHandle dc, CStatic wndStatic) +#define MSG_WM_CTLCOLORSTATIC(func) \ + if (uMsg == WM_CTLCOLORSTATIC) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSettingChange(UINT uFlags, LPCTSTR lpszSection) +#define MSG_WM_SETTINGCHANGE(func) \ + if (uMsg == WM_SETTINGCHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPCTSTR)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDevModeChange(LPCTSTR lpDeviceName) +#define MSG_WM_DEVMODECHANGE(func) \ + if (uMsg == WM_DEVMODECHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((LPCTSTR)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnActivateApp(BOOL bActive, DWORD dwThreadID) +#define MSG_WM_ACTIVATEAPP(func) \ + if (uMsg == WM_ACTIVATEAPP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)wParam, (DWORD)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnFontChange() +#define MSG_WM_FONTCHANGE(func) \ + if (uMsg == WM_FONTCHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnTimeChange() +#define MSG_WM_TIMECHANGE(func) \ + if (uMsg == WM_TIMECHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCancelMode() +#define MSG_WM_CANCELMODE(func) \ + if (uMsg == WM_CANCELMODE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnSetCursor(CWindow wnd, UINT nHitTest, UINT message) +#define MSG_WM_SETCURSOR(func) \ + if (uMsg == WM_SETCURSOR) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HWND)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnMouseActivate(CWindow wndTopLevel, UINT nHitTest, UINT message) +#define MSG_WM_MOUSEACTIVATE(func) \ + if (uMsg == WM_MOUSEACTIVATE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HWND)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnChildActivate() +#define MSG_WM_CHILDACTIVATE(func) \ + if (uMsg == WM_CHILDACTIVATE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnGetMinMaxInfo(LPMINMAXINFO lpMMI) +#define MSG_WM_GETMINMAXINFO(func) \ + if (uMsg == WM_GETMINMAXINFO) \ + { \ + this->SetMsgHandled(TRUE); \ + func((LPMINMAXINFO)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnIconEraseBkgnd(CDCHandle dc) +#define MSG_WM_ICONERASEBKGND(func) \ + if (uMsg == WM_ICONERASEBKGND) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HDC)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSpoolerStatus(UINT nStatus, UINT nJobs) +#define MSG_WM_SPOOLERSTATUS(func) \ + if (uMsg == WM_SPOOLERSTATUS) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (UINT)LOWORD(lParam)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDrawItem(int nIDCtl, LPDRAWITEMSTRUCT lpDrawItemStruct) +#define MSG_WM_DRAWITEM(func) \ + if (uMsg == WM_DRAWITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPDRAWITEMSTRUCT)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMeasureItem(int nIDCtl, LPMEASUREITEMSTRUCT lpMeasureItemStruct) +#define MSG_WM_MEASUREITEM(func) \ + if (uMsg == WM_MEASUREITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPMEASUREITEMSTRUCT)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDeleteItem(int nIDCtl, LPDELETEITEMSTRUCT lpDeleteItemStruct) +#define MSG_WM_DELETEITEM(func) \ + if (uMsg == WM_DELETEITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPDELETEITEMSTRUCT)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +//int OnCharToItem(UINT nChar, UINT nIndex, CListBox listBox) +#define MSG_WM_CHARTOITEM(func) \ + if (uMsg == WM_CHARTOITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnVKeyToItem(UINT nKey, UINT nIndex, CListBox listBox) +#define MSG_WM_VKEYTOITEM(func) \ + if (uMsg == WM_VKEYTOITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HCURSOR OnQueryDragIcon() +#define MSG_WM_QUERYDRAGICON(func) \ + if (uMsg == WM_QUERYDRAGICON) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func(); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnCompareItem(int nIDCtl, LPCOMPAREITEMSTRUCT lpCompareItemStruct) +#define MSG_WM_COMPAREITEM(func) \ + if (uMsg == WM_COMPAREITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)wParam, (LPCOMPAREITEMSTRUCT)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCompacting(UINT nCpuTime) +#define MSG_WM_COMPACTING(func) \ + if (uMsg == WM_COMPACTING) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnNcCreate(LPCREATESTRUCT lpCreateStruct) +#define MSG_WM_NCCREATE(func) \ + if (uMsg == WM_NCCREATE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((LPCREATESTRUCT)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcDestroy() +#define MSG_WM_NCDESTROY(func) \ + if (uMsg == WM_NCDESTROY) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnNcCalcSize(BOOL bCalcValidRects, LPARAM lParam) +#define MSG_WM_NCCALCSIZE(func) \ + if (uMsg == WM_NCCALCSIZE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((BOOL)wParam, lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// UINT OnNcHitTest(CPoint point) +#define MSG_WM_NCHITTEST(func) \ + if (uMsg == WM_NCHITTEST) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func(::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcPaint(CRgnHandle rgn) +#define MSG_WM_NCPAINT(func) \ + if (uMsg == WM_NCPAINT) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HRGN)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnNcActivate(BOOL bActive) +#define MSG_WM_NCACTIVATE(func) \ + if (uMsg == WM_NCACTIVATE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((BOOL)wParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// UINT OnGetDlgCode(LPMSG lpMsg) +#define MSG_WM_GETDLGCODE(func) \ + if (uMsg == WM_GETDLGCODE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((LPMSG)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcMouseMove(UINT nHitTest, CPoint point) +#define MSG_WM_NCMOUSEMOVE(func) \ + if (uMsg == WM_NCMOUSEMOVE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcLButtonDown(UINT nHitTest, CPoint point) +#define MSG_WM_NCLBUTTONDOWN(func) \ + if (uMsg == WM_NCLBUTTONDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcLButtonUp(UINT nHitTest, CPoint point) +#define MSG_WM_NCLBUTTONUP(func) \ + if (uMsg == WM_NCLBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcLButtonDblClk(UINT nHitTest, CPoint point) +#define MSG_WM_NCLBUTTONDBLCLK(func) \ + if (uMsg == WM_NCLBUTTONDBLCLK) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcRButtonDown(UINT nHitTest, CPoint point) +#define MSG_WM_NCRBUTTONDOWN(func) \ + if (uMsg == WM_NCRBUTTONDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcRButtonUp(UINT nHitTest, CPoint point) +#define MSG_WM_NCRBUTTONUP(func) \ + if (uMsg == WM_NCRBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcRButtonDblClk(UINT nHitTest, CPoint point) +#define MSG_WM_NCRBUTTONDBLCLK(func) \ + if (uMsg == WM_NCRBUTTONDBLCLK) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcMButtonDown(UINT nHitTest, CPoint point) +#define MSG_WM_NCMBUTTONDOWN(func) \ + if (uMsg == WM_NCMBUTTONDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcMButtonUp(UINT nHitTest, CPoint point) +#define MSG_WM_NCMBUTTONUP(func) \ + if (uMsg == WM_NCMBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcMButtonDblClk(UINT nHitTest, CPoint point) +#define MSG_WM_NCMBUTTONDBLCLK(func) \ + if (uMsg == WM_NCMBUTTONDBLCLK) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnKeyDown(UINT nChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_KEYDOWN(func) \ + if (uMsg == WM_KEYDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnKeyUp(UINT nChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_KEYUP(func) \ + if (uMsg == WM_KEYUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnChar(TCHAR chChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_CHAR(func) \ + if (uMsg == WM_CHAR) \ + { \ + this->SetMsgHandled(TRUE); \ + func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDeadChar(TCHAR chChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_DEADCHAR(func) \ + if (uMsg == WM_DEADCHAR) \ + { \ + this->SetMsgHandled(TRUE); \ + func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSysKeyDown(UINT nChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_SYSKEYDOWN(func) \ + if (uMsg == WM_SYSKEYDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSysKeyUp(UINT nChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_SYSKEYUP(func) \ + if (uMsg == WM_SYSKEYUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSysChar(TCHAR chChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_SYSCHAR(func) \ + if (uMsg == WM_SYSCHAR) \ + { \ + this->SetMsgHandled(TRUE); \ + func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSysDeadChar(TCHAR chChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_SYSDEADCHAR(func) \ + if (uMsg == WM_SYSDEADCHAR) \ + { \ + this->SetMsgHandled(TRUE); \ + func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSysCommand(UINT nID, CPoint point) +#define MSG_WM_SYSCOMMAND(func) \ + if (uMsg == WM_SYSCOMMAND) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnTCard(UINT idAction, DWORD dwActionData) +#define MSG_WM_TCARD(func) \ + if (uMsg == WM_TCARD) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (DWORD)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnTimer(UINT_PTR nIDEvent) +#define MSG_WM_TIMER(func) \ + if (uMsg == WM_TIMER) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT_PTR)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnHScroll(UINT nSBCode, UINT nPos, CScrollBar pScrollBar) +#define MSG_WM_HSCROLL(func) \ + if (uMsg == WM_HSCROLL) \ + { \ + this->SetMsgHandled(TRUE); \ + func((int)LOWORD(wParam), (short)HIWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnVScroll(UINT nSBCode, UINT nPos, CScrollBar pScrollBar) +#define MSG_WM_VSCROLL(func) \ + if (uMsg == WM_VSCROLL) \ + { \ + this->SetMsgHandled(TRUE); \ + func((int)LOWORD(wParam), (short)HIWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnInitMenu(CMenuHandle menu) +#define MSG_WM_INITMENU(func) \ + if (uMsg == WM_INITMENU) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HMENU)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnInitMenuPopup(CMenuHandle menuPopup, UINT nIndex, BOOL bSysMenu) +#define MSG_WM_INITMENUPOPUP(func) \ + if (uMsg == WM_INITMENUPOPUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HMENU)wParam, (UINT)LOWORD(lParam), (BOOL)HIWORD(lParam)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMenuSelect(UINT nItemID, UINT nFlags, CMenuHandle menu) +#define MSG_WM_MENUSELECT(func) \ + if (uMsg == WM_MENUSELECT) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HMENU)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnMenuChar(UINT nChar, UINT nFlags, CMenuHandle menu) +#define MSG_WM_MENUCHAR(func) \ + if (uMsg == WM_MENUCHAR) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((TCHAR)LOWORD(wParam), (UINT)HIWORD(wParam), (HMENU)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnNotify(int idCtrl, LPNMHDR pnmh) +#define MSG_WM_NOTIFY(func) \ + if (uMsg == WM_NOTIFY) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((int)wParam, (LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnEnterIdle(UINT nWhy, CWindow wndWho) +#define MSG_WM_ENTERIDLE(func) \ + if (uMsg == WM_ENTERIDLE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMouseMove(UINT nFlags, CPoint point) +#define MSG_WM_MOUSEMOVE(func) \ + if (uMsg == WM_MOUSEMOVE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnMouseWheel(UINT nFlags, short zDelta, CPoint pt) +#define MSG_WM_MOUSEWHEEL(func) \ + if (uMsg == WM_MOUSEWHEEL) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)LOWORD(wParam), (short)HIWORD(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnLButtonDown(UINT nFlags, CPoint point) +#define MSG_WM_LBUTTONDOWN(func) \ + if (uMsg == WM_LBUTTONDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnLButtonUp(UINT nFlags, CPoint point) +#define MSG_WM_LBUTTONUP(func) \ + if (uMsg == WM_LBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnLButtonDblClk(UINT nFlags, CPoint point) +#define MSG_WM_LBUTTONDBLCLK(func) \ + if (uMsg == WM_LBUTTONDBLCLK) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnRButtonDown(UINT nFlags, CPoint point) +#define MSG_WM_RBUTTONDOWN(func) \ + if (uMsg == WM_RBUTTONDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnRButtonUp(UINT nFlags, CPoint point) +#define MSG_WM_RBUTTONUP(func) \ + if (uMsg == WM_RBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnRButtonDblClk(UINT nFlags, CPoint point) +#define MSG_WM_RBUTTONDBLCLK(func) \ + if (uMsg == WM_RBUTTONDBLCLK) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMButtonDown(UINT nFlags, CPoint point) +#define MSG_WM_MBUTTONDOWN(func) \ + if (uMsg == WM_MBUTTONDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMButtonUp(UINT nFlags, CPoint point) +#define MSG_WM_MBUTTONUP(func) \ + if (uMsg == WM_MBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMButtonDblClk(UINT nFlags, CPoint point) +#define MSG_WM_MBUTTONDBLCLK(func) \ + if (uMsg == WM_MBUTTONDBLCLK) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnParentNotify(UINT message, UINT nChildID, LPARAM lParam) +#define MSG_WM_PARENTNOTIFY(func) \ + if (uMsg == WM_PARENTNOTIFY) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMDIActivate(CWindow wndDeactivate, CWindow wndActivate) +#define MSG_WM_MDIACTIVATE(func) \ + if (uMsg == WM_MDIACTIVATE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam, (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnRenderFormat(UINT nFormat) +#define MSG_WM_RENDERFORMAT(func) \ + if (uMsg == WM_RENDERFORMAT) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnRenderAllFormats() +#define MSG_WM_RENDERALLFORMATS(func) \ + if (uMsg == WM_RENDERALLFORMATS) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDestroyClipboard() +#define MSG_WM_DESTROYCLIPBOARD(func) \ + if (uMsg == WM_DESTROYCLIPBOARD) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDrawClipboard() +#define MSG_WM_DRAWCLIPBOARD(func) \ + if (uMsg == WM_DRAWCLIPBOARD) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnPaintClipboard(CWindow wndViewer, const LPPAINTSTRUCT lpPaintStruct) +#define MSG_WM_PAINTCLIPBOARD(func) \ + if (uMsg == WM_PAINTCLIPBOARD) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam, (const LPPAINTSTRUCT)::GlobalLock((HGLOBAL)lParam)); \ + ::GlobalUnlock((HGLOBAL)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnVScrollClipboard(CWindow wndViewer, UINT nSBCode, UINT nPos) +#define MSG_WM_VSCROLLCLIPBOARD(func) \ + if (uMsg == WM_VSCROLLCLIPBOARD) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnContextMenu(CWindow wnd, CPoint point) +#define MSG_WM_CONTEXTMENU(func) \ + if (uMsg == WM_CONTEXTMENU) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSizeClipboard(CWindow wndViewer, const LPRECT lpRect) +#define MSG_WM_SIZECLIPBOARD(func) \ + if (uMsg == WM_SIZECLIPBOARD) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam, (const LPRECT)::GlobalLock((HGLOBAL)lParam)); \ + ::GlobalUnlock((HGLOBAL)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnAskCbFormatName(UINT nMaxCount, LPTSTR lpszString) +#define MSG_WM_ASKCBFORMATNAME(func) \ + if (uMsg == WM_ASKCBFORMATNAME) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPTSTR)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnChangeCbChain(CWindow wndRemove, CWindow wndAfter) +#define MSG_WM_CHANGECBCHAIN(func) \ + if (uMsg == WM_CHANGECBCHAIN) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam, (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnHScrollClipboard(CWindow wndViewer, UINT nSBCode, UINT nPos) +#define MSG_WM_HSCROLLCLIPBOARD(func) \ + if (uMsg == WM_HSCROLLCLIPBOARD) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnQueryNewPalette() +#define MSG_WM_QUERYNEWPALETTE(func) \ + if (uMsg == WM_QUERYNEWPALETTE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func(); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnPaletteChanged(CWindow wndFocus) +#define MSG_WM_PALETTECHANGED(func) \ + if (uMsg == WM_PALETTECHANGED) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnPaletteIsChanging(CWindow wndPalChg) +#define MSG_WM_PALETTEISCHANGING(func) \ + if (uMsg == WM_PALETTEISCHANGING) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDropFiles(HDROP hDropInfo) +#define MSG_WM_DROPFILES(func) \ + if (uMsg == WM_DROPFILES) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HDROP)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnWindowPosChanging(LPWINDOWPOS lpWndPos) +#define MSG_WM_WINDOWPOSCHANGING(func) \ + if (uMsg == WM_WINDOWPOSCHANGING) \ + { \ + this->SetMsgHandled(TRUE); \ + func((LPWINDOWPOS)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnWindowPosChanged(LPWINDOWPOS lpWndPos) +#define MSG_WM_WINDOWPOSCHANGED(func) \ + if (uMsg == WM_WINDOWPOSCHANGED) \ + { \ + this->SetMsgHandled(TRUE); \ + func((LPWINDOWPOS)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnExitMenuLoop(BOOL fIsTrackPopupMenu) +#define MSG_WM_EXITMENULOOP(func) \ + if (uMsg == WM_EXITMENULOOP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnEnterMenuLoop(BOOL fIsTrackPopupMenu) +#define MSG_WM_ENTERMENULOOP(func) \ + if (uMsg == WM_ENTERMENULOOP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnStyleChanged(int nStyleType, LPSTYLESTRUCT lpStyleStruct) +#define MSG_WM_STYLECHANGED(func) \ + if (uMsg == WM_STYLECHANGED) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPSTYLESTRUCT)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnStyleChanging(int nStyleType, LPSTYLESTRUCT lpStyleStruct) +#define MSG_WM_STYLECHANGING(func) \ + if (uMsg == WM_STYLECHANGING) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPSTYLESTRUCT)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSizing(UINT fwSide, LPRECT pRect) +#define MSG_WM_SIZING(func) \ + if (uMsg == WM_SIZING) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPRECT)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMoving(UINT fwSide, LPRECT pRect) +#define MSG_WM_MOVING(func) \ + if (uMsg == WM_MOVING) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPRECT)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCaptureChanged(CWindow wnd) +#define MSG_WM_CAPTURECHANGED(func) \ + if (uMsg == WM_CAPTURECHANGED) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnDeviceChange(UINT nEventType, DWORD_PTR dwData) +#define MSG_WM_DEVICECHANGE(func) \ + if (uMsg == WM_DEVICECHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)wParam, (DWORD_PTR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCommand(UINT uNotifyCode, int nID, CWindow wndCtl) +#define MSG_WM_COMMAND(func) \ + if (uMsg == WM_COMMAND) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDisplayChange(UINT uBitsPerPixel, CSize sizeScreen) +#define MSG_WM_DISPLAYCHANGE(func) \ + if (uMsg == WM_DISPLAYCHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CSize(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnEnterSizeMove() +#define MSG_WM_ENTERSIZEMOVE(func) \ + if (uMsg == WM_ENTERSIZEMOVE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnExitSizeMove() +#define MSG_WM_EXITSIZEMOVE(func) \ + if (uMsg == WM_EXITSIZEMOVE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HFONT OnGetFont() +#define MSG_WM_GETFONT(func) \ + if (uMsg == WM_GETFONT) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func(); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnGetHotKey() +#define MSG_WM_GETHOTKEY(func) \ + if (uMsg == WM_GETHOTKEY) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func(); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HICON OnGetIcon() +#define MSG_WM_GETICON(func) \ + if (uMsg == WM_GETICON) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)wParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnGetText(int cchTextMax, LPTSTR lpszText) +#define MSG_WM_GETTEXT(func) \ + if (uMsg == WM_GETTEXT) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((int)wParam, (LPTSTR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnGetTextLength() +#define MSG_WM_GETTEXTLENGTH(func) \ + if (uMsg == WM_GETTEXTLENGTH) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func(); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnHelp(LPHELPINFO lpHelpInfo) +#define MSG_WM_HELP(func) \ + if (uMsg == WM_HELP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((LPHELPINFO)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnHotKey(int nHotKeyID, UINT uModifiers, UINT uVirtKey) +#define MSG_WM_HOTKEY(func) \ + if (uMsg == WM_HOTKEY) \ + { \ + this->SetMsgHandled(TRUE); \ + func((int)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnInputLangChange(DWORD dwCharSet, HKL hKbdLayout) +#define MSG_WM_INPUTLANGCHANGE(func) \ + if (uMsg == WM_INPUTLANGCHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((DWORD)wParam, (HKL)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnInputLangChangeRequest(BOOL bSysCharSet, HKL hKbdLayout) +#define MSG_WM_INPUTLANGCHANGEREQUEST(func) \ + if (uMsg == WM_INPUTLANGCHANGEREQUEST) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)wParam, (HKL)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNextDlgCtl(BOOL bHandle, WPARAM wCtlFocus) +#define MSG_WM_NEXTDLGCTL(func) \ + if (uMsg == WM_NEXTDLGCTL) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)LOWORD(lParam), wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNextMenu(int nVirtKey, LPMDINEXTMENU lpMdiNextMenu) +#define MSG_WM_NEXTMENU(func) \ + if (uMsg == WM_NEXTMENU) \ + { \ + this->SetMsgHandled(TRUE); \ + func((int)wParam, (LPMDINEXTMENU)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnNotifyFormat(CWindow wndFrom, int nCommand) +#define MSG_WM_NOTIFYFORMAT(func) \ + if (uMsg == WM_NOTIFYFORMAT) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HWND)wParam, (int)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnPowerBroadcast(DWORD dwPowerEvent, DWORD_PTR dwData) +#define MSG_WM_POWERBROADCAST(func) \ + if (uMsg == WM_POWERBROADCAST) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((DWORD)wParam, (DWORD_PTR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnPrint(CDCHandle dc, UINT uFlags) +#define MSG_WM_PRINT(func) \ + if (uMsg == WM_PRINT) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HDC)wParam, (UINT)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnPrintClient(CDCHandle dc, UINT uFlags) +#define MSG_WM_PRINTCLIENT(func) \ + if (uMsg == WM_PRINTCLIENT) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HDC)wParam, (UINT)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnRasDialEvent(RASCONNSTATE rasconnstate, DWORD dwError) +#define MSG_WM_RASDIALEVENT(func) \ + if (uMsg == WM_RASDIALEVENT) \ + { \ + this->SetMsgHandled(TRUE); \ + func((RASCONNSTATE)wParam, (DWORD)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSetFont(CFontHandle font, BOOL bRedraw) +#define MSG_WM_SETFONT(func) \ + if (uMsg == WM_SETFONT) \ + { \ + this->SetMsgHandled(TRUE); \ + func((HFONT)wParam, (BOOL)LOWORD(lParam)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnSetHotKey(int nVirtKey, UINT uFlags) +#define MSG_WM_SETHOTKEY(func) \ + if (uMsg == WM_SETHOTKEY) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((int)LOBYTE(LOWORD(wParam)), (UINT)HIBYTE(LOWORD(wParam))); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HICON OnSetIcon(UINT uType, HICON hIcon) +#define MSG_WM_SETICON(func) \ + if (uMsg == WM_SETICON) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)wParam, (HICON)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnSetRedraw(BOOL bRedraw) +#define MSG_WM_SETREDRAW(func) \ + if (uMsg == WM_SETREDRAW) \ + { \ + this->SetMsgHandled(TRUE); \ + func((BOOL)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnSetText(LPCTSTR lpstrText) +#define MSG_WM_SETTEXT(func) \ + if (uMsg == WM_SETTEXT) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((LPCTSTR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnUserChanged() +#define MSG_WM_USERCHANGED(func) \ + if (uMsg == WM_USERCHANGED) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +/////////////////////////////////////////////////////////////////////////////// +// Newer Windows messages + +// void OnMouseHover(WPARAM wParam, CPoint ptPos) +#define MSG_WM_MOUSEHOVER(func) \ + if (uMsg == WM_MOUSEHOVER) \ + { \ + this->SetMsgHandled(TRUE); \ + func(wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMouseLeave() +#define MSG_WM_MOUSELEAVE(func) \ + if (uMsg == WM_MOUSELEAVE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcMouseHover(UINT nHitTest, CPoint ptPos) +#define MSG_WM_NCMOUSEHOVER(func) \ + if (uMsg == WM_NCMOUSEHOVER) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, ::CPoint(MAKEPOINTS(lParam).x, MAKEPOINTS(lParam).y)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNcMouseLeave() +#define MSG_WM_NCMOUSELEAVE(func) \ + if (uMsg == WM_NCMOUSELEAVE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMenuRButtonUp(WPARAM wParam, CMenuHandle menu) +#define MSG_WM_MENURBUTTONUP(func) \ + if (uMsg == WM_MENURBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func(wParam, (HMENU)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnMenuDrag(WPARAM wParam, CMenuHandle menu) +#define MSG_WM_MENUDRAG(func) \ + if (uMsg == WM_MENUDRAG) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func(wParam, (HMENU)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnMenuGetObject(PMENUGETOBJECTINFO info) +#define MSG_WM_MENUGETOBJECT(func) \ + if (uMsg == WM_MENUGETOBJECT) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((PMENUGETOBJECTINFO)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnUnInitMenuPopup(UINT nID, CMenuHandle menu) +#define MSG_WM_UNINITMENUPOPUP(func) \ + if (uMsg == WM_UNINITMENUPOPUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(lParam), (HMENU)wParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnMenuCommand(WPARAM nIndex, CMenuHandle menu) +#define MSG_WM_MENUCOMMAND(func) \ + if (uMsg == WM_MENUCOMMAND) \ + { \ + this->SetMsgHandled(TRUE); \ + func(wParam, (HMENU)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnAppCommand(CWindow wndFocus, short cmd, WORD uDevice, int dwKeys) +#define MSG_WM_APPCOMMAND(func) \ + if (uMsg == WM_APPCOMMAND) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HWND)wParam, GET_APPCOMMAND_LPARAM(lParam), GET_DEVICE_LPARAM(lParam), GET_KEYSTATE_LPARAM(lParam)); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNCXButtonDown(int fwButton, short nHittest, CPoint ptPos) +#define MSG_WM_NCXBUTTONDOWN(func) \ + if (uMsg == WM_NCXBUTTONDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func(GET_XBUTTON_WPARAM(wParam), GET_NCHITTEST_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = (LRESULT)TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNCXButtonUp(int fwButton, short nHittest, CPoint ptPos) +#define MSG_WM_NCXBUTTONUP(func) \ + if (uMsg == WM_NCXBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func(GET_XBUTTON_WPARAM(wParam), GET_NCHITTEST_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = (LRESULT)TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnNCXButtonDblClk(int fwButton, short nHittest, CPoint ptPos) +#define MSG_WM_NCXBUTTONDBLCLK(func) \ + if (uMsg == WM_NCXBUTTONDBLCLK) \ + { \ + this->SetMsgHandled(TRUE); \ + func(GET_XBUTTON_WPARAM(wParam), GET_NCHITTEST_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = (LRESULT)TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnXButtonDown(int fwButton, int dwKeys, CPoint ptPos) +#define MSG_WM_XBUTTONDOWN(func) \ + if (uMsg == WM_XBUTTONDOWN) \ + { \ + this->SetMsgHandled(TRUE); \ + func(GET_XBUTTON_WPARAM(wParam), GET_KEYSTATE_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = (LRESULT)TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnXButtonUp(int fwButton, int dwKeys, CPoint ptPos) +#define MSG_WM_XBUTTONUP(func) \ + if (uMsg == WM_XBUTTONUP) \ + { \ + this->SetMsgHandled(TRUE); \ + func(GET_XBUTTON_WPARAM(wParam), GET_KEYSTATE_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = (LRESULT)TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnXButtonDblClk(int fwButton, int dwKeys, CPoint ptPos) +#define MSG_WM_XBUTTONDBLCLK(func) \ + if (uMsg == WM_XBUTTONDBLCLK) \ + { \ + this->SetMsgHandled(TRUE); \ + func(GET_XBUTTON_WPARAM(wParam), GET_KEYSTATE_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + lResult = (LRESULT)TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnChangeUIState(WORD nAction, WORD nState) +#define MSG_WM_CHANGEUISTATE(func) \ + if (uMsg == WM_CHANGEUISTATE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(LOWORD(wParam), HIWORD(wParam)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnUpdateUIState(WORD nAction, WORD nState) +#define MSG_WM_UPDATEUISTATE(func) \ + if (uMsg == WM_UPDATEUISTATE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(LOWORD(wParam), HIWORD(wParam)); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnQueryUIState() +#define MSG_WM_QUERYUISTATE(func) \ + if (uMsg == WM_QUERYUISTATE) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func(); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnInput(WPARAM RawInputCode, HRAWINPUT hRawInput) +#define MSG_WM_INPUT(func) \ + if (uMsg == WM_INPUT) \ + { \ + this->SetMsgHandled(TRUE); \ + func(GET_RAWINPUT_CODE_WPARAM(wParam), (HRAWINPUT)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnUniChar(TCHAR nChar, UINT nRepCnt, UINT nFlags) +#define MSG_WM_UNICHAR(func) \ + if (uMsg == WM_UNICHAR) \ + { \ + this->SetMsgHandled(TRUE); \ + func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \ + if(this->IsMsgHandled()) \ + { \ + lResult = (wParam == UNICODE_NOCHAR) ? TRUE : FALSE; \ + return TRUE; \ + } \ + } + +// void OnWTSSessionChange(WPARAM nStatusCode, DWORD dwSessionID) +#define MSG_WM_WTSSESSION_CHANGE(func) \ + if (uMsg == WM_WTSSESSION_CHANGE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(wParam, (DWORD)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnThemeChanged() +#define MSG_WM_THEMECHANGED(func) \ + if (uMsg == WM_THEMECHANGED) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +#if (_WIN32_WINNT >= 0x0600) + +// BOOL OnMouseHWheel(UINT nFlags, short zDelta, CPoint pt) +#define MSG_WM_MOUSEHWHEEL(func) \ + if (uMsg == WM_MOUSEHWHEEL) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)LOWORD(wParam), (short)HIWORD(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +#endif // (_WIN32_WINNT >= 0x0600) + +#if (WINVER >= 0x0601) + +// void OnGesture(ULONGLONG ullArguments, HGESTUREINFO hGestureInfo) +#define MSG_WM_GESTURE(func) \ + if (uMsg == WM_GESTURE) \ + { \ + this->SetMsgHandled(TRUE); \ + func((ULONGLONG)wParam, (HGESTUREINFO)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnGestureNotify(PGESTURENOTIFYSTRUCT pGestureNotifyStruct) +#define MSG_WM_GESTURENOTIFY(func) \ + if (uMsg == WM_GESTURENOTIFY) \ + { \ + func((PGESTURENOTIFYSTRUCT)lParam); \ + } + +// void OnDpiChanged(UINT nDpiX, UINT nDpiY, PRECT pRect) +#define MSG_WM_DPICHANGED(func) \ + if (uMsg == WM_DPICHANGED) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (PRECT)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +#endif // (WINVER >= 0x0601) + +#if (WINVER >= 0x0605) + +// void OnDpiChangedBeforeParent() +#define MSG_WM_DPICHANGED_BEFOREPARENT(func) \ + if (uMsg == WM_DPICHANGED_BEFOREPARENT) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDpiChangedAfterParent() +#define MSG_WM_DPICHANGED_AFTERPARENT(func) \ + if (uMsg == WM_DPICHANGED_AFTERPARENT) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// BOOL OnGetDpiScaledSize(UINT uDpi, PSIZE pSize) +#define MSG_WM_GETDPISCALEDSIZE(func) \ +if (uMsg == WM_GETDPISCALEDSIZE) \ +{ \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)wParam, (PSIZE)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ +} + +#endif // (WINVER >= 0x0605) + +/////////////////////////////////////////////////////////////////////////////// +// ATL defined messages + +// BOOL OnForwardMsg(LPMSG Msg, DWORD nUserData) +#define MSG_WM_FORWARDMSG(func) \ + if (uMsg == WM_FORWARDMSG) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((LPMSG)lParam, (DWORD)wParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +/////////////////////////////////////////////////////////////////////////////// +// Dialog specific messages + +// LRESULT OnDMGetDefID() +#define MSG_DM_GETDEFID(func) \ + if (uMsg == DM_GETDEFID) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func(); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDMSetDefID(UINT DefID) +#define MSG_DM_SETDEFID(func) \ + if (uMsg == DM_SETDEFID) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnDMReposition() +#define MSG_DM_REPOSITION(func) \ + if (uMsg == DM_REPOSITION) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +/////////////////////////////////////////////////////////////////////////////// +// Reflected messages + +// void OnReflectedCommand(UINT uNotifyCode, int nID, CWindow wndCtl) +#define MSG_OCM_COMMAND(func) \ + if (uMsg == OCM_COMMAND) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedNotify(int idCtrl, LPNMHDR pnmh) +#define MSG_OCM_NOTIFY(func) \ + if (uMsg == OCM_NOTIFY) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((int)wParam, (LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnReflectedParentNotify(UINT message, UINT nChildID, LPARAM lParam) +#define MSG_OCM_PARENTNOTIFY(func) \ + if (uMsg == OCM_PARENTNOTIFY) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnReflectedDrawItem(int nIDCtl, LPDRAWITEMSTRUCT lpDrawItemStruct) +#define MSG_OCM_DRAWITEM(func) \ + if (uMsg == OCM_DRAWITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPDRAWITEMSTRUCT)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnReflectedMeasureItem(int nIDCtl, LPMEASUREITEMSTRUCT lpMeasureItemStruct) +#define MSG_OCM_MEASUREITEM(func) \ + if (uMsg == OCM_MEASUREITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPMEASUREITEMSTRUCT)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnReflectedCompareItem(int nIDCtl, LPCOMPAREITEMSTRUCT lpCompareItemStruct) +#define MSG_OCM_COMPAREITEM(func) \ + if (uMsg == OCM_COMPAREITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)wParam, (LPCOMPAREITEMSTRUCT)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnReflectedDeleteItem(int nIDCtl, LPDELETEITEMSTRUCT lpDeleteItemStruct) +#define MSG_OCM_DELETEITEM(func) \ + if (uMsg == OCM_DELETEITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)wParam, (LPDELETEITEMSTRUCT)lParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// int OnReflectedVKeyToItem(UINT nKey, UINT nIndex, CListBox listBox) +#define MSG_OCM_VKEYTOITEM(func) \ + if (uMsg == OCM_VKEYTOITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +//int OnReflectedCharToItem(UINT nChar, UINT nIndex, CListBox listBox) +#define MSG_OCM_CHARTOITEM(func) \ + if (uMsg == OCM_CHARTOITEM) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnReflectedHScroll(UINT nSBCode, UINT nPos, CScrollBar pScrollBar) +#define MSG_OCM_HSCROLL(func) \ + if (uMsg == OCM_HSCROLL) \ + { \ + this->SetMsgHandled(TRUE); \ + func((int)LOWORD(wParam), (short)HIWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnReflectedVScroll(UINT nSBCode, UINT nPos, CScrollBar pScrollBar) +#define MSG_OCM_VSCROLL(func) \ + if (uMsg == OCM_VSCROLL) \ + { \ + this->SetMsgHandled(TRUE); \ + func((int)LOWORD(wParam), (short)HIWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnReflectedCtlColorEdit(CDCHandle dc, CEdit edit) +#define MSG_OCM_CTLCOLOREDIT(func) \ + if (uMsg == OCM_CTLCOLOREDIT) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnReflectedCtlColorListBox(CDCHandle dc, CListBox listBox) +#define MSG_OCM_CTLCOLORLISTBOX(func) \ + if (uMsg == OCM_CTLCOLORLISTBOX) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnReflectedCtlColorBtn(CDCHandle dc, CButton button) +#define MSG_OCM_CTLCOLORBTN(func) \ + if (uMsg == OCM_CTLCOLORBTN) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnReflectedCtlColorDlg(CDCHandle dc, CWindow wnd) +#define MSG_OCM_CTLCOLORDLG(func) \ + if (uMsg == OCM_CTLCOLORDLG) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnReflectedCtlColorScrollBar(CDCHandle dc, CScrollBar scrollBar) +#define MSG_OCM_CTLCOLORSCROLLBAR(func) \ + if (uMsg == OCM_CTLCOLORSCROLLBAR) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// HBRUSH OnReflectedCtlColorStatic(CDCHandle dc, CStatic wndStatic) +#define MSG_OCM_CTLCOLORSTATIC(func) \ + if (uMsg == OCM_CTLCOLORSTATIC) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +/////////////////////////////////////////////////////////////////////////////// +// Edit specific messages + +// void OnClear() +#define MSG_WM_CLEAR(func) \ + if (uMsg == WM_CLEAR) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCopy() +#define MSG_WM_COPY(func) \ + if (uMsg == WM_COPY) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCut() +#define MSG_WM_CUT(func) \ + if (uMsg == WM_CUT) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnPaste() +#define MSG_WM_PASTE(func) \ + if (uMsg == WM_PASTE) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnUndo() +#define MSG_WM_UNDO(func) \ + if (uMsg == WM_UNDO) \ + { \ + this->SetMsgHandled(TRUE); \ + func(); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +/////////////////////////////////////////////////////////////////////////////// +// Generic message handlers + +// LRESULT OnMessageHandlerEX(UINT uMsg, WPARAM wParam, LPARAM lParam) +#define MESSAGE_HANDLER_EX(msg, func) \ + if(uMsg == msg) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func(uMsg, wParam, lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnMessageRangeHandlerEX(UINT uMsg, WPARAM wParam, LPARAM lParam) +#define MESSAGE_RANGE_HANDLER_EX(msgFirst, msgLast, func) \ + if((uMsg >= msgFirst) && (uMsg <= msgLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func(uMsg, wParam, lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +/////////////////////////////////////////////////////////////////////////////// +// Commands and notifications + +// void OnCommandHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define COMMAND_HANDLER_EX(id, code, func) \ + if ((uMsg == WM_COMMAND) && (code == HIWORD(wParam)) && (id == LOWORD(wParam))) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCommandIDHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define COMMAND_ID_HANDLER_EX(id, func) \ + if ((uMsg == WM_COMMAND) && (id == LOWORD(wParam))) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCommandCodeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define COMMAND_CODE_HANDLER_EX(code, func) \ + if ((uMsg == WM_COMMAND) && (code == HIWORD(wParam))) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnNotifyHandlerEX(LPNMHDR pnmh) +#define NOTIFY_HANDLER_EX(id, cd, func) \ + if ((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (id == ((LPNMHDR)lParam)->idFrom)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnNotifyIDHandlerEX(LPNMHDR pnmh) +#define NOTIFY_ID_HANDLER_EX(id, func) \ + if ((uMsg == WM_NOTIFY) && (id == ((LPNMHDR)lParam)->idFrom)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnNotifyCodeHandlerEX(LPNMHDR pnmh) +#define NOTIFY_CODE_HANDLER_EX(cd, func) \ + if ((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCommandRangeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define COMMAND_RANGE_HANDLER_EX(idFirst, idLast, func) \ + if((uMsg == WM_COMMAND) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnCommandRangeCodeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define COMMAND_RANGE_CODE_HANDLER_EX(idFirst, idLast, code, func) \ + if((uMsg == WM_COMMAND) && (code == HIWORD(wParam)) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnNotifyRangeHandlerEX(LPNMHDR pnmh) +#define NOTIFY_RANGE_HANDLER_EX(idFirst, idLast, func) \ + if((uMsg == WM_NOTIFY) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnNotifyRangeCodeHandlerEX(LPNMHDR pnmh) +#define NOTIFY_RANGE_CODE_HANDLER_EX(idFirst, idLast, cd, func) \ + if((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedCommandHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define REFLECTED_COMMAND_HANDLER_EX(id, code, func) \ + if ((uMsg == OCM_COMMAND) && (code == HIWORD(wParam)) && (id == LOWORD(wParam))) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedCommandIDHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define REFLECTED_COMMAND_ID_HANDLER_EX(id, func) \ + if ((uMsg == OCM_COMMAND) && (id == LOWORD(wParam))) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedCommandCodeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define REFLECTED_COMMAND_CODE_HANDLER_EX(code, func) \ + if ((uMsg == OCM_COMMAND) && (code == HIWORD(wParam))) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedNotifyHandlerEX(LPNMHDR pnmh) +#define REFLECTED_NOTIFY_HANDLER_EX(id, cd, func) \ + if ((uMsg == OCM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (id == ((LPNMHDR)lParam)->idFrom)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedNotifyIDHandlerEX(LPNMHDR pnmh) +#define REFLECTED_NOTIFY_ID_HANDLER_EX(id, func) \ + if ((uMsg == OCM_NOTIFY) && (id == ((LPNMHDR)lParam)->idFrom)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedNotifyCodeHandlerEX(LPNMHDR pnmh) +#define REFLECTED_NOTIFY_CODE_HANDLER_EX(cd, func) \ + if ((uMsg == OCM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnReflectedCommandRangeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define REFLECTED_COMMAND_RANGE_HANDLER_EX(idFirst, idLast, func) \ + if((uMsg == OCM_COMMAND) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnReflectedCommandRangeCodeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl) +#define REFLECTED_COMMAND_RANGE_CODE_HANDLER_EX(idFirst, idLast, code, func) \ + if((uMsg == OCM_COMMAND) && (code == HIWORD(wParam)) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \ + lResult = 0; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedNotifyRangeHandlerEX(LPNMHDR pnmh) +#define REFLECTED_NOTIFY_RANGE_HANDLER_EX(idFirst, idLast, func) \ + if((uMsg == OCM_NOTIFY) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// LRESULT OnReflectedNotifyRangeCodeHandlerEX(LPNMHDR pnmh) +#define REFLECTED_NOTIFY_RANGE_CODE_HANDLER_EX(idFirst, idLast, cd, func) \ + if((uMsg == OCM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \ + { \ + this->SetMsgHandled(TRUE); \ + lResult = func((LPNMHDR)lParam); \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +// void OnAppCommandHandler(UINT uDevice, DWORD dwKeys, CWindow wndFocus) +#define APPCOMMAND_HANDLER_EX(cmd, func) \ + if((uMsg == WM_APPCOMMAND) && (cmd == GET_APPCOMMAND_LPARAM(lParam))) \ + { \ + this->SetMsgHandled(TRUE); \ + func(GET_DEVICE_LPARAM(lParam), GET_KEYSTATE_LPARAM(lParam), (HWND)wParam); \ + lResult = TRUE; \ + if(this->IsMsgHandled()) \ + return TRUE; \ + } + +#endif // __ATLCRACK_H__ diff --git a/Examples/WhisperDesktop/Utils/WTL/atlctrls.h b/Examples/WhisperDesktop/Utils/WTL/atlctrls.h new file mode 100644 index 0000000..61df427 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/WTL/atlctrls.h @@ -0,0 +1,9764 @@ +// Windows Template Library - WTL version 10.0 +// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved. +// +// This file is a part of the Windows Template Library. +// The use and distribution terms for this software are covered by the +// Microsoft Public License (http://opensource.org/licenses/MS-PL) +// which can be found in the file MS-PL.txt at the root folder. + +#ifndef __ATLCTRLS_H__ +#define __ATLCTRLS_H__ + +#pragma once + +#ifndef __ATLAPP_H__ + #error atlctrls.h requires atlapp.h to be included first +#endif + +#ifndef __ATLWIN_H__ + #error atlctrls.h requires atlwin.h to be included first +#endif + +#include <richedit.h> +#include <richole.h> + +#if (_RICHEDIT_VER < 0x0300) + #error WTL10 requires _RICHEDIT_VER >= 0x0300 +#endif + +// protect template members from windowsx.h macros +#ifdef _INC_WINDOWSX + #undef GetNextSibling + #undef GetPrevSibling +#endif // _INC_WINDOWSX + + +/////////////////////////////////////////////////////////////////////////////// +// Classes in this file: +// +// CStaticT<TBase> - CStatic +// CButtonT<TBase> - CButton +// CListBoxT<TBase> - CListBox +// CComboBoxT<TBase> - CComboBox +// CEditT<TBase> - CEdit +// CEditCommands<T> +// CScrollBarT<TBase> - CScrollBar +// +// CImageListT<t_bManaged> - CImageList, CImageListManaged +// CListViewCtrlT<TBase> - CListViewCtrl +// CTreeViewCtrlT<TBase> - CTreeViewCtrl +// CTreeItemT<TBase> - CTreeItem +// CTreeViewCtrlExT<TBase> - CTreeViewCtrlEx +// CHeaderCtrlT<TBase> - CHeaderCtrl +// CToolBarCtrlT<TBase> - CToolBarCtrl +// CStatusBarCtrlT<TBase> - CStatusBarCtrl +// CTabCtrlT<TBase> - CTabCtrl +// CToolInfo +// CToolTipCtrlT<TBase> - CToolTipCtrl +// CTrackBarCtrlT<TBase> - CTrackBarCtrl +// CUpDownCtrlT<TBase> - CUpDownCtrl +// CProgressBarCtrlT<TBase> - CProgressBarCtrl +// CHotKeyCtrlT<TBase> - CHotKeyCtrl +// CAnimateCtrlT<TBase> - CAnimateCtrl +// CRichEditCtrlT<TBase> - CRichEditCtrl +// CRichEditCommands<T> +// CDragListBoxT<TBase> - CDragListBox +// CDragListNotifyImpl<T> +// CReBarCtrlT<TBase> - CReBarCtrl +// CComboBoxExT<TBase> - CComboBoxEx +// CDateTimePickerCtrlT<TBase> - CDateTimePickerCtrl +// CMonthCalendarCtrlT<TBase> - CMonthCalendarCtrl +// CFlatScrollBarImpl<T> +// CFlatScrollBarT<TBase> - CFlatScrollBar +// CIPAddressCtrlT<TBase> - CIPAddressCtrl +// CPagerCtrlT<TBase> - CPagerCtrl +// CLinkCtrlT<TBase> - CLinkCtrl +// +// CCustomDraw<T> + + +namespace WTL +{ + +// These are wrapper classes for Windows standard and common controls. +// To implement a window based on a control, use following: +// Example: Implementing a window based on a list box +// +// class CMyListBox : CWindowImpl<CMyListBox, CListBox> +// { +// public: +// BEGIN_MSG_MAP(CMyListBox) +// // put your message handler entries here +// END_MSG_MAP() +// }; + + + +// --- Standard Windows controls --- + +/////////////////////////////////////////////////////////////////////////////// +// CStatic - client side for a Windows STATIC control + +template <class TBase> +class CStaticT : public TBase +{ +public: +// Constructors + CStaticT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CStaticT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return _T("STATIC"); + } + + HICON GetIcon() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HICON)::SendMessage(this->m_hWnd, STM_GETICON, 0, 0L); + } + + HICON SetIcon(HICON hIcon) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HICON)::SendMessage(this->m_hWnd, STM_SETICON, (WPARAM)hIcon, 0L); + } + + HENHMETAFILE GetEnhMetaFile() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HENHMETAFILE)::SendMessage(this->m_hWnd, STM_GETIMAGE, IMAGE_ENHMETAFILE, 0L); + } + + HENHMETAFILE SetEnhMetaFile(HENHMETAFILE hMetaFile) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HENHMETAFILE)::SendMessage(this->m_hWnd, STM_SETIMAGE, IMAGE_ENHMETAFILE, (LPARAM)hMetaFile); + } + + CBitmapHandle GetBitmap() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CBitmapHandle((HBITMAP)::SendMessage(this->m_hWnd, STM_GETIMAGE, IMAGE_BITMAP, 0L)); + } + + CBitmapHandle SetBitmap(HBITMAP hBitmap) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CBitmapHandle((HBITMAP)::SendMessage(this->m_hWnd, STM_SETIMAGE, IMAGE_BITMAP, (LPARAM)hBitmap)); + } + + HCURSOR GetCursor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HCURSOR)::SendMessage(this->m_hWnd, STM_GETIMAGE, IMAGE_CURSOR, 0L); + } + + HCURSOR SetCursor(HCURSOR hCursor) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HCURSOR)::SendMessage(this->m_hWnd, STM_SETIMAGE, IMAGE_CURSOR, (LPARAM)hCursor); + } +}; + +typedef CStaticT<ATL::CWindow> CStatic; + + +/////////////////////////////////////////////////////////////////////////////// +// CButton - client side for a Windows BUTTON control + +template <class TBase> +class CButtonT : public TBase +{ +public: +// Constructors + CButtonT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CButtonT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return _T("BUTTON"); + } + + UINT GetState() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, BM_GETSTATE, 0, 0L); + } + + void SetState(BOOL bHighlight) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, BM_SETSTATE, bHighlight, 0L); + } + + int GetCheck() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, BM_GETCHECK, 0, 0L); + } + + void SetCheck(int nCheck) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, BM_SETCHECK, nCheck, 0L); + } + + UINT GetButtonStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::GetWindowLong(this->m_hWnd, GWL_STYLE) & 0xFFFF; + } + + void SetButtonStyle(UINT nStyle, BOOL bRedraw = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, BM_SETSTYLE, nStyle, (LPARAM)bRedraw); + } + + HICON GetIcon() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HICON)::SendMessage(this->m_hWnd, BM_GETIMAGE, IMAGE_ICON, 0L); + } + + HICON SetIcon(HICON hIcon) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HICON)::SendMessage(this->m_hWnd, BM_SETIMAGE, IMAGE_ICON, (LPARAM)hIcon); + } + + CBitmapHandle GetBitmap() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CBitmapHandle((HBITMAP)::SendMessage(this->m_hWnd, BM_GETIMAGE, IMAGE_BITMAP, 0L)); + } + + CBitmapHandle SetBitmap(HBITMAP hBitmap) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CBitmapHandle((HBITMAP)::SendMessage(this->m_hWnd, BM_SETIMAGE, IMAGE_BITMAP, (LPARAM)hBitmap)); + } + + BOOL GetIdealSize(LPSIZE lpSize) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, BCM_GETIDEALSIZE, 0, (LPARAM)lpSize); + } + + BOOL GetImageList(PBUTTON_IMAGELIST pButtonImagelist) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, BCM_GETIMAGELIST, 0, (LPARAM)pButtonImagelist); + } + + BOOL SetImageList(PBUTTON_IMAGELIST pButtonImagelist) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, BCM_SETIMAGELIST, 0, (LPARAM)pButtonImagelist); + } + + BOOL GetTextMargin(LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, BCM_GETTEXTMARGIN, 0, (LPARAM)lpRect); + } + + BOOL SetTextMargin(LPRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, BCM_SETTEXTMARGIN, 0, (LPARAM)lpRect); + } + +#if (WINVER >= 0x0600) + void SetDontClick(BOOL bDontClick) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, BM_SETDONTCLICK, (WPARAM)bDontClick, 0L); + } +#endif // (WINVER >= 0x0600) + +#if (_WIN32_WINNT >= 0x0600) + BOOL SetDropDownState(BOOL bDropDown) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (BS_SPLITBUTTON | BS_DEFSPLITBUTTON)) != 0); + return (BOOL)::SendMessage(this->m_hWnd, BCM_SETDROPDOWNSTATE, (WPARAM)bDropDown, 0L); + } + + BOOL GetSplitInfo(PBUTTON_SPLITINFO pSplitInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (BS_SPLITBUTTON | BS_DEFSPLITBUTTON)) != 0); + return (BOOL)::SendMessage(this->m_hWnd, BCM_GETSPLITINFO, 0, (LPARAM)pSplitInfo); + } + + BOOL SetSplitInfo(PBUTTON_SPLITINFO pSplitInfo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (BS_SPLITBUTTON | BS_DEFSPLITBUTTON)) != 0); + return (BOOL)::SendMessage(this->m_hWnd, BCM_SETSPLITINFO, 0, (LPARAM)pSplitInfo); + } + + int GetNoteLength() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (BS_COMMANDLINK | BS_DEFCOMMANDLINK)) != 0); + return (int)::SendMessage(this->m_hWnd, BCM_GETNOTELENGTH, 0, 0L); + } + + BOOL GetNote(LPWSTR lpstrNoteText, int cchNoteText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (BS_COMMANDLINK | BS_DEFCOMMANDLINK)) != 0); + return (BOOL)::SendMessage(this->m_hWnd, BCM_GETNOTE, cchNoteText, (LPARAM)lpstrNoteText); + } + + BOOL SetNote(LPCWSTR lpstrNoteText) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (BS_COMMANDLINK | BS_DEFCOMMANDLINK)) != 0); + return (BOOL)::SendMessage(this->m_hWnd, BCM_SETNOTE, 0, (LPARAM)lpstrNoteText); + } + + LRESULT SetElevationRequiredState(BOOL bSet) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SendMessage(this->m_hWnd, BCM_SETSHIELD, 0, (LPARAM)bSet); + } +#endif // (_WIN32_WINNT >= 0x0600) + +// Operations + void Click() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, BM_CLICK, 0, 0L); + } +}; + +typedef CButtonT<ATL::CWindow> CButton; + + +/////////////////////////////////////////////////////////////////////////////// +// CListBox - client side for a Windows LISTBOX control + +template <class TBase> +class CListBoxT : public TBase +{ +public: +// Constructors + CListBoxT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CListBoxT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return _T("LISTBOX"); + } + + // for entire listbox + int GetCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETCOUNT, 0, 0L); + } + + int SetCount(int cItems) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(((this->GetStyle() & LBS_NODATA) != 0) && ((this->GetStyle() & LBS_HASSTRINGS) == 0)); + return (int)::SendMessage(this->m_hWnd, LB_SETCOUNT, cItems, 0L); + } + + int GetHorizontalExtent() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETHORIZONTALEXTENT, 0, 0L); + } + + void SetHorizontalExtent(int cxExtent) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LB_SETHORIZONTALEXTENT, cxExtent, 0L); + } + + int GetTopIndex() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETTOPINDEX, 0, 0L); + } + + int SetTopIndex(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_SETTOPINDEX, nIndex, 0L); + } + + LCID GetLocale() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LCID)::SendMessage(this->m_hWnd, LB_GETLOCALE, 0, 0L); + } + + LCID SetLocale(LCID nNewLocale) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LCID)::SendMessage(this->m_hWnd, LB_SETLOCALE, (WPARAM)nNewLocale, 0L); + } + + DWORD GetListBoxInfo() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, LB_GETLISTBOXINFO, 0, 0L); + } + + // for single-selection listboxes + int GetCurSel() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) == 0); + return (int)::SendMessage(this->m_hWnd, LB_GETCURSEL, 0, 0L); + } + + int SetCurSel(int nSelect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) == 0); + return (int)::SendMessage(this->m_hWnd, LB_SETCURSEL, nSelect, 0L); + } + + // for multiple-selection listboxes + int GetSel(int nIndex) const // also works for single-selection + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETSEL, nIndex, 0L); + } + + int SetSel(int nIndex, BOOL bSelect = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0); + return (int)::SendMessage(this->m_hWnd, LB_SETSEL, bSelect, nIndex); + } + + int GetSelCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0); + return (int)::SendMessage(this->m_hWnd, LB_GETSELCOUNT, 0, 0L); + } + + int GetSelItems(int nMaxItems, LPINT rgIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0); + return (int)::SendMessage(this->m_hWnd, LB_GETSELITEMS, nMaxItems, (LPARAM)rgIndex); + } + + int GetAnchorIndex() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0); + return (int)::SendMessage(this->m_hWnd, LB_GETANCHORINDEX, 0, 0L); + } + + void SetAnchorIndex(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0); + ::SendMessage(this->m_hWnd, LB_SETANCHORINDEX, nIndex, 0L); + } + + int GetCaretIndex() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETCARETINDEX, 0, 0); + } + + int SetCaretIndex(int nIndex, BOOL bScroll = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_SETCARETINDEX, nIndex, MAKELONG(bScroll, 0)); + } + + // for listbox items + DWORD_PTR GetItemData(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD_PTR)::SendMessage(this->m_hWnd, LB_GETITEMDATA, nIndex, 0L); + } + + int SetItemData(int nIndex, DWORD_PTR dwItemData) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_SETITEMDATA, nIndex, (LPARAM)dwItemData); + } + + void* GetItemDataPtr(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (void*)::SendMessage(this->m_hWnd, LB_GETITEMDATA, nIndex, 0L); + } + + int SetItemDataPtr(int nIndex, void* pData) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItemData(nIndex, (DWORD_PTR)pData); + } + + int GetItemRect(int nIndex, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETITEMRECT, nIndex, (LPARAM)lpRect); + } + + int GetText(int nIndex, LPTSTR lpszBuffer) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETTEXT, nIndex, (LPARAM)lpszBuffer); + } + +#ifdef _OLEAUTO_H_ + BOOL GetTextBSTR(int nIndex, BSTR& bstrText) const + { + USES_CONVERSION; + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(bstrText == NULL); + + int nLen = GetTextLen(nIndex); + if(nLen == LB_ERR) + return FALSE; + + ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff; + LPTSTR lpstrText = buff.Allocate(nLen + 1); + if(lpstrText == NULL) + return FALSE; + + if(GetText(nIndex, lpstrText) == LB_ERR) + return FALSE; + + bstrText = ::SysAllocString(T2OLE(lpstrText)); + return (bstrText != NULL) ? TRUE : FALSE; + } +#endif // _OLEAUTO_H_ + +#ifdef __ATLSTR_H__ + int GetText(int nIndex, ATL::CString& strText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + int cchLen = GetTextLen(nIndex); + if(cchLen == LB_ERR) + return LB_ERR; + int nRet = LB_ERR; + LPTSTR lpstr = strText.GetBufferSetLength(cchLen); + if(lpstr != NULL) + { + nRet = GetText(nIndex, lpstr); + strText.ReleaseBuffer(); + } + return nRet; + } +#endif // __ATLSTR_H__ + + int GetTextLen(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETTEXTLEN, nIndex, 0L); + } + + int GetItemHeight(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_GETITEMHEIGHT, nIndex, 0L); + } + + int SetItemHeight(int nIndex, UINT cyItemHeight) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_SETITEMHEIGHT, nIndex, MAKELONG(cyItemHeight, 0)); + } + + // Settable only attributes + void SetColumnWidth(int cxWidth) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LB_SETCOLUMNWIDTH, cxWidth, 0L); + } + + BOOL SetTabStops(int nTabStops, LPINT rgTabStops) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & LBS_USETABSTOPS) != 0); + return (BOOL)::SendMessage(this->m_hWnd, LB_SETTABSTOPS, nTabStops, (LPARAM)rgTabStops); + } + + BOOL SetTabStops() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & LBS_USETABSTOPS) != 0); + return (BOOL)::SendMessage(this->m_hWnd, LB_SETTABSTOPS, 0, 0L); + } + + BOOL SetTabStops(const int& cxEachStop) // takes an 'int' + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & LBS_USETABSTOPS) != 0); + return (BOOL)::SendMessage(this->m_hWnd, LB_SETTABSTOPS, 1, (LPARAM)(LPINT)&cxEachStop); + } + +// Operations + int InitStorage(int nItems, UINT nBytes) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_INITSTORAGE, (WPARAM)nItems, nBytes); + } + + void ResetContent() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LB_RESETCONTENT, 0, 0L); + } + + UINT ItemFromPoint(POINT pt, BOOL& bOutside) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dw = (DWORD)::SendMessage(this->m_hWnd, LB_ITEMFROMPOINT, 0, MAKELPARAM(pt.x, pt.y)); + bOutside = (BOOL)HIWORD(dw); + return (UINT)LOWORD(dw); + } + + // manipulating listbox items + int AddString(LPCTSTR lpszItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_ADDSTRING, 0, (LPARAM)lpszItem); + } + + int DeleteString(UINT nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_DELETESTRING, nIndex, 0L); + } + + int InsertString(int nIndex, LPCTSTR lpszItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_INSERTSTRING, nIndex, (LPARAM)lpszItem); + } + + int Dir(UINT attr, LPCTSTR lpszWildCard) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_DIR, attr, (LPARAM)lpszWildCard); + } + + int AddFile(LPCTSTR lpstrFileName) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_ADDFILE, 0, (LPARAM)lpstrFileName); + } + + // selection helpers + int FindString(int nStartAfter, LPCTSTR lpszItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_FINDSTRING, nStartAfter, (LPARAM)lpszItem); + } + + int FindStringExact(int nIndexStart, LPCTSTR lpszFind) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_FINDSTRINGEXACT, nIndexStart, (LPARAM)lpszFind); + } + + int SelectString(int nStartAfter, LPCTSTR lpszItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LB_SELECTSTRING, nStartAfter, (LPARAM)lpszItem); + } + + int SelItemRange(BOOL bSelect, int nFirstItem, int nLastItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0); + ATLASSERT(nFirstItem <= nLastItem); + return bSelect ? (int)::SendMessage(this->m_hWnd, LB_SELITEMRANGEEX, nFirstItem, nLastItem) : (int)::SendMessage(this->m_hWnd, LB_SELITEMRANGEEX, nLastItem, nFirstItem); + } +}; + +typedef CListBoxT<ATL::CWindow> CListBox; + + +/////////////////////////////////////////////////////////////////////////////// +// CComboBox - client side for a Windows COMBOBOX control + +template <class TBase> +class CComboBoxT : public TBase +{ +public: +// Constructors + CComboBoxT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CComboBoxT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return _T("COMBOBOX"); + } + + // for entire combo box + int GetCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_GETCOUNT, 0, 0L); + } + + int GetCurSel() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_GETCURSEL, 0, 0L); + } + + int SetCurSel(int nSelect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_SETCURSEL, nSelect, 0L); + } + + LCID GetLocale() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LCID)::SendMessage(this->m_hWnd, CB_GETLOCALE, 0, 0L); + } + + LCID SetLocale(LCID nNewLocale) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LCID)::SendMessage(this->m_hWnd, CB_SETLOCALE, (WPARAM)nNewLocale, 0L); + } + + int GetTopIndex() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_GETTOPINDEX, 0, 0L); + } + + int SetTopIndex(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_SETTOPINDEX, nIndex, 0L); + } + + UINT GetHorizontalExtent() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, CB_GETHORIZONTALEXTENT, 0, 0L); + } + + void SetHorizontalExtent(UINT nExtent) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, CB_SETHORIZONTALEXTENT, nExtent, 0L); + } + + int GetDroppedWidth() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_GETDROPPEDWIDTH, 0, 0L); + } + + int SetDroppedWidth(UINT nWidth) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_SETDROPPEDWIDTH, nWidth, 0L); + } + + BOOL GetComboBoxInfo(PCOMBOBOXINFO pComboBoxInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CB_GETCOMBOBOXINFO, 0, (LPARAM)pComboBoxInfo); + } + + // for edit control + DWORD GetEditSel() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, CB_GETEDITSEL, 0, 0L); + } + + BOOL SetEditSel(int nStartChar, int nEndChar) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CB_SETEDITSEL, 0, MAKELONG(nStartChar, nEndChar)); + } + + // for combobox item + DWORD_PTR GetItemData(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD_PTR)::SendMessage(this->m_hWnd, CB_GETITEMDATA, nIndex, 0L); + } + + int SetItemData(int nIndex, DWORD_PTR dwItemData) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_SETITEMDATA, nIndex, (LPARAM)dwItemData); + } + + void* GetItemDataPtr(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (void*)GetItemData(nIndex); + } + + int SetItemDataPtr(int nIndex, void* pData) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItemData(nIndex, (DWORD_PTR)pData); + } + + int GetLBText(int nIndex, LPTSTR lpszText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_GETLBTEXT, nIndex, (LPARAM)lpszText); + } + + BOOL GetLBTextBSTR(int nIndex, BSTR& bstrText) const + { + USES_CONVERSION; + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(bstrText == NULL); + + int nLen = GetLBTextLen(nIndex); + if(nLen == CB_ERR) + return FALSE; + + ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff; + LPTSTR lpstrText = buff.Allocate(nLen + 1); + if(lpstrText == NULL) + return FALSE; + + if(GetLBText(nIndex, lpstrText) == CB_ERR) + return FALSE; + + bstrText = ::SysAllocString(T2OLE(lpstrText)); + return (bstrText != NULL) ? TRUE : FALSE; + } + +#ifdef __ATLSTR_H__ + int GetLBText(int nIndex, ATL::CString& strText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + int cchLen = GetLBTextLen(nIndex); + if(cchLen == CB_ERR) + return CB_ERR; + int nRet = CB_ERR; + LPTSTR lpstr = strText.GetBufferSetLength(cchLen); + if(lpstr != NULL) + { + nRet = GetLBText(nIndex, lpstr); + strText.ReleaseBuffer(); + } + return nRet; + } +#endif // __ATLSTR_H__ + + int GetLBTextLen(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_GETLBTEXTLEN, nIndex, 0L); + } + + int GetItemHeight(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_GETITEMHEIGHT, nIndex, 0L); + } + + int SetItemHeight(int nIndex, UINT cyItemHeight) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_SETITEMHEIGHT, nIndex, MAKELONG(cyItemHeight, 0)); + } + + BOOL GetExtendedUI() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CB_GETEXTENDEDUI, 0, 0L); + } + + int SetExtendedUI(BOOL bExtended = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_SETEXTENDEDUI, bExtended, 0L); + } + + void GetDroppedControlRect(LPRECT lprect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, CB_GETDROPPEDCONTROLRECT, 0, (LPARAM)lprect); + } + + BOOL GetDroppedState() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CB_GETDROPPEDSTATE, 0, 0L); + } + + int GetMinVisible() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_GETMINVISIBLE, 0, 0L); + } + + BOOL SetMinVisible(int nMinVisible) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CB_SETMINVISIBLE, nMinVisible, 0L); + } + + // Vista only + BOOL GetCueBannerText(LPWSTR lpwText, int cchText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CB_GETCUEBANNER, (WPARAM)lpwText, cchText); + } + + // Vista only + BOOL SetCueBannerText(LPCWSTR lpcwText) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CB_SETCUEBANNER, 0, (LPARAM)lpcwText); + } + +// Operations + int InitStorage(int nItems, UINT nBytes) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_INITSTORAGE, (WPARAM)nItems, nBytes); + } + + void ResetContent() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, CB_RESETCONTENT, 0, 0L); + } + + // for edit control + BOOL LimitText(int nMaxChars) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CB_LIMITTEXT, nMaxChars, 0L); + } + + // for drop-down combo boxes + void ShowDropDown(BOOL bShowIt = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, CB_SHOWDROPDOWN, bShowIt, 0L); + } + + // manipulating listbox items + int AddString(LPCTSTR lpszString) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_ADDSTRING, 0, (LPARAM)lpszString); + } + + int DeleteString(UINT nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_DELETESTRING, nIndex, 0L); + } + + int InsertString(int nIndex, LPCTSTR lpszString) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_INSERTSTRING, nIndex, (LPARAM)lpszString); + } + + int Dir(UINT attr, LPCTSTR lpszWildCard) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_DIR, attr, (LPARAM)lpszWildCard); + } + + // selection helpers + int FindString(int nStartAfter, LPCTSTR lpszString) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_FINDSTRING, nStartAfter, (LPARAM)lpszString); + } + + int FindStringExact(int nIndexStart, LPCTSTR lpszFind) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_FINDSTRINGEXACT, nIndexStart, (LPARAM)lpszFind); + } + + int SelectString(int nStartAfter, LPCTSTR lpszString) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CB_SELECTSTRING, nStartAfter, (LPARAM)lpszString); + } + + // Clipboard operations + void Clear() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_CLEAR, 0, 0L); + } + + void Copy() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_COPY, 0, 0L); + } + + void Cut() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_CUT, 0, 0L); + } + + void Paste() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_PASTE, 0, 0L); + } +}; + +typedef CComboBoxT<ATL::CWindow> CComboBox; + + +/////////////////////////////////////////////////////////////////////////////// +// CEdit - client side for a Windows EDIT control + +template <class TBase> +class CEditT : public TBase +{ +public: +// Constructors + CEditT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CEditT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return _T("EDIT"); + } + + BOOL CanUndo() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_CANUNDO, 0, 0L); + } + + int GetLineCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETLINECOUNT, 0, 0L); + } + + BOOL GetModify() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_GETMODIFY, 0, 0L); + } + + void SetModify(BOOL bModified = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETMODIFY, bModified, 0L); + } + + void GetRect(LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_GETRECT, 0, (LPARAM)lpRect); + } + + DWORD GetSel() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETSEL, 0, 0L); + } + + void GetSel(int& nStartChar, int& nEndChar) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_GETSEL, (WPARAM)&nStartChar, (LPARAM)&nEndChar); + } + + HLOCAL GetHandle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HLOCAL)::SendMessage(this->m_hWnd, EM_GETHANDLE, 0, 0L); + } + + void SetHandle(HLOCAL hBuffer) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETHANDLE, (WPARAM)hBuffer, 0L); + } + + DWORD GetMargins() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETMARGINS, 0, 0L); + } + + void GetMargins(UINT& nLeft, UINT& nRight) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, EM_GETMARGINS, 0, 0L); + nLeft = LOWORD(dwRet); + nRight = HIWORD(dwRet); + } + + void SetMargins(UINT nLeft, UINT nRight, WORD wFlags = EC_LEFTMARGIN | EC_RIGHTMARGIN) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETMARGINS, wFlags, MAKELONG(nLeft, nRight)); + } + + UINT GetLimitText() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, EM_GETLIMITTEXT, 0, 0L); + } + + void SetLimitText(UINT nMax) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETLIMITTEXT, nMax, 0L); + } + + POINT PosFromChar(UINT nChar) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, EM_POSFROMCHAR, nChar, 0); + POINT point = { GET_X_LPARAM(dwRet), GET_Y_LPARAM(dwRet) }; + return point; + } + + int CharFromPos(POINT pt, int* pLine = NULL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, EM_CHARFROMPOS, 0, MAKELPARAM(pt.x, pt.y)); + if(pLine != NULL) + *pLine = (int)(short)HIWORD(dwRet); + return (int)(short)LOWORD(dwRet); + } + + // NOTE: first word in lpszBuffer must contain the size of the buffer! + int GetLine(int nIndex, LPTSTR lpszBuffer) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETLINE, nIndex, (LPARAM)lpszBuffer); + } + + int GetLine(int nIndex, LPTSTR lpszBuffer, int nMaxLength) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + *(LPWORD)lpszBuffer = (WORD)nMaxLength; + return (int)::SendMessage(this->m_hWnd, EM_GETLINE, nIndex, (LPARAM)lpszBuffer); + } + + TCHAR GetPasswordChar() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (TCHAR)::SendMessage(this->m_hWnd, EM_GETPASSWORDCHAR, 0, 0L); + } + + void SetPasswordChar(TCHAR ch) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETPASSWORDCHAR, ch, 0L); + } + + EDITWORDBREAKPROC GetWordBreakProc() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (EDITWORDBREAKPROC)::SendMessage(this->m_hWnd, EM_GETWORDBREAKPROC, 0, 0L); + } + + void SetWordBreakProc(EDITWORDBREAKPROC ewbprc) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETWORDBREAKPROC, 0, (LPARAM)ewbprc); + } + + int GetFirstVisibleLine() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETFIRSTVISIBLELINE, 0, 0L); + } + + int GetThumb() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & ES_MULTILINE) != 0); + return (int)::SendMessage(this->m_hWnd, EM_GETTHUMB, 0, 0L); + } + + BOOL SetReadOnly(BOOL bReadOnly = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETREADONLY, bReadOnly, 0L); + } + + UINT GetImeStatus(UINT uStatus) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, EM_GETIMESTATUS, uStatus, 0L); + } + + UINT SetImeStatus(UINT uStatus, UINT uData) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, EM_SETIMESTATUS, uStatus, uData); + } + + BOOL GetCueBannerText(LPCWSTR lpstrText, int cchText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_GETCUEBANNER, (WPARAM)lpstrText, cchText); + } + + // bKeepWithFocus - Vista only + BOOL SetCueBannerText(LPCWSTR lpstrText, BOOL bKeepWithFocus = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCUEBANNER, (WPARAM)bKeepWithFocus, (LPARAM)(lpstrText)); + } + +// Operations + void EmptyUndoBuffer() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_EMPTYUNDOBUFFER, 0, 0L); + } + + BOOL FmtLines(BOOL bAddEOL) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_FMTLINES, bAddEOL, 0L); + } + + void LimitText(int nChars = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_LIMITTEXT, nChars, 0L); + } + + int LineFromChar(int nIndex = -1) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_LINEFROMCHAR, nIndex, 0L); + } + + int LineIndex(int nLine = -1) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_LINEINDEX, nLine, 0L); + } + + int LineLength(int nLine = -1) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_LINELENGTH, nLine, 0L); + } + + void LineScroll(int nLines, int nChars = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_LINESCROLL, nChars, nLines); + } + + void ReplaceSel(LPCTSTR lpszNewText, BOOL bCanUndo = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_REPLACESEL, (WPARAM) bCanUndo, (LPARAM)lpszNewText); + } + + void SetRect(LPCRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETRECT, 0, (LPARAM)lpRect); + } + + void SetRectNP(LPCRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETRECTNP, 0, (LPARAM)lpRect); + } + + void SetSel(DWORD dwSelection, BOOL bNoScroll = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETSEL, LOWORD(dwSelection), HIWORD(dwSelection)); + if(!bNoScroll) + ::SendMessage(this->m_hWnd, EM_SCROLLCARET, 0, 0L); + } + + void SetSel(int nStartChar, int nEndChar, BOOL bNoScroll = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETSEL, nStartChar, nEndChar); + if(!bNoScroll) + ::SendMessage(this->m_hWnd, EM_SCROLLCARET, 0, 0L); + } + + void SetSelAll(BOOL bNoScroll = FALSE) + { + SetSel(0, -1, bNoScroll); + } + + void SetSelNone(BOOL bNoScroll = FALSE) + { + SetSel(-1, 0, bNoScroll); + } + + BOOL SetTabStops(int nTabStops, LPINT rgTabStops) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, nTabStops, (LPARAM)rgTabStops); + } + + BOOL SetTabStops() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, 0, 0L); + } + + BOOL SetTabStops(const int& cxEachStop) // takes an 'int' + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, 1, (LPARAM)(LPINT)&cxEachStop); + } + + void ScrollCaret() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SCROLLCARET, 0, 0L); + } + + int Scroll(int nScrollAction) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & ES_MULTILINE) != 0); + LRESULT lRet = ::SendMessage(this->m_hWnd, EM_SCROLL, nScrollAction, 0L); + if(!(BOOL)HIWORD(lRet)) + return -1; // failed + return (int)(short)LOWORD(lRet); + + } + + void InsertText(int nInsertAfterChar, LPCTSTR lpstrText, BOOL bNoScroll = FALSE, BOOL bCanUndo = FALSE) + { + SetSel(nInsertAfterChar, nInsertAfterChar, bNoScroll); + ReplaceSel(lpstrText, bCanUndo); + } + + void AppendText(LPCTSTR lpstrText, BOOL bNoScroll = FALSE, BOOL bCanUndo = FALSE) + { + InsertText(this->GetWindowTextLength(), lpstrText, bNoScroll, bCanUndo); + } + + BOOL ShowBalloonTip(PEDITBALLOONTIP pEditBaloonTip) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SHOWBALLOONTIP, 0, (LPARAM)pEditBaloonTip); + } + + BOOL HideBalloonTip() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_HIDEBALLOONTIP, 0, 0L); + } + +#if (_WIN32_WINNT >= 0x0600) + DWORD GetHilite() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETHILITE, 0, 0L); + } + + void GetHilite(int& nStartChar, int& nEndChar) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, EM_GETHILITE, 0, 0L); + nStartChar = (int)(short)LOWORD(dwRet); + nEndChar = (int)(short)HIWORD(dwRet); + } + + void SetHilite(int nStartChar, int nEndChar) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETHILITE, nStartChar, nEndChar); + } +#endif // (_WIN32_WINNT >= 0x0600) + + // Clipboard operations + BOOL Undo() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_UNDO, 0, 0L); + } + + void Clear() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_CLEAR, 0, 0L); + } + + void Copy() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_COPY, 0, 0L); + } + + void Cut() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_CUT, 0, 0L); + } + + void Paste() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_PASTE, 0, 0L); + } + + // New messages added in Windows 10.0.17763 +#if defined(NTDDI_VERSION) && defined(NTDDI_WIN10_RS5) && (NTDDI_VERSION >= NTDDI_WIN10_RS5) + DWORD SetExtendedStyle(DWORD dwStyle, DWORD dwMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SendMessage(this->m_hWnd, EM_SETEXTENDEDSTYLE, dwMask, dwStyle); + } + + DWORD GetExtendedStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SendMessage(this->m_hWnd, EM_GETEXTENDEDSTYLE, 0, 0L); + } + + BOOL SetEndOfLine(EC_ENDOFLINE eolType) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETENDOFLINE, eolType, 0L); + } + + EC_ENDOFLINE GetEndOfLine() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (EC_ENDOFLINE)::SendMessage(this->m_hWnd, EM_GETENDOFLINE, 0, 0L); + } + + BOOL EnableSearchWeb(BOOL bEnable) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_ENABLESEARCHWEB, (WPARAM)bEnable, 0L); + } + + void SearchWeb() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SEARCHWEB, 0, 0L); + } + + BOOL SetCaretIndex(DWORD dwCaretIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCARETINDEX, dwCaretIndex, 0L); + } + + DWORD GetCaretIndex() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SendMessage(this->m_hWnd, EM_GETCARETINDEX, 0, 0L); + } + + BOOL GetZoom(int& nNum, int& nDen) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_GETZOOM, (WPARAM)&nNum, (LPARAM)&nDen); + } + + BOOL SetZoom(int nNum, int nDen) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((nNum >= 0) && (nNum <= 64)); + ATLASSERT((nDen >= 0) && (nDen <= 64)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETZOOM, nNum, nDen); + } + + DWORD GetFileLineFromChar(DWORD dwCharIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SendMessage(this->m_hWnd, EM_FILELINEFROMCHAR, dwCharIndex, 0L); + } + + DWORD GetFileLineIndex(DWORD dwLineNum) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SendMessage(this->m_hWnd, EM_FILELINEINDEX, dwLineNum, 0L); + } + + DWORD GetFileLineLength(DWORD dwCharIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SendMessage(this->m_hWnd, EM_FILELINELENGTH, dwCharIndex, 0L); + } + + DWORD GetFileLine(DWORD dwLineNum, LPTSTR lpstrLine, WORD wLen) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + WORD* pw = (WORD*)lpstrLine; + *pw = wLen; + return ::SendMessage(this->m_hWnd, EM_GETFILELINE, dwLineNum, (LPARAM)lpstrLine); + } + +#ifdef __ATLSTR_H__ + ATL::CString GetFileLine(DWORD dwLineNum) const + { + ATL::CString strLine; + DWORD dwCharIndex = GetFileLineIndex(dwLineNum); + if(dwCharIndex != (DWORD)-1) + { + DWORD dwLen = GetFileLineLength(dwCharIndex); + if(dwLen > 0) + { + LPTSTR lpstrLine = strLine.GetBufferSetLength(dwLen); + ATLVERIFY(GetFileLine(dwLineNum, lpstrLine, (WORD)dwLen) == dwLen); + strLine.ReleaseBuffer(); + } + } + + return strLine; + } +#endif // __ATLSTR_H__ + + DWORD GetFileLineCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SendMessage(this->m_hWnd, EM_GETFILELINECOUNT, 0, 0L); + } +#endif // defined(NTDDI_VERSION) && defined(NTDDI_WIN10_RS5) && (NTDDI_VERSION >= NTDDI_WIN10_RS5) +}; + +typedef CEditT<ATL::CWindow> CEdit; + + +/////////////////////////////////////////////////////////////////////////////// +// CEditCommands - message handlers for standard EDIT commands + +// Chain to CEditCommands message map. Your class must also derive from CEdit. +// Example: +// class CMyEdit : public CWindowImpl<CMyEdit, CEdit>, +// public CEditCommands<CMyEdit> +// { +// public: +// BEGIN_MSG_MAP(CMyEdit) +// // your handlers... +// CHAIN_MSG_MAP_ALT(CEditCommands<CMyEdit>, 1) +// END_MSG_MAP() +// // other stuff... +// }; + +template <class T> +class CEditCommands +{ +public: + BEGIN_MSG_MAP(CEditCommands< T >) + ALT_MSG_MAP(1) + COMMAND_ID_HANDLER(ID_EDIT_CLEAR, OnEditClear) + COMMAND_ID_HANDLER(ID_EDIT_CLEAR_ALL, OnEditClearAll) + COMMAND_ID_HANDLER(ID_EDIT_COPY, OnEditCopy) + COMMAND_ID_HANDLER(ID_EDIT_CUT, OnEditCut) + COMMAND_ID_HANDLER(ID_EDIT_PASTE, OnEditPaste) + COMMAND_ID_HANDLER(ID_EDIT_SELECT_ALL, OnEditSelectAll) + COMMAND_ID_HANDLER(ID_EDIT_UNDO, OnEditUndo) + END_MSG_MAP() + + LRESULT OnEditClear(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/) + { + T* pT = static_cast<T*>(this); + pT->Clear(); + return 0; + } + + LRESULT OnEditClearAll(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/) + { + T* pT = static_cast<T*>(this); + pT->SetSel(0, -1); + pT->Clear(); + return 0; + } + + LRESULT OnEditCopy(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/) + { + T* pT = static_cast<T*>(this); + pT->Copy(); + return 0; + } + + LRESULT OnEditCut(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/) + { + T* pT = static_cast<T*>(this); + pT->Cut(); + return 0; + } + + LRESULT OnEditPaste(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/) + { + T* pT = static_cast<T*>(this); + pT->Paste(); + return 0; + } + + LRESULT OnEditSelectAll(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/) + { + T* pT = static_cast<T*>(this); + pT->SetSel(0, -1); + return 0; + } + + LRESULT OnEditUndo(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/) + { + T* pT = static_cast<T*>(this); + pT->Undo(); + return 0; + } + +// State (update UI) helpers + BOOL CanCut() const + { return HasSelection(); } + + BOOL CanCopy() const + { return HasSelection(); } + + BOOL CanClear() const + { return HasSelection(); } + + BOOL CanSelectAll() const + { return HasText(); } + + BOOL CanFind() const + { return HasText(); } + + BOOL CanRepeat() const + { return HasText(); } + + BOOL CanReplace() const + { return HasText(); } + + BOOL CanClearAll() const + { return HasText(); } + +// Implementation + BOOL HasSelection() const + { + const T* pT = static_cast<const T*>(this); + int nMin = 0, nMax = 0; + ::SendMessage(pT->m_hWnd, EM_GETSEL, (WPARAM)&nMin, (LPARAM)&nMax); + return (nMin != nMax); + } + + BOOL HasText() const + { + const T* pT = static_cast<const T*>(this); + return (pT->GetWindowTextLength() > 0); + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CScrollBar - client side for a Windows SCROLLBAR control + +template <class TBase> +class CScrollBarT : public TBase +{ +public: +// Constructors + CScrollBarT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CScrollBarT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return _T("SCROLLBAR"); + } + + int GetScrollPos() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::GetScrollPos(this->m_hWnd, SB_CTL); + } + + int SetScrollPos(int nPos, BOOL bRedraw = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SetScrollPos(this->m_hWnd, SB_CTL, nPos, bRedraw); + } + + void GetScrollRange(LPINT lpMinPos, LPINT lpMaxPos) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::GetScrollRange(this->m_hWnd, SB_CTL, lpMinPos, lpMaxPos); + } + + void SetScrollRange(int nMinPos, int nMaxPos, BOOL bRedraw = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SetScrollRange(this->m_hWnd, SB_CTL, nMinPos, nMaxPos, bRedraw); + } + + BOOL GetScrollInfo(LPSCROLLINFO lpScrollInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::GetScrollInfo(this->m_hWnd, SB_CTL, lpScrollInfo); + } + + int SetScrollInfo(LPSCROLLINFO lpScrollInfo, BOOL bRedraw = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::SetScrollInfo(this->m_hWnd, SB_CTL, lpScrollInfo, bRedraw); + } + + int GetScrollLimit() const + { + SCROLLINFO info = { sizeof(SCROLLINFO), SIF_RANGE | SIF_PAGE }; + ::GetScrollInfo(this->m_hWnd, SB_CTL, &info); + if(info.nPage > 1) + info.nMax -= info.nPage - 1; + + return info.nMax; + } + + BOOL GetScrollBarInfo(PSCROLLBARINFO pScrollBarInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, SBM_GETSCROLLBARINFO, 0, (LPARAM)pScrollBarInfo); + } + +// Operations + void ShowScrollBar(BOOL bShow = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::ShowScrollBar(this->m_hWnd, SB_CTL, bShow); + } + + BOOL EnableScrollBar(UINT nArrowFlags = ESB_ENABLE_BOTH) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::EnableScrollBar(this->m_hWnd, SB_CTL, nArrowFlags); + } +}; + +typedef CScrollBarT<ATL::CWindow> CScrollBar; + + +// --- Windows Common Controls --- + +/////////////////////////////////////////////////////////////////////////////// +// CImageList + +// forward declarations +template <bool t_bManaged> class CImageListT; +typedef CImageListT<false> CImageList; +typedef CImageListT<true> CImageListManaged; + + +template <bool t_bManaged> +class CImageListT +{ +public: +// Data members + HIMAGELIST m_hImageList; + +// Constructor/destructor/operators + CImageListT(HIMAGELIST hImageList = NULL) : m_hImageList(hImageList) + { } + + ~CImageListT() + { + if(t_bManaged && (m_hImageList != NULL)) + Destroy(); + } + + CImageListT<t_bManaged>& operator =(HIMAGELIST hImageList) + { + Attach(hImageList); + return *this; + } + + void Attach(HIMAGELIST hImageList) + { + if(t_bManaged && (m_hImageList != NULL) && (m_hImageList != hImageList)) + ImageList_Destroy(m_hImageList); + m_hImageList = hImageList; + } + + HIMAGELIST Detach() + { + HIMAGELIST hImageList = m_hImageList; + m_hImageList = NULL; + return hImageList; + } + + operator HIMAGELIST() const { return m_hImageList; } + + bool IsNull() const { return (m_hImageList == NULL); } + +// Attributes + int GetImageCount() const + { + ATLASSERT(m_hImageList != NULL); + return ImageList_GetImageCount(m_hImageList); + } + + COLORREF GetBkColor() const + { + ATLASSERT(m_hImageList != NULL); + return ImageList_GetBkColor(m_hImageList); + } + + COLORREF SetBkColor(COLORREF cr) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_SetBkColor(m_hImageList, cr); + } + + BOOL GetImageInfo(int nImage, IMAGEINFO* pImageInfo) const + { + ATLASSERT(m_hImageList != NULL); + return ImageList_GetImageInfo(m_hImageList, nImage, pImageInfo); + } + + HICON GetIcon(int nIndex, UINT uFlags = ILD_NORMAL) const + { + ATLASSERT(m_hImageList != NULL); + return ImageList_GetIcon(m_hImageList, nIndex, uFlags); + } + + BOOL GetIconSize(int& cx, int& cy) const + { + ATLASSERT(m_hImageList != NULL); + return ImageList_GetIconSize(m_hImageList, &cx, &cy); + } + + BOOL GetIconSize(SIZE& size) const + { + ATLASSERT(m_hImageList != NULL); + return ImageList_GetIconSize(m_hImageList, (int*)&size.cx, (int*)&size.cy); + } + + BOOL SetIconSize(int cx, int cy) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_SetIconSize(m_hImageList, cx, cy); + } + + BOOL SetIconSize(SIZE size) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_SetIconSize(m_hImageList, size.cx, size.cy); + } + + BOOL SetImageCount(UINT uNewCount) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_SetImageCount(m_hImageList, uNewCount); + } + + BOOL SetOverlayImage(int nImage, int nOverlay) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_SetOverlayImage(m_hImageList, nImage, nOverlay); + } + +// Operations + BOOL Create(int cx, int cy, UINT nFlags, int nInitial, int nGrow) + { + ATLASSERT(m_hImageList == NULL); + m_hImageList = ImageList_Create(cx, cy, nFlags, nInitial, nGrow); + return (m_hImageList != NULL) ? TRUE : FALSE; + } + + BOOL Create(ATL::_U_STRINGorID bitmap, int cx, int nGrow, COLORREF crMask) + { + ATLASSERT(m_hImageList == NULL); + m_hImageList = ImageList_LoadBitmap(ModuleHelper::GetResourceInstance(), bitmap.m_lpstr, cx, nGrow, crMask); + return (m_hImageList != NULL) ? TRUE : FALSE; + } + + BOOL CreateFromImage(ATL::_U_STRINGorID image, int cx, int nGrow, COLORREF crMask, UINT uType, UINT uFlags = LR_DEFAULTCOLOR | LR_DEFAULTSIZE) + { + ATLASSERT(m_hImageList == NULL); + m_hImageList = ImageList_LoadImage(ModuleHelper::GetResourceInstance(), image.m_lpstr, cx, nGrow, crMask, uType, uFlags); + return (m_hImageList != NULL) ? TRUE : FALSE; + } + + BOOL Merge(HIMAGELIST hImageList1, int nImage1, HIMAGELIST hImageList2, int nImage2, int dx, int dy) + { + ATLASSERT(m_hImageList == NULL); + m_hImageList = ImageList_Merge(hImageList1, nImage1, hImageList2, nImage2, dx, dy); + return (m_hImageList != NULL) ? TRUE : FALSE; + } + +#ifdef __IStream_INTERFACE_DEFINED__ + BOOL CreateFromStream(LPSTREAM lpStream) + { + ATLASSERT(m_hImageList == NULL); + m_hImageList = ImageList_Read(lpStream); + return (m_hImageList != NULL) ? TRUE : FALSE; + } +#endif // __IStream_INTERFACE_DEFINED__ + + BOOL Destroy() + { + if (m_hImageList == NULL) + return FALSE; + BOOL bRet = ImageList_Destroy(m_hImageList); + if(bRet) + m_hImageList = NULL; + return bRet; + } + + int Add(HBITMAP hBitmap, HBITMAP hBitmapMask = NULL) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_Add(m_hImageList, hBitmap, hBitmapMask); + } + + int Add(HBITMAP hBitmap, COLORREF crMask) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_AddMasked(m_hImageList, hBitmap, crMask); + } + + BOOL Remove(int nImage) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_Remove(m_hImageList, nImage); + } + + BOOL RemoveAll() + { + ATLASSERT(m_hImageList != NULL); + return ImageList_RemoveAll(m_hImageList); + } + + BOOL Replace(int nImage, HBITMAP hBitmap, HBITMAP hBitmapMask) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_Replace(m_hImageList, nImage, hBitmap, hBitmapMask); + } + + int AddIcon(HICON hIcon) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_AddIcon(m_hImageList, hIcon); + } + + int ReplaceIcon(int nImage, HICON hIcon) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_ReplaceIcon(m_hImageList, nImage, hIcon); + } + + HICON ExtractIcon(int nImage) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_ExtractIcon(NULL, m_hImageList, nImage); + } + + BOOL Draw(HDC hDC, int nImage, int x, int y, UINT nStyle) + { + ATLASSERT(m_hImageList != NULL); + ATLASSERT(hDC != NULL); + return ImageList_Draw(m_hImageList, nImage, hDC, x, y, nStyle); + } + + BOOL Draw(HDC hDC, int nImage, POINT pt, UINT nStyle) + { + ATLASSERT(m_hImageList != NULL); + ATLASSERT(hDC != NULL); + return ImageList_Draw(m_hImageList, nImage, hDC, pt.x, pt.y, nStyle); + } + + BOOL DrawEx(int nImage, HDC hDC, int x, int y, int dx, int dy, COLORREF rgbBk, COLORREF rgbFg, UINT fStyle) + { + ATLASSERT(m_hImageList != NULL); + ATLASSERT(hDC != NULL); + return ImageList_DrawEx(m_hImageList, nImage, hDC, x, y, dx, dy, rgbBk, rgbFg, fStyle); + } + + BOOL DrawEx(int nImage, HDC hDC, RECT& rect, COLORREF rgbBk, COLORREF rgbFg, UINT fStyle) + { + ATLASSERT(m_hImageList != NULL); + ATLASSERT(hDC != NULL); + return ImageList_DrawEx(m_hImageList, nImage, hDC, rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, rgbBk, rgbFg, fStyle); + } + + static BOOL DrawIndirect(IMAGELISTDRAWPARAMS* pimldp) + { + return ImageList_DrawIndirect(pimldp); + } + + BOOL Copy(int nSrc, int nDst, UINT uFlags = ILCF_MOVE) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_Copy(m_hImageList, nDst, m_hImageList, nSrc, uFlags); + } + +#ifdef __IStream_INTERFACE_DEFINED__ + static HIMAGELIST Read(LPSTREAM lpStream) + { + return ImageList_Read(lpStream); + } + + BOOL Write(LPSTREAM lpStream) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_Write(m_hImageList, lpStream); + } + + static HRESULT ReadEx(DWORD dwFlags, LPSTREAM lpStream, REFIID riid, PVOID* ppv) + { + return ImageList_ReadEx(dwFlags, lpStream, riid, ppv); + } + + HRESULT WriteEx(DWORD dwFlags, LPSTREAM lpStream) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_WriteEx(m_hImageList, dwFlags, lpStream); + } +#endif // __IStream_INTERFACE_DEFINED__ + + // Drag operations + BOOL BeginDrag(int nImage, POINT ptHotSpot) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_BeginDrag(m_hImageList, nImage, ptHotSpot.x, ptHotSpot.y); + } + + BOOL BeginDrag(int nImage, int xHotSpot, int yHotSpot) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_BeginDrag(m_hImageList, nImage, xHotSpot, yHotSpot); + } + + static void EndDrag() + { + ImageList_EndDrag(); + } + + static BOOL DragMove(POINT pt) + { + return ImageList_DragMove(pt.x, pt.y); + } + + static BOOL DragMove(int x, int y) + { + return ImageList_DragMove(x, y); + } + + BOOL SetDragCursorImage(int nDrag, POINT ptHotSpot) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_SetDragCursorImage(m_hImageList, nDrag, ptHotSpot.x, ptHotSpot.y); + } + + BOOL SetDragCursorImage(int nDrag, int xHotSpot, int yHotSpot) + { + ATLASSERT(m_hImageList != NULL); + return ImageList_SetDragCursorImage(m_hImageList, nDrag, xHotSpot, yHotSpot); + } + + static BOOL DragShowNolock(BOOL bShow = TRUE) + { + return ImageList_DragShowNolock(bShow); + } + + static CImageList GetDragImage(LPPOINT lpPoint, LPPOINT lpPointHotSpot) + { + return CImageList(ImageList_GetDragImage(lpPoint, lpPointHotSpot)); + } + + static BOOL DragEnter(HWND hWnd, POINT point) + { + return ImageList_DragEnter(hWnd, point.x, point.y); + } + + static BOOL DragEnter(HWND hWnd, int x, int y) + { + return ImageList_DragEnter(hWnd, x, y); + } + + static BOOL DragLeave(HWND hWnd) + { + return ImageList_DragLeave(hWnd); + } + + CImageList Duplicate() const + { + ATLASSERT(m_hImageList != NULL); + return CImageList(ImageList_Duplicate(m_hImageList)); + } + + static CImageList Duplicate(HIMAGELIST hImageList) + { + ATLASSERT(hImageList != NULL); + return CImageList(ImageList_Duplicate(hImageList)); + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CToolTipCtrl + +class CToolInfo : public TOOLINFO +{ +public: + CToolInfo(UINT nFlags, HWND hWnd, UINT_PTR nIDTool = 0, LPRECT lpRect = NULL, LPTSTR lpstrText = LPSTR_TEXTCALLBACK, LPARAM lUserParam = NULL) + { + Init(nFlags, hWnd, nIDTool, lpRect, lpstrText, lUserParam); + } + + operator LPTOOLINFO() { return this; } + + operator LPARAM() { return (LPARAM)this; } + + void Init(UINT nFlags, HWND hWnd, UINT_PTR nIDTool = 0, LPRECT lpRect = NULL, LPTSTR lpstrText = LPSTR_TEXTCALLBACK, LPARAM lUserParam = NULL) + { + ATLASSERT(::IsWindow(hWnd)); + memset(this, 0, sizeof(TOOLINFO)); + cbSize = RunTimeHelper::SizeOf_TOOLINFO(); + uFlags = nFlags; + if(nIDTool == 0) + { + hwnd = ::GetParent(hWnd); + uFlags |= TTF_IDISHWND; + uId = (UINT_PTR)hWnd; + } + else + { + hwnd = hWnd; + uId = nIDTool; + } + if(lpRect != NULL) + rect = *lpRect; + hinst = ModuleHelper::GetResourceInstance(); + lpszText = lpstrText; + lParam = lUserParam; + } +}; + +template <class TBase> +class CToolTipCtrlT : public TBase +{ +public: +// Constructors + CToolTipCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CToolTipCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return TOOLTIPS_CLASS; + } + + void GetText(LPTOOLINFO lpToolInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_GETTEXT, 0, (LPARAM)&lpToolInfo); + } + + void GetText(LPTSTR lpstrText, HWND hWnd, UINT_PTR nIDTool = 0) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(hWnd != NULL); + CToolInfo ti(0, hWnd, nIDTool, NULL, lpstrText); + ::SendMessage(this->m_hWnd, TTM_GETTEXT, 0, ti); + } + + BOOL GetToolInfo(LPTOOLINFO lpToolInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TTM_GETTOOLINFO, 0, (LPARAM)lpToolInfo); + } + + BOOL GetToolInfo(HWND hWnd, UINT_PTR nIDTool, UINT* puFlags, LPRECT lpRect, LPTSTR lpstrText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(hWnd != NULL); + ATLASSERT(puFlags != NULL); + ATLASSERT(lpRect != NULL); + CToolInfo ti(0, hWnd, nIDTool, NULL, lpstrText); + BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, TTM_GETTOOLINFO, 0, ti); + if(bRet != FALSE) + { + *puFlags = ti.uFlags; + *lpRect = ti.rect; + } + return bRet; + } + + void SetToolInfo(LPTOOLINFO lpToolInfo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_SETTOOLINFO, 0, (LPARAM)lpToolInfo); + } + + void SetToolRect(LPTOOLINFO lpToolInfo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_NEWTOOLRECT, 0, (LPARAM)lpToolInfo); + } + + void SetToolRect(HWND hWnd, UINT_PTR nIDTool, LPCRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(hWnd != NULL); + ATLASSERT(nIDTool != 0); + + CToolInfo ti(0, hWnd, nIDTool, (LPRECT)lpRect, NULL); + ::SendMessage(this->m_hWnd, TTM_NEWTOOLRECT, 0, ti); + } + + int GetToolCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TTM_GETTOOLCOUNT, 0, 0L); + } + + int GetDelayTime(DWORD dwType) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TTM_GETDELAYTIME, dwType, 0L); + } + + void SetDelayTime(DWORD dwType, int nTime) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_SETDELAYTIME, dwType, MAKELPARAM(nTime, 0)); + } + + void GetMargin(LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_GETMARGIN, 0, (LPARAM)lpRect); + } + + void SetMargin(LPRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_SETMARGIN, 0, (LPARAM)lpRect); + } + + int GetMaxTipWidth() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TTM_GETMAXTIPWIDTH, 0, 0L); + } + + int SetMaxTipWidth(int nWidth) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TTM_SETMAXTIPWIDTH, 0, nWidth); + } + + COLORREF GetTipBkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TTM_GETTIPBKCOLOR, 0, 0L); + } + + void SetTipBkColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_SETTIPBKCOLOR, (WPARAM)clr, 0L); + } + + COLORREF GetTipTextColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TTM_GETTIPTEXTCOLOR, 0, 0L); + } + + void SetTipTextColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_SETTIPTEXTCOLOR, (WPARAM)clr, 0L); + } + + BOOL GetCurrentTool(LPTOOLINFO lpToolInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TTM_GETCURRENTTOOL, 0, (LPARAM)lpToolInfo); + } + + SIZE GetBubbleSize(LPTOOLINFO lpToolInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, TTM_GETBUBBLESIZE, 0, (LPARAM)lpToolInfo); + SIZE size = { GET_X_LPARAM(dwRet), GET_Y_LPARAM(dwRet) }; + return size; + } + + BOOL SetTitle(UINT_PTR uIcon, LPCTSTR lpstrTitle) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TTM_SETTITLE, uIcon, (LPARAM)lpstrTitle); + } + + + BOOL SetTitle(HICON hIcon, LPCTSTR lpstrTitle) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TTM_SETTITLE, (WPARAM)hIcon, (LPARAM)lpstrTitle); + } + + void GetTitle(PTTGETTITLE pTTGetTitle) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_GETTITLE, 0, (LPARAM)pTTGetTitle); + } + + void SetWindowTheme(LPCWSTR lpstrTheme) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_SETWINDOWTHEME, 0, (LPARAM)lpstrTheme); + } + +// Operations + void Activate(BOOL bActivate) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_ACTIVATE, bActivate, 0L); + } + + BOOL AddTool(LPTOOLINFO lpToolInfo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TTM_ADDTOOL, 0, (LPARAM)lpToolInfo); + } + + BOOL AddTool(HWND hWnd, ATL::_U_STRINGorID text = LPSTR_TEXTCALLBACK, LPCRECT lpRectTool = NULL, UINT_PTR nIDTool = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(hWnd != NULL); + // the toolrect and toolid must both be zero or both valid + ATLASSERT(((lpRectTool != NULL) && (nIDTool != 0)) || ((lpRectTool == NULL) && (nIDTool == 0))); + + CToolInfo ti(0, hWnd, nIDTool, (LPRECT)lpRectTool, (LPTSTR)text.m_lpstr); + return (BOOL)::SendMessage(this->m_hWnd, TTM_ADDTOOL, 0, ti); + } + + void DelTool(LPTOOLINFO lpToolInfo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_DELTOOL, 0, (LPARAM)lpToolInfo); + } + + void DelTool(HWND hWnd, UINT_PTR nIDTool = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(hWnd != NULL); + + CToolInfo ti(0, hWnd, nIDTool, NULL, NULL); + ::SendMessage(this->m_hWnd, TTM_DELTOOL, 0, ti); + } + + BOOL HitTest(LPTTHITTESTINFO lpHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TTM_HITTEST, 0, (LPARAM)lpHitTestInfo); + } + + BOOL HitTest(HWND hWnd, POINT pt, LPTOOLINFO lpToolInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(hWnd != NULL); + ATLASSERT(lpToolInfo != NULL); + + TTHITTESTINFO hti = {}; + hti.ti.cbSize = RunTimeHelper::SizeOf_TOOLINFO(); + hti.hwnd = hWnd; + hti.pt.x = pt.x; + hti.pt.y = pt.y; + if((BOOL)::SendMessage(this->m_hWnd, TTM_HITTEST, 0, (LPARAM)&hti) != FALSE) + { + *lpToolInfo = hti.ti; + return TRUE; + } + return FALSE; + } + + void RelayEvent(LPMSG lpMsg) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_RELAYEVENT, 0, (LPARAM)lpMsg); + } + + void UpdateTipText(LPTOOLINFO lpToolInfo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_UPDATETIPTEXT, 0, (LPARAM)lpToolInfo); + } + + void UpdateTipText(ATL::_U_STRINGorID text, HWND hWnd, UINT_PTR nIDTool = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(hWnd != NULL); + + CToolInfo ti(0, hWnd, nIDTool, NULL, (LPTSTR)text.m_lpstr); + ::SendMessage(this->m_hWnd, TTM_UPDATETIPTEXT, 0, ti); + } + + BOOL EnumTools(UINT_PTR nTool, LPTOOLINFO lpToolInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TTM_ENUMTOOLS, nTool, (LPARAM)lpToolInfo); + } + + void Pop() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_POP, 0, 0L); + } + + void TrackActivate(LPTOOLINFO lpToolInfo, BOOL bActivate) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_TRACKACTIVATE, bActivate, (LPARAM)lpToolInfo); + } + + void TrackActivate(HWND hWnd, UINT_PTR nIDTool, BOOL bActivate) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(hWnd != NULL); + + CToolInfo ti(0, hWnd, nIDTool); + ::SendMessage(this->m_hWnd, TTM_TRACKACTIVATE, bActivate, ti); + } + + void TrackPosition(int xPos, int yPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_TRACKPOSITION, 0, MAKELPARAM(xPos, yPos)); + } + + void Update() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_UPDATE, 0, 0L); + } + + BOOL AdjustRect(LPRECT lpRect, BOOL bLarger /*= TRUE*/) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TTM_ADJUSTRECT, bLarger, (LPARAM)lpRect); + } + + void Popup() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TTM_POPUP, 0, 0L); + } +}; + +typedef CToolTipCtrlT<ATL::CWindow> CToolTipCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CHeaderCtrl + +template <class TBase> +class CHeaderCtrlT : public TBase +{ +public: +// Constructors + CHeaderCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CHeaderCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return WC_HEADER; + } + + int GetItemCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_GETITEMCOUNT, 0, 0L); + } + + BOOL GetItem(int nIndex, LPHDITEM pHeaderItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_GETITEM, nIndex, (LPARAM)pHeaderItem); + } + + BOOL SetItem(int nIndex, LPHDITEM pHeaderItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_SETITEM, nIndex, (LPARAM)pHeaderItem); + } + + CImageList GetImageList() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, HDM_GETIMAGELIST, 0, 0L)); + } + + CImageList SetImageList(HIMAGELIST hImageList) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, HDM_SETIMAGELIST, 0, (LPARAM)hImageList)); + } + + BOOL GetOrderArray(int nSize, int* lpnArray) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_GETORDERARRAY, nSize, (LPARAM)lpnArray); + } + + BOOL SetOrderArray(int nSize, int* lpnArray) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_SETORDERARRAY, nSize, (LPARAM)lpnArray); + } + + BOOL GetItemRect(int nIndex, LPRECT lpItemRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_GETITEMRECT, nIndex, (LPARAM)lpItemRect); + } + + int SetHotDivider(BOOL bPos, DWORD dwInputValue) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_SETHOTDIVIDER, bPos, dwInputValue); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_SETUNICODEFORMAT, bUnicode, 0L); + } + + int GetBitmapMargin() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_GETBITMAPMARGIN, 0, 0L); + } + + int SetBitmapMargin(int nWidth) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_SETBITMAPMARGIN, nWidth, 0L); + } + + int SetFilterChangeTimeout(DWORD dwTimeOut) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_SETFILTERCHANGETIMEOUT, 0, dwTimeOut); + } + +#if (_WIN32_WINNT >= 0x0600) + BOOL GetItemDropDownRect(int nIndex, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_GETITEMDROPDOWNRECT, nIndex, (LPARAM)lpRect); + } + + BOOL GetOverflowRect(LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_GETOVERFLOWRECT, 0, (LPARAM)lpRect); + } + + int GetFocusedItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_GETFOCUSEDITEM, 0, 0L); + } + + BOOL SetFocusedItem(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_SETFOCUSEDITEM, 0, nIndex); + } +#endif // (_WIN32_WINNT >= 0x0600) + +// Operations + int InsertItem(int nIndex, LPHDITEM phdi) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_INSERTITEM, nIndex, (LPARAM)phdi); + } + + int AddItem(LPHDITEM phdi) + { + return InsertItem(GetItemCount(), phdi); + } + + BOOL DeleteItem(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_DELETEITEM, nIndex, 0L); + } + + BOOL Layout(HD_LAYOUT* pHeaderLayout) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, HDM_LAYOUT, 0, (LPARAM)pHeaderLayout); + } + + int HitTest(LPHDHITTESTINFO lpHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_HITTEST, 0, (LPARAM)lpHitTestInfo); + } + + int OrderToIndex(int nOrder) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_ORDERTOINDEX, nOrder, 0L); + } + + CImageList CreateDragImage(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, HDM_CREATEDRAGIMAGE, nIndex, 0L)); + } + + int EditFilter(int nColumn, BOOL bDiscardChanges) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_EDITFILTER, nColumn, MAKELPARAM(bDiscardChanges, 0)); + } + + int ClearFilter(int nColumn) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_CLEARFILTER, nColumn, 0L); + } + + int ClearAllFilters() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, HDM_CLEARFILTER, (WPARAM)-1, 0L); + } +}; + +typedef CHeaderCtrlT<ATL::CWindow> CHeaderCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CListViewCtrl + +template <class TBase> +class CListViewCtrlT : public TBase +{ +public: +// Constructors + CListViewCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CListViewCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return WC_LISTVIEW; + } + + COLORREF GetBkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETBKCOLOR, 0, 0L); + } + + BOOL SetBkColor(COLORREF cr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETBKCOLOR, 0, cr); + } + + CImageList GetImageList(int nImageListType) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, LVM_GETIMAGELIST, nImageListType, 0L)); + } + + CImageList SetImageList(HIMAGELIST hImageList, int nImageList) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, LVM_SETIMAGELIST, nImageList, (LPARAM)hImageList)); + } + + int GetItemCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETITEMCOUNT, 0, 0L); + } + + BOOL SetItemCount(int nItems) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMCOUNT, nItems, 0L); + } + + BOOL GetItem(LPLVITEM pItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEM, 0, (LPARAM)pItem); + } + + BOOL SetItem(const LVITEM* pItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEM, 0, (LPARAM)pItem); + } + + BOOL SetItem(int nItem, int nSubItem, UINT nMask, LPCTSTR lpszItem, + int nImage, UINT nState, UINT nStateMask, LPARAM lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVITEM lvi = {}; + lvi.mask = nMask; + lvi.iItem = nItem; + lvi.iSubItem = nSubItem; + lvi.stateMask = nStateMask; + lvi.state = nState; + lvi.pszText = (LPTSTR) lpszItem; + lvi.iImage = nImage; + lvi.lParam = lParam; + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEM, 0, (LPARAM)&lvi); + } + + UINT GetItemState(int nItem, UINT nMask) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, LVM_GETITEMSTATE, nItem, nMask); + } + + BOOL SetItemState(int nItem, UINT nState, UINT nStateMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVITEM lvi = {}; + lvi.state = nState; + lvi.stateMask = nStateMask; + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMSTATE, nItem, (LPARAM)&lvi); + } + + BOOL SetItemState(int nItem, LPLVITEM pItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMSTATE, nItem, (LPARAM)pItem); + } + + BOOL GetItemText(int nItem, int nSubItem, BSTR& bstrText) const + { + USES_CONVERSION; + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(bstrText == NULL); + LVITEM lvi = {}; + lvi.iSubItem = nSubItem; + + LPTSTR lpstrText = NULL; + int nRes = 0; + for(int nLen = 256; ; nLen *= 2) + { + ATLTRY(lpstrText = new TCHAR[nLen]); + if(lpstrText == NULL) + break; + lpstrText[0] = NULL; + lvi.cchTextMax = nLen; + lvi.pszText = lpstrText; + nRes = (int)::SendMessage(this->m_hWnd, LVM_GETITEMTEXT, (WPARAM)nItem, (LPARAM)&lvi); + if(nRes < nLen - 1) + break; + delete [] lpstrText; + lpstrText = NULL; + } + + if(lpstrText != NULL) + { + if(nRes != 0) + bstrText = ::SysAllocString(T2OLE(lpstrText)); + delete [] lpstrText; + } + + return (bstrText != NULL) ? TRUE : FALSE; + } + +#ifdef __ATLSTR_H__ + int GetItemText(int nItem, int nSubItem, ATL::CString& strText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVITEM lvi = {}; + lvi.iSubItem = nSubItem; + + strText.Empty(); + int nRes = 0; + for(int nLen = 256; ; nLen *= 2) + { + lvi.cchTextMax = nLen; + lvi.pszText = strText.GetBufferSetLength(nLen); + if(lvi.pszText == NULL) + { + nRes = 0; + break; + } + nRes = (int)::SendMessage(this->m_hWnd, LVM_GETITEMTEXT, (WPARAM)nItem, (LPARAM)&lvi); + if(nRes < nLen - 1) + break; + } + strText.ReleaseBuffer(); + return nRes; + } +#endif // __ATLSTR_H__ + + int GetItemText(int nItem, int nSubItem, LPTSTR lpszText, int nLen) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVITEM lvi = {}; + lvi.iSubItem = nSubItem; + lvi.cchTextMax = nLen; + lvi.pszText = lpszText; + return (int)::SendMessage(this->m_hWnd, LVM_GETITEMTEXT, (WPARAM)nItem, (LPARAM)&lvi); + } + + BOOL SetItemText(int nItem, int nSubItem, LPCTSTR lpszText) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItem(nItem, nSubItem, LVIF_TEXT, lpszText, 0, 0, 0, 0); + } + + DWORD_PTR GetItemData(int nItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVITEM lvi = {}; + lvi.iItem = nItem; + lvi.mask = LVIF_PARAM; + BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEM, 0, (LPARAM)&lvi); + return (DWORD_PTR)(bRet ? lvi.lParam : NULL); + } + + BOOL SetItemData(int nItem, DWORD_PTR dwData) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItem(nItem, 0, LVIF_PARAM, NULL, 0, 0, 0, (LPARAM)dwData); + } + + UINT GetCallbackMask() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, LVM_GETCALLBACKMASK, 0, 0L); + } + + BOOL SetCallbackMask(UINT nMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETCALLBACKMASK, nMask, 0L); + } + + BOOL GetItemPosition(int nItem, LPPOINT lpPoint) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEMPOSITION, nItem, (LPARAM)lpPoint); + } + + BOOL SetItemPosition(int nItem, POINT pt) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(((this->GetStyle() & LVS_TYPEMASK) == LVS_ICON) || ((this->GetStyle() & LVS_TYPEMASK) == LVS_SMALLICON)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMPOSITION32, nItem, (LPARAM)&pt); + } + + BOOL SetItemPosition(int nItem, int x, int y) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(((this->GetStyle() & LVS_TYPEMASK) == LVS_ICON) || ((this->GetStyle() & LVS_TYPEMASK) == LVS_SMALLICON)); + POINT pt = { x, y }; + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMPOSITION32, nItem, (LPARAM)&pt); + } + + int GetStringWidth(LPCTSTR lpsz) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETSTRINGWIDTH, 0, (LPARAM)lpsz); + } + + CEdit GetEditControl() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CEdit((HWND)::SendMessage(this->m_hWnd, LVM_GETEDITCONTROL, 0, 0L)); + } + + BOOL GetColumn(int nCol, LVCOLUMN* pColumn) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETCOLUMN, nCol, (LPARAM)pColumn); + } + + BOOL SetColumn(int nCol, const LVCOLUMN* pColumn) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETCOLUMN, nCol, (LPARAM)pColumn); + } + + int GetColumnWidth(int nCol) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETCOLUMNWIDTH, nCol, 0L); + } + + BOOL SetColumnWidth(int nCol, int cx) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETCOLUMNWIDTH, nCol, MAKELPARAM(cx, 0)); + } + + BOOL GetViewRect(LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETVIEWRECT, 0, (LPARAM)lpRect); + } + + COLORREF GetTextColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETTEXTCOLOR, 0, 0L); + } + + BOOL SetTextColor(COLORREF cr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETTEXTCOLOR, 0, cr); + } + + COLORREF GetTextBkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETTEXTBKCOLOR, 0, 0L); + } + + BOOL SetTextBkColor(COLORREF cr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETTEXTBKCOLOR, 0, cr); + } + + int GetTopIndex() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETTOPINDEX, 0, 0L); + } + + int GetCountPerPage() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETCOUNTPERPAGE, 0, 0L); + } + + BOOL GetOrigin(LPPOINT lpPoint) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETORIGIN, 0, (LPARAM)lpPoint); + } + + UINT GetSelectedCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, LVM_GETSELECTEDCOUNT, 0, 0L); + } + + BOOL GetItemRect(int nItem, LPRECT lpRect, UINT nCode) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + lpRect->left = nCode; + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEMRECT, (WPARAM)nItem, (LPARAM)lpRect); + } + + HCURSOR GetHotCursor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HCURSOR)::SendMessage(this->m_hWnd, LVM_GETHOTCURSOR, 0, 0L); + } + + HCURSOR SetHotCursor(HCURSOR hHotCursor) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HCURSOR)::SendMessage(this->m_hWnd, LVM_SETHOTCURSOR, 0, (LPARAM)hHotCursor); + } + + int GetHotItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETHOTITEM, 0, 0L); + } + + int SetHotItem(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_SETHOTITEM, nIndex, 0L); + } + + BOOL GetColumnOrderArray(int nCount, int* lpnArray) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETCOLUMNORDERARRAY, nCount, (LPARAM)lpnArray); + } + + BOOL SetColumnOrderArray(int nCount, int* lpnArray) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETCOLUMNORDERARRAY, nCount, (LPARAM)lpnArray); + } + + CHeaderCtrl GetHeader() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CHeaderCtrl((HWND)::SendMessage(this->m_hWnd, LVM_GETHEADER, 0, 0L)); + } + + BOOL GetSubItemRect(int nItem, int nSubItem, int nFlag, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & LVS_TYPEMASK) == LVS_REPORT); + ATLASSERT(lpRect != NULL); + lpRect->top = nSubItem; + lpRect->left = nFlag; + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETSUBITEMRECT, nItem, (LPARAM)lpRect); + } + + DWORD SetIconSpacing(int cx, int cy) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & LVS_TYPEMASK) == LVS_ICON); + return (DWORD)::SendMessage(this->m_hWnd, LVM_SETICONSPACING, 0, MAKELPARAM(cx, cy)); + } + + int GetISearchString(LPTSTR lpstr) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETISEARCHSTRING, 0, (LPARAM)lpstr); + } + + void GetItemSpacing(SIZE& sizeSpacing, BOOL bSmallIconView = FALSE) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, LVM_GETITEMSPACING, bSmallIconView, 0L); + sizeSpacing.cx = GET_X_LPARAM(dwRet); + sizeSpacing.cy = GET_Y_LPARAM(dwRet); + } + + // single-selection only + int GetSelectedIndex() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & LVS_SINGLESEL) != 0); + return (int)::SendMessage(this->m_hWnd, LVM_GETNEXTITEM, (WPARAM)-1, MAKELPARAM(LVNI_ALL | LVNI_SELECTED, 0)); + } + + BOOL GetSelectedItem(LPLVITEM pItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & LVS_SINGLESEL) != 0); + ATLASSERT(pItem != NULL); + pItem->iItem = (int)::SendMessage(this->m_hWnd, LVM_GETNEXTITEM, (WPARAM)-1, MAKELPARAM(LVNI_ALL | LVNI_SELECTED, 0)); + if(pItem->iItem == -1) + return FALSE; + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEM, 0, (LPARAM)pItem); + } + + // extended list view styles + DWORD GetExtendedListViewStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, LVM_GETEXTENDEDLISTVIEWSTYLE, 0, 0L); + } + + // dwExMask = 0 means all styles + DWORD SetExtendedListViewStyle(DWORD dwExStyle, DWORD dwExMask = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, LVM_SETEXTENDEDLISTVIEWSTYLE, dwExMask, dwExStyle); + } + + // checkboxes only + BOOL GetCheckState(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((GetExtendedListViewStyle() & LVS_EX_CHECKBOXES) != 0); + UINT uRet = GetItemState(nIndex, LVIS_STATEIMAGEMASK); + return (uRet >> 12) - 1; + } + + BOOL SetCheckState(int nItem, BOOL bCheck) + { + int nCheck = bCheck ? 2 : 1; // one based index + return SetItemState(nItem, INDEXTOSTATEIMAGEMASK(nCheck), LVIS_STATEIMAGEMASK); + } + + // view type + DWORD GetViewType() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (this->GetStyle() & LVS_TYPEMASK); + } + + DWORD SetViewType(DWORD dwType) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((dwType == LVS_ICON) || (dwType == LVS_SMALLICON) || (dwType == LVS_LIST) || (dwType == LVS_REPORT)); + DWORD dwOldType = GetViewType(); + if(dwType != dwOldType) + this->ModifyStyle(LVS_TYPEMASK, (dwType & LVS_TYPEMASK)); + return dwOldType; + } + + BOOL GetBkImage(LPLVBKIMAGE plvbki) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETBKIMAGE, 0, (LPARAM)plvbki); + } + + BOOL SetBkImage(LPLVBKIMAGE plvbki) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETBKIMAGE, 0, (LPARAM)plvbki); + } + + int GetSelectionMark() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETSELECTIONMARK, 0, 0L); + } + + int SetSelectionMark(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_SETSELECTIONMARK, 0, nIndex); + } + + BOOL GetWorkAreas(int nWorkAreas, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETWORKAREAS, nWorkAreas, (LPARAM)lpRect); + } + + BOOL SetWorkAreas(int nWorkAreas, LPRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETWORKAREAS, nWorkAreas, (LPARAM)lpRect); + } + + DWORD GetHoverTime() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((GetExtendedListViewStyle() & (LVS_EX_TRACKSELECT | LVS_EX_ONECLICKACTIVATE | LVS_EX_TWOCLICKACTIVATE)) != 0); + return (DWORD)::SendMessage(this->m_hWnd, LVM_GETHOVERTIME, 0, 0L); + } + + DWORD SetHoverTime(DWORD dwHoverTime) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((GetExtendedListViewStyle() & (LVS_EX_TRACKSELECT | LVS_EX_ONECLICKACTIVATE | LVS_EX_TWOCLICKACTIVATE)) != 0); + return (DWORD)::SendMessage(this->m_hWnd, LVM_SETHOVERTIME, 0, dwHoverTime); + } + + BOOL GetNumberOfWorkAreas(int* pnWorkAreas) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETNUMBEROFWORKAREAS, 0, (LPARAM)pnWorkAreas); + } + + BOOL SetItemCountEx(int nItems, DWORD dwFlags) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(((this->GetStyle() & LVS_OWNERDATA) != 0) && (((this->GetStyle() & LVS_TYPEMASK) == LVS_REPORT) || ((this->GetStyle() & LVS_TYPEMASK) == LVS_LIST))); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMCOUNT, nItems, dwFlags); + } + + CToolTipCtrl GetToolTips() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, LVM_GETTOOLTIPS, 0, 0L)); + } + + CToolTipCtrl SetToolTips(HWND hWndTT) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, LVM_SETTOOLTIPS, (WPARAM)hWndTT, 0L)); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETUNICODEFORMAT, bUnicode, 0L); + } + + int GetSelectedColumn() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETSELECTEDCOLUMN, 0, 0L); + } + + void SetSelectedColumn(int nColumn) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_SETSELECTEDCOLUMN, nColumn, 0L); + } + + DWORD GetView() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, LVM_GETVIEW, 0, 0L); + } + + int SetView(DWORD dwView) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_SETVIEW, dwView, 0L); + } + + BOOL IsGroupViewEnabled() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_ISGROUPVIEWENABLED, 0, 0L); + } + + int GetGroupInfo(int nGroupID, PLVGROUP pGroup) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETGROUPINFO, nGroupID, (LPARAM)pGroup); + } + + int SetGroupInfo(int nGroupID, PLVGROUP pGroup) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_SETGROUPINFO, nGroupID, (LPARAM)pGroup); + } + + void GetGroupMetrics(PLVGROUPMETRICS pGroupMetrics) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_GETGROUPMETRICS, 0, (LPARAM)pGroupMetrics); + } + + void SetGroupMetrics(PLVGROUPMETRICS pGroupMetrics) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_SETGROUPMETRICS, 0, (LPARAM)pGroupMetrics); + } + + void GetTileViewInfo(PLVTILEVIEWINFO pTileViewInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_GETTILEVIEWINFO, 0, (LPARAM)pTileViewInfo); + } + + BOOL SetTileViewInfo(PLVTILEVIEWINFO pTileViewInfo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETTILEVIEWINFO, 0, (LPARAM)pTileViewInfo); + } + + void GetTileInfo(PLVTILEINFO pTileInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_GETTILEINFO, 0, (LPARAM)pTileInfo); + } + + BOOL SetTileInfo(PLVTILEINFO pTileInfo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETTILEINFO, 0, (LPARAM)pTileInfo); + } + + BOOL GetInsertMark(LPLVINSERTMARK pInsertMark) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETINSERTMARK, 0, (LPARAM)pInsertMark); + } + + BOOL SetInsertMark(LPLVINSERTMARK pInsertMark) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETINSERTMARK, 0, (LPARAM)pInsertMark); + } + + int GetInsertMarkRect(LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETINSERTMARKRECT, 0, (LPARAM)lpRect); + } + + COLORREF GetInsertMarkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETINSERTMARKCOLOR, 0, 0L); + } + + COLORREF SetInsertMarkColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, LVM_SETINSERTMARKCOLOR, 0, clr); + } + + COLORREF GetOutlineColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETOUTLINECOLOR, 0, 0L); + } + + COLORREF SetOutlineColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, LVM_SETOUTLINECOLOR, 0, clr); + } + +#if (_WIN32_WINNT >= 0x0600) + int GetGroupCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETGROUPCOUNT, 0, 0L); + } + + BOOL GetGroupInfoByIndex(int nIndex, PLVGROUP pGroup) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETGROUPINFOBYINDEX, nIndex, (LPARAM)pGroup); + } + + BOOL GetGroupRect(int nGroupID, int nType, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpRect != NULL); + if(lpRect != NULL) + lpRect->top = nType; + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETGROUPRECT, nGroupID, (LPARAM)lpRect); + } + + UINT GetGroupState(int nGroupID, UINT uMask) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, LVM_GETGROUPSTATE, nGroupID, (LPARAM)uMask); + } + + int GetFocusedGroup() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETFOCUSEDGROUP, 0, 0L); + } + + BOOL GetEmptyText(LPWSTR lpstrText, int cchText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETEMPTYTEXT, cchText, (LPARAM)lpstrText); + } + + BOOL GetFooterRect(LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETFOOTERRECT, 0, (LPARAM)lpRect); + } + + BOOL GetFooterInfo(LPLVFOOTERINFO lpFooterInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETFOOTERINFO, 0, (LPARAM)lpFooterInfo); + } + + BOOL GetFooterItemRect(int nItem, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETFOOTERITEMRECT, nItem, (LPARAM)lpRect); + } + + BOOL GetFooterItem(int nItem, LPLVFOOTERITEM lpFooterItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETFOOTERITEM, nItem, (LPARAM)lpFooterItem); + } + + BOOL GetItemIndexRect(PLVITEMINDEX pItemIndex, int nSubItem, int nType, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(pItemIndex != NULL); + ATLASSERT(lpRect != NULL); + if(lpRect != NULL) + { + lpRect->top = nSubItem; + lpRect->left = nType; + } + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEMINDEXRECT, (WPARAM)pItemIndex, (LPARAM)lpRect); + } + + BOOL SetItemIndexState(PLVITEMINDEX pItemIndex, UINT uState, UINT dwMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVITEM lvi = {}; + lvi.state = uState; + lvi.stateMask = dwMask; + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMINDEXSTATE, (WPARAM)pItemIndex, (LPARAM)&lvi); + } + + BOOL GetNextItemIndex(PLVITEMINDEX pItemIndex, WORD wFlags) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_GETNEXTITEMINDEX, (WPARAM)pItemIndex, MAKELPARAM(wFlags, 0)); + } +#endif // (_WIN32_WINNT >= 0x0600) + +// Operations + int InsertColumn(int nCol, const LVCOLUMN* pColumn) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_INSERTCOLUMN, nCol, (LPARAM)pColumn); + } + + int InsertColumn(int nCol, LPCTSTR lpszColumnHeading, int nFormat = LVCFMT_LEFT, + int nWidth = -1, int nSubItem = -1, int iImage = -1, int iOrder = -1) + { + LVCOLUMN column = {}; + column.mask = LVCF_TEXT | LVCF_FMT; + column.pszText = (LPTSTR)lpszColumnHeading; + column.fmt = nFormat; + if (nWidth != -1) + { + column.mask |= LVCF_WIDTH; + column.cx = nWidth; + } + if (nSubItem != -1) + { + column.mask |= LVCF_SUBITEM; + column.iSubItem = nSubItem; + } + if (iImage != -1) + { + column.mask |= LVCF_IMAGE; + column.iImage = iImage; + } + if (iOrder != -1) + { + column.mask |= LVCF_ORDER; + column.iOrder = iOrder; + } + return InsertColumn(nCol, &column); + } + + BOOL DeleteColumn(int nCol) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_DELETECOLUMN, nCol, 0L); + } + + int InsertItem(UINT nMask, int nItem, LPCTSTR lpszItem, UINT nState, UINT nStateMask, int nImage, LPARAM lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVITEM item = {}; + item.mask = nMask; + item.iItem = nItem; + item.iSubItem = 0; + item.pszText = (LPTSTR)lpszItem; + item.state = nState; + item.stateMask = nStateMask; + item.iImage = nImage; + item.lParam = lParam; + return InsertItem(&item); + } + + int InsertItem(const LVITEM* pItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_INSERTITEM, 0, (LPARAM)pItem); + } + + int InsertItem(int nItem, LPCTSTR lpszItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return InsertItem(LVIF_TEXT, nItem, lpszItem, 0, 0, 0, 0); + } + + int InsertItem(int nItem, LPCTSTR lpszItem, int nImage) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return InsertItem(LVIF_TEXT|LVIF_IMAGE, nItem, lpszItem, 0, 0, nImage, 0); + } + + int GetNextItem(int nItem, int nFlags) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_GETNEXTITEM, nItem, MAKELPARAM(nFlags, 0)); + } + + BOOL DeleteItem(int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_DELETEITEM, nItem, 0L); + } + + BOOL DeleteAllItems() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_DELETEALLITEMS, 0, 0L); + } + + int FindItem(LVFINDINFO* pFindInfo, int nStart = -1) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_FINDITEM, nStart, (LPARAM)pFindInfo); + } + + int FindItem(LPCTSTR lpstrFind, bool bPartial = true, bool bWrap = false, int nStart = -1) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVFINDINFO lvfi = {}; + lvfi.flags = LVFI_STRING | (bWrap ? LVFI_WRAP : 0) | (bPartial ? LVFI_PARTIAL : 0); + lvfi.psz = lpstrFind; + return (int)::SendMessage(this->m_hWnd, LVM_FINDITEM, nStart, (LPARAM)&lvfi); + } + + int HitTest(LVHITTESTINFO* pHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_HITTEST, 0, (LPARAM)pHitTestInfo); + } + + int HitTest(POINT pt, UINT* pFlags) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVHITTESTINFO hti = {}; + hti.pt = pt; + int nRes = (int)::SendMessage(this->m_hWnd, LVM_HITTEST, 0, (LPARAM)&hti); + if (pFlags != NULL) + *pFlags = hti.flags; + return nRes; + } + + BOOL EnsureVisible(int nItem, BOOL bPartialOK) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_ENSUREVISIBLE, nItem, MAKELPARAM(bPartialOK, 0)); + } + + BOOL Scroll(int cx, int cy) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SCROLL, cx, cy); + } + + BOOL Scroll(SIZE size) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SCROLL, size.cx, size.cy); + } + + BOOL RedrawItems(int nFirst, int nLast) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_REDRAWITEMS, nFirst, nLast); + } + + BOOL Arrange(UINT nCode) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_ARRANGE, nCode, 0L); + } + + CEdit EditLabel(int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CEdit((HWND)::SendMessage(this->m_hWnd, LVM_EDITLABEL, nItem, 0L)); + } + + BOOL Update(int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_UPDATE, nItem, 0L); + } + + BOOL SortItems(PFNLVCOMPARE pfnCompare, LPARAM lParamSort) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SORTITEMS, (WPARAM)lParamSort, (LPARAM)pfnCompare); + } + + CImageList RemoveImageList(int nImageList) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, LVM_SETIMAGELIST, (WPARAM)nImageList, NULL)); + } + + CImageList CreateDragImage(int nItem, LPPOINT lpPoint) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, LVM_CREATEDRAGIMAGE, nItem, (LPARAM)lpPoint)); + } + + DWORD ApproximateViewRect(int cx = -1, int cy = -1, int nCount = -1) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, LVM_APPROXIMATEVIEWRECT, nCount, MAKELPARAM(cx, cy)); + } + + int SubItemHitTest(LPLVHITTESTINFO lpInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_SUBITEMHITTEST, 0, (LPARAM)lpInfo); + } + + int AddColumn(LPCTSTR strColumn, int nItem, int nSubItem = -1, + int nMask = LVCF_FMT | LVCF_WIDTH | LVCF_TEXT | LVCF_SUBITEM, + int nFmt = LVCFMT_LEFT) + { + const int cxOffset = 15; + ATLASSERT(::IsWindow(this->m_hWnd)); + LVCOLUMN lvc = {}; + lvc.mask = nMask; + lvc.fmt = nFmt; + lvc.pszText = (LPTSTR)strColumn; + lvc.cx = GetStringWidth(lvc.pszText) + cxOffset; + if(nMask & LVCF_SUBITEM) + lvc.iSubItem = (nSubItem != -1) ? nSubItem : nItem; + return InsertColumn(nItem, &lvc); + } + + int AddItem(int nItem, int nSubItem, LPCTSTR strItem, int nImageIndex = -3) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVITEM lvItem = {}; + lvItem.mask = LVIF_TEXT; + lvItem.iItem = nItem; + lvItem.iSubItem = nSubItem; + lvItem.pszText = (LPTSTR)strItem; + if(nImageIndex != -3) + { + lvItem.mask |= LVIF_IMAGE; + lvItem.iImage = nImageIndex; + } + if(nSubItem == 0) + return InsertItem(&lvItem); + return SetItem(&lvItem) ? nItem : -1; + } + + BOOL SortItemsEx(PFNLVCOMPARE pfnCompare, LPARAM lParamSort) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SORTITEMSEX, (WPARAM)lParamSort, (LPARAM)pfnCompare); + } + + int InsertGroup(int nItem, PLVGROUP pGroup) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_INSERTGROUP, nItem, (LPARAM)pGroup); + } + + int AddGroup(PLVGROUP pGroup) + { + return InsertGroup(-1, pGroup); + } + + int RemoveGroup(int nGroupID) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_REMOVEGROUP, nGroupID, 0L); + } + + void MoveGroup(int nGroupID, int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_MOVEGROUP, nGroupID, nItem); + } + + void MoveItemToGroup(int nItem, int nGroupID) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_MOVEITEMTOGROUP, nItem, nGroupID); + } + + int EnableGroupView(BOOL bEnable) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_ENABLEGROUPVIEW, bEnable, 0L); + } + + int SortGroups(PFNLVGROUPCOMPARE pCompareFunc, LPVOID lpVoid = NULL) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_SORTGROUPS, (WPARAM)pCompareFunc, (LPARAM)lpVoid); + } + + void InsertGroupSorted(PLVINSERTGROUPSORTED pInsertGroupSorted) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_INSERTGROUPSORTED, (WPARAM)pInsertGroupSorted, 0L); + } + + void RemoveAllGroups() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_REMOVEALLGROUPS, 0, 0L); + } + + BOOL HasGroup(int nGroupID) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_HASGROUP, nGroupID, 0L); + } + + BOOL InsertMarkHitTest(LPPOINT lpPoint, LPLVINSERTMARK pInsertMark) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_INSERTMARKHITTEST, (WPARAM)lpPoint, (LPARAM)pInsertMark); + } + + BOOL SetInfoTip(PLVSETINFOTIP pSetInfoTip) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_SETINFOTIP, 0, (LPARAM)pSetInfoTip); + } + + void CancelEditLabel() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, LVM_CANCELEDITLABEL, 0, 0L); + } + + UINT MapIndexToID(int nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, LVM_MAPINDEXTOID, nIndex, 0L); + } + + int MapIDToIndex(UINT uID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_MAPIDTOINDEX, uID, 0L); + } + + BOOL IsItemVisible(int nItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LVM_ISITEMVISIBLE, nItem, 0L); + } + +#if (_WIN32_WINNT >= 0x0600) + int HitTestEx(LPLVHITTESTINFO lpHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_HITTEST, (WPARAM)-1, (LPARAM)lpHitTestInfo); + } + + int HitTestEx(POINT pt, UINT* pFlags) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + LVHITTESTINFO hti = {}; + hti.pt = pt; + int nRes = (int)::SendMessage(this->m_hWnd, LVM_HITTEST, (WPARAM)-1, (LPARAM)&hti); + if (pFlags != NULL) + *pFlags = hti.flags; + return nRes; + } + + int SubItemHitTestEx(LPLVHITTESTINFO lpHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LVM_SUBITEMHITTEST, (WPARAM)-1, (LPARAM)lpHitTestInfo); + } +#endif // (_WIN32_WINNT >= 0x0600) + + // Note: selects only one item + BOOL SelectItem(int nIndex) // -1 to select none + { + ATLASSERT(::IsWindow(this->m_hWnd)); + + BOOL bRet = FALSE; + if(nIndex != -1) + { + // multi-selection only: de-select all items + if((this->GetStyle() & LVS_SINGLESEL) == 0) + SetItemState(-1, 0, LVIS_SELECTED); + + bRet = SetItemState(nIndex, LVIS_SELECTED | LVIS_FOCUSED, LVIS_SELECTED | LVIS_FOCUSED); + if(bRet) + { + SetSelectionMark(nIndex); + bRet = EnsureVisible(nIndex, FALSE); + } + } + else // no item specified, just de-select + { + bRet = SetItemState(-1, 0, LVIS_SELECTED); + } + + return bRet; + } + + // multi-selection only + BOOL SelectAllItems(bool bSelect = true) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & LVS_SINGLESEL) == 0); + + return SetItemState(-1, bSelect ? LVIS_SELECTED : 0, LVIS_SELECTED); + } +}; + +typedef CListViewCtrlT<ATL::CWindow> CListViewCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CTreeViewCtrl + +template <class TBase> +class CTreeViewCtrlT : public TBase +{ +public: +// Constructors + CTreeViewCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CTreeViewCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return WC_TREEVIEW; + } + + UINT GetCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, TVM_GETCOUNT, 0, 0L); + } + + UINT GetIndent() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, TVM_GETINDENT, 0, 0L); + } + + void SetIndent(UINT nIndent) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TVM_SETINDENT, nIndent, 0L); + } + + CImageList GetImageList(int nImageListType = TVSIL_NORMAL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TVM_GETIMAGELIST, (WPARAM)nImageListType, 0L)); + } + + CImageList SetImageList(HIMAGELIST hImageList, int nImageListType = TVSIL_NORMAL) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TVM_SETIMAGELIST, (WPARAM)nImageListType, (LPARAM)hImageList)); + } + + BOOL GetItem(LPTVITEM pItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)pItem); + } + + BOOL SetItem(LPTVITEM pItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SETITEM, 0, (LPARAM)pItem); + } + + BOOL SetItem(HTREEITEM hItem, UINT nMask, LPCTSTR lpszItem, int nImage, + int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVITEM item = {}; + item.hItem = hItem; + item.mask = nMask; + item.pszText = (LPTSTR) lpszItem; + item.iImage = nImage; + item.iSelectedImage = nSelectedImage; + item.state = nState; + item.stateMask = nStateMask; + item.lParam = lParam; + return (BOOL)::SendMessage(this->m_hWnd, TVM_SETITEM, 0, (LPARAM)&item); + } + + BOOL GetItemText(HTREEITEM hItem, LPTSTR lpstrText, int nLen) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpstrText != NULL); + + TVITEM item = {}; + item.hItem = hItem; + item.mask = TVIF_TEXT; + item.pszText = lpstrText; + item.cchTextMax = nLen; + + return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item); + } + + BOOL GetItemText(HTREEITEM hItem, BSTR& bstrText) const + { + USES_CONVERSION; + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(bstrText == NULL); + TVITEM item = {}; + item.hItem = hItem; + item.mask = TVIF_TEXT; + + LPTSTR lpstrText = NULL; + BOOL bRet = FALSE; + for(int nLen = 256; ; nLen *= 2) + { + ATLTRY(lpstrText = new TCHAR[nLen]); + if(lpstrText == NULL) + break; + lpstrText[0] = NULL; + item.pszText = lpstrText; + item.cchTextMax = nLen; + bRet = (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item); + if(!bRet || (lstrlen(item.pszText) < (nLen - 1))) + break; + delete [] lpstrText; + lpstrText = NULL; + } + + if(lpstrText != NULL) + { + if(bRet) + bstrText = ::SysAllocString(T2OLE(lpstrText)); + delete [] lpstrText; + } + + return (bstrText != NULL) ? TRUE : FALSE; + } + +#ifdef __ATLSTR_H__ + BOOL GetItemText(HTREEITEM hItem, ATL::CString& strText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVITEM item = {}; + item.hItem = hItem; + item.mask = TVIF_TEXT; + + strText.Empty(); + BOOL bRet = FALSE; + for(int nLen = 256; ; nLen *= 2) + { + item.pszText = strText.GetBufferSetLength(nLen); + if(item.pszText == NULL) + { + bRet = FALSE; + break; + } + item.cchTextMax = nLen; + bRet = (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item); + if(!bRet || (lstrlen(item.pszText) < (nLen - 1))) + break; + } + strText.ReleaseBuffer(); + return bRet; + } +#endif // __ATLSTR_H__ + + BOOL SetItemText(HTREEITEM hItem, LPCTSTR lpszItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItem(hItem, TVIF_TEXT, lpszItem, 0, 0, 0, 0, NULL); + } + + BOOL GetItemImage(HTREEITEM hItem, int& nImage, int& nSelectedImage) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVITEM item = {}; + item.hItem = hItem; + item.mask = TVIF_IMAGE|TVIF_SELECTEDIMAGE; + BOOL bRes = (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item); + if (bRes) + { + nImage = item.iImage; + nSelectedImage = item.iSelectedImage; + } + return bRes; + } + + BOOL SetItemImage(HTREEITEM hItem, int nImage, int nSelectedImage) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItem(hItem, TVIF_IMAGE|TVIF_SELECTEDIMAGE, NULL, nImage, nSelectedImage, 0, 0, NULL); + } + + UINT GetItemState(HTREEITEM hItem, UINT nStateMask) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (((UINT)::SendMessage(this->m_hWnd, TVM_GETITEMSTATE, (WPARAM)hItem, (LPARAM)nStateMask)) & nStateMask); + } + + BOOL SetItemState(HTREEITEM hItem, UINT nState, UINT nStateMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItem(hItem, TVIF_STATE, NULL, 0, 0, nState, nStateMask, NULL); + } + + DWORD_PTR GetItemData(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVITEM item = {}; + item.hItem = hItem; + item.mask = TVIF_PARAM; + BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item); + return (DWORD_PTR)(bRet ? item.lParam : NULL); + } + + BOOL SetItemData(HTREEITEM hItem, DWORD_PTR dwData) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItem(hItem, TVIF_PARAM, NULL, 0, 0, 0, 0, (LPARAM)dwData); + } + + CEdit GetEditControl() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CEdit((HWND)::SendMessage(this->m_hWnd, TVM_GETEDITCONTROL, 0, 0L)); + } + + UINT GetVisibleCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, TVM_GETVISIBLECOUNT, 0, 0L); + } + + BOOL GetItemRect(HTREEITEM hItem, LPRECT lpRect, BOOL bTextOnly) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + *(HTREEITEM*)lpRect = hItem; + return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEMRECT, (WPARAM)bTextOnly, (LPARAM)lpRect); + } + + BOOL ItemHasChildren(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVITEM item = {}; + item.hItem = hItem; + item.mask = TVIF_CHILDREN; + ::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item); + return item.cChildren; + } + + CToolTipCtrl GetToolTips() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TVM_GETTOOLTIPS, 0, 0L)); + } + + CToolTipCtrl SetToolTips(HWND hWndTT) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TVM_SETTOOLTIPS, (WPARAM)hWndTT, 0L)); + } + + int GetISearchString(LPTSTR lpstr) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TVM_GETISEARCHSTRING, 0, (LPARAM)lpstr); + } + + // checkboxes only + BOOL GetCheckState(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & TVS_CHECKBOXES) != 0); + UINT uRet = GetItemState(hItem, TVIS_STATEIMAGEMASK); + return (uRet >> 12) - 1; + } + + BOOL SetCheckState(HTREEITEM hItem, BOOL bCheck) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & TVS_CHECKBOXES) != 0); + int nCheck = bCheck ? 2 : 1; // one based index + return SetItemState(hItem, INDEXTOSTATEIMAGEMASK(nCheck), TVIS_STATEIMAGEMASK); + } + + // for standard and extended checkboxes (0 = no checkbox, 1 = unchecked, 2 = checked, >2 = optional extended check states) + UINT GetCheckStateEx(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(this->GetImageList(TVSIL_STATE) != NULL); + UINT uRet = GetItemState(hItem, TVIS_STATEIMAGEMASK); + return (uRet >> 12); + } + + BOOL SetCheckStateEx(HTREEITEM hItem, UINT uCheckState) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(this->GetImageList(TVSIL_STATE) != NULL); + ATLASSERT(uCheckState < (UINT)::ImageList_GetImageCount(this->GetImageList(TVSIL_STATE))); + return SetItemState(hItem, INDEXTOSTATEIMAGEMASK(uCheckState), TVIS_STATEIMAGEMASK); + } + + COLORREF GetBkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TVM_GETBKCOLOR, 0, 0L); + } + + COLORREF SetBkColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TVM_SETBKCOLOR, 0, (LPARAM)clr); + } + + COLORREF GetInsertMarkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TVM_GETINSERTMARKCOLOR, 0, 0L); + } + + COLORREF SetInsertMarkColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TVM_SETINSERTMARKCOLOR, 0, (LPARAM)clr); + } + + int GetItemHeight() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TVM_GETITEMHEIGHT, 0, 0L); + } + + int SetItemHeight(int cyHeight) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TVM_SETITEMHEIGHT, cyHeight, 0L); + } + + int GetScrollTime() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TVM_GETSCROLLTIME, 0, 0L); + } + + int SetScrollTime(int nScrollTime) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TVM_SETSCROLLTIME, nScrollTime, 0L); + } + + COLORREF GetTextColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TVM_GETTEXTCOLOR, 0, 0L); + } + + COLORREF SetTextColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TVM_SETTEXTCOLOR, 0, (LPARAM)clr); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SETUNICODEFORMAT, bUnicode, 0L); + } + + COLORREF GetLineColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TVM_GETLINECOLOR, 0, 0L); + } + + COLORREF SetLineColor(COLORREF clrNew /*= CLR_DEFAULT*/) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TVM_SETLINECOLOR, 0, (LPARAM)clrNew); + } + + BOOL GetItem(LPTVITEMEX pItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)pItem); + } + + BOOL SetItem(LPTVITEMEX pItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SETITEM, 0, (LPARAM)pItem); + } + + DWORD GetExtendedStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TVM_GETEXTENDEDSTYLE, 0, 0L); + } + + DWORD SetExtendedStyle(DWORD dwStyle, DWORD dwMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TVM_SETEXTENDEDSTYLE, dwMask, dwStyle); + } + +#if (_WIN32_WINNT >= 0x0600) + BOOL SetAutoScrollInfo(UINT uPixPerSec, UINT uUpdateTime) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SETAUTOSCROLLINFO, (WPARAM)uPixPerSec, (LPARAM)uUpdateTime); + } + + DWORD GetSelectedCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TVM_GETSELECTEDCOUNT, 0, 0L); + } + + BOOL GetItemPartRect(HTREEITEM hItem, TVITEMPART partID, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVGETITEMPARTRECTINFO gipri = { hItem, lpRect, partID }; + return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEMPARTRECT, 0, (LPARAM)&gipri); + } +#endif // (_WIN32_WINNT >= 0x0600) + +// Operations + HTREEITEM InsertItem(LPTVINSERTSTRUCT lpInsertStruct) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_INSERTITEM, 0, (LPARAM)lpInsertStruct); + } + + HTREEITEM InsertItem(LPCTSTR lpszItem, int nImage, + int nSelectedImage, HTREEITEM hParent, HTREEITEM hInsertAfter) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return InsertItem(TVIF_TEXT | TVIF_IMAGE | TVIF_SELECTEDIMAGE, lpszItem, nImage, nSelectedImage, 0, 0, 0, hParent, hInsertAfter); + } + + HTREEITEM InsertItem(LPCTSTR lpszItem, HTREEITEM hParent, HTREEITEM hInsertAfter) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return InsertItem(TVIF_TEXT, lpszItem, 0, 0, 0, 0, 0, hParent, hInsertAfter); + } + + HTREEITEM InsertItem(UINT nMask, LPCTSTR lpszItem, int nImage, + int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam, + HTREEITEM hParent, HTREEITEM hInsertAfter) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVINSERTSTRUCT tvis = {}; + tvis.hParent = hParent; + tvis.hInsertAfter = hInsertAfter; + tvis.item.mask = nMask; + tvis.item.pszText = (LPTSTR) lpszItem; + tvis.item.iImage = nImage; + tvis.item.iSelectedImage = nSelectedImage; + tvis.item.state = nState; + tvis.item.stateMask = nStateMask; + tvis.item.lParam = lParam; + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_INSERTITEM, 0, (LPARAM)&tvis); + } + + BOOL DeleteItem(HTREEITEM hItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_DELETEITEM, 0, (LPARAM)hItem); + } + + BOOL DeleteAllItems() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_DELETEITEM, 0, (LPARAM)TVI_ROOT); + } + + BOOL Expand(HTREEITEM hItem, UINT nCode = TVE_EXPAND) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_EXPAND, nCode, (LPARAM)hItem); + } + + HTREEITEM GetNextItem(HTREEITEM hItem, UINT nCode) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, nCode, (LPARAM)hItem); + } + + HTREEITEM GetChildItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_CHILD, (LPARAM)hItem); + } + + HTREEITEM GetNextSiblingItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXT, (LPARAM)hItem); + } + + HTREEITEM GetPrevSiblingItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PREVIOUS, (LPARAM)hItem); + } + + HTREEITEM GetParentItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PARENT, (LPARAM)hItem); + } + + HTREEITEM GetFirstVisibleItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_FIRSTVISIBLE, 0L); + } + + HTREEITEM GetNextVisibleItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXTVISIBLE, (LPARAM)hItem); + } + + HTREEITEM GetPrevVisibleItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PREVIOUSVISIBLE, (LPARAM)hItem); + } + + HTREEITEM GetSelectedItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_CARET, 0L); + } + + HTREEITEM GetDropHilightItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_DROPHILITE, 0L); + } + + HTREEITEM GetRootItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_ROOT, 0L); + } + + HTREEITEM GetLastVisibleItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_LASTVISIBLE, 0L); + } + + HTREEITEM GetNextSelectedItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXTSELECTED, (LPARAM)hItem); + } + + BOOL Select(HTREEITEM hItem, UINT nCode) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SELECTITEM, nCode, (LPARAM)hItem); + } + + BOOL SelectItem(HTREEITEM hItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SELECTITEM, TVGN_CARET, (LPARAM)hItem); + } + + BOOL SelectDropTarget(HTREEITEM hItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SELECTITEM, TVGN_DROPHILITE, (LPARAM)hItem); + } + + BOOL SelectSetFirstVisible(HTREEITEM hItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SELECTITEM, TVGN_FIRSTVISIBLE, (LPARAM)hItem); + } + + CEdit EditLabel(HTREEITEM hItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CEdit((HWND)::SendMessage(this->m_hWnd, TVM_EDITLABEL, 0, (LPARAM)hItem)); + } + + BOOL EndEditLabelNow(BOOL bCancel) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_ENDEDITLABELNOW, bCancel, 0L); + } + + HTREEITEM HitTest(TVHITTESTINFO* pHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_HITTEST, 0, (LPARAM)pHitTestInfo); + } + + HTREEITEM HitTest(POINT pt, UINT* pFlags) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVHITTESTINFO hti = {}; + hti.pt = pt; + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_HITTEST, 0, (LPARAM)&hti); + if (pFlags != NULL) + *pFlags = hti.flags; + return hTreeItem; + } + + BOOL SortChildren(HTREEITEM hItem, BOOL bRecurse = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SORTCHILDREN, (WPARAM)bRecurse, (LPARAM)hItem); + } + + BOOL EnsureVisible(HTREEITEM hItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_ENSUREVISIBLE, 0, (LPARAM)hItem); + } + + BOOL SortChildrenCB(LPTVSORTCB pSort, BOOL bRecurse = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SORTCHILDRENCB, (WPARAM)bRecurse, (LPARAM)pSort); + } + + CImageList RemoveImageList(int nImageList) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TVM_SETIMAGELIST, (WPARAM)nImageList, NULL)); + } + + CImageList CreateDragImage(HTREEITEM hItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TVM_CREATEDRAGIMAGE, 0, (LPARAM)hItem)); + } + + BOOL SetInsertMark(HTREEITEM hTreeItem, BOOL bAfter) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SETINSERTMARK, bAfter, (LPARAM)hTreeItem); + } + + BOOL RemoveInsertMark() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TVM_SETINSERTMARK, 0, 0L); + } + + HTREEITEM MapAccIDToHTREEITEM(UINT uID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_MAPACCIDTOHTREEITEM, uID, 0L); + } + + UINT MapHTREEITEMToAccID(HTREEITEM hTreeItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, TVM_MAPHTREEITEMTOACCID, (WPARAM)hTreeItem, 0L); + } + +#if (_WIN32_WINNT >= 0x0600) + void ShowInfoTip(HTREEITEM hItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TVM_SHOWINFOTIP, 0, (LPARAM)hItem); + } +#endif // (_WIN32_WINNT >= 0x0600) +}; + +typedef CTreeViewCtrlT<ATL::CWindow> CTreeViewCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CTreeViewCtrlEx + +// forward declaration +template <class TBase> class CTreeViewCtrlExT; + +// Note: TBase here is for CTreeViewCtrlExT, and not for CTreeItemT itself +template <class TBase> +class CTreeItemT +{ +public: + HTREEITEM m_hTreeItem; + CTreeViewCtrlExT<TBase>* m_pTreeView; + +// Construction + CTreeItemT(HTREEITEM hTreeItem = NULL, CTreeViewCtrlExT<TBase>* pTreeView = NULL) : m_hTreeItem(hTreeItem), m_pTreeView(pTreeView) + { } + + CTreeItemT(const CTreeItemT<TBase>& posSrc) + { + *this = posSrc; + } + + operator HTREEITEM() { return m_hTreeItem; } + + CTreeItemT<TBase>& operator =(const CTreeItemT<TBase>& itemSrc) + { + m_hTreeItem = itemSrc.m_hTreeItem; + m_pTreeView = itemSrc.m_pTreeView; + return *this; + } + +// Attributes + CTreeViewCtrlExT<TBase>* GetTreeView() const { return m_pTreeView; } + + BOOL operator !() const { return m_hTreeItem == NULL; } + + BOOL IsNull() const { return m_hTreeItem == NULL; } + + BOOL GetRect(LPRECT lpRect, BOOL bTextOnly) const; + BOOL GetText(LPTSTR lpstrText, int nLen) const; + BOOL GetText(BSTR& bstrText) const; +#ifdef __ATLSTR_H__ + BOOL GetText(ATL::CString& strText) const; +#endif // __ATLSTR_H__ + BOOL SetText(LPCTSTR lpszItem); + BOOL GetImage(int& nImage, int& nSelectedImage) const; + BOOL SetImage(int nImage, int nSelectedImage); + UINT GetState(UINT nStateMask) const; + BOOL SetState(UINT nState, UINT nStateMask); + DWORD_PTR GetData() const; + BOOL SetData(DWORD_PTR dwData); + BOOL SetItem(UINT nMask, LPCTSTR lpszItem, int nImage, int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam); + +// Operations + CTreeItemT<TBase> InsertAfter(LPCTSTR lpstrItem, HTREEITEM hItemAfter, int nImageIndex) + { + return _Insert(lpstrItem, nImageIndex, hItemAfter); + } + + CTreeItemT<TBase> AddHead(LPCTSTR lpstrItem, int nImageIndex) + { + return _Insert(lpstrItem, nImageIndex, TVI_FIRST); + } + + CTreeItemT<TBase> AddTail(LPCTSTR lpstrItem, int nImageIndex) + { + return _Insert(lpstrItem, nImageIndex, TVI_LAST); + } + + CTreeItemT<TBase> GetChild() const; + CTreeItemT<TBase> GetNext(UINT nCode) const; + CTreeItemT<TBase> GetNextSibling() const; + CTreeItemT<TBase> GetPrevSibling() const; + CTreeItemT<TBase> GetParent() const; + CTreeItemT<TBase> GetFirstVisible() const; + CTreeItemT<TBase> GetNextVisible() const; + CTreeItemT<TBase> GetPrevVisible() const; + CTreeItemT<TBase> GetSelected() const; + CTreeItemT<TBase> GetDropHilight() const; + CTreeItemT<TBase> GetRoot() const; + CTreeItemT<TBase> GetLastVisible() const; + CTreeItemT<TBase> GetNextSelected() const; + BOOL HasChildren() const; + BOOL Delete(); + BOOL Expand(UINT nCode = TVE_EXPAND); + BOOL Select(UINT nCode); + BOOL Select(); + BOOL SelectDropTarget(); + BOOL SelectSetFirstVisible(); + HWND EditLabel(); + HIMAGELIST CreateDragImage(); + BOOL SortChildren(BOOL bRecurse = FALSE); + BOOL EnsureVisible(); + CTreeItemT<TBase> _Insert(LPCTSTR lpstrItem, int nImageIndex, HTREEITEM hItemAfter); + int GetImageIndex() const; + BOOL SetInsertMark(BOOL bAfter); + UINT MapHTREEITEMToAccID() const; +#if (_WIN32_WINNT >= 0x0600) + void ShowInfoTip(); + BOOL GetPartRect(TVITEMPART partID, LPRECT lpRect) const; +#endif // (_WIN32_WINNT >= 0x0600) +}; + +typedef CTreeItemT<ATL::CWindow> CTreeItem; + + +template <class TBase> +class CTreeViewCtrlExT : public CTreeViewCtrlT< TBase > +{ +public: +// Constructors + CTreeViewCtrlExT(HWND hWnd = NULL) : CTreeViewCtrlT< TBase >(hWnd) + { } + + CTreeViewCtrlExT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + +// Operations (overides that return CTreeItem) + CTreeItemT<TBase> InsertItem(LPTVINSERTSTRUCT lpInsertStruct) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_INSERTITEM, 0, (LPARAM)lpInsertStruct); + return CTreeItemT<TBase>(hTreeItem, this); + } + + CTreeItemT<TBase> InsertItem(LPCTSTR lpszItem, int nImage, + int nSelectedImage, HTREEITEM hParent, HTREEITEM hInsertAfter) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return InsertItem(TVIF_TEXT | TVIF_IMAGE | TVIF_SELECTEDIMAGE, lpszItem, nImage, nSelectedImage, 0, 0, 0, hParent, hInsertAfter); + } + + CTreeItemT<TBase> InsertItem(LPCTSTR lpszItem, HTREEITEM hParent, HTREEITEM hInsertAfter) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return InsertItem(TVIF_TEXT, lpszItem, 0, 0, 0, 0, 0, hParent, hInsertAfter); + } + + CTreeItemT<TBase> GetNextItem(HTREEITEM hItem, UINT nCode) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, nCode, (LPARAM)hItem); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetChildItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_CHILD, (LPARAM)hItem); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetNextSiblingItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXT, (LPARAM)hItem); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetPrevSiblingItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PREVIOUS, (LPARAM)hItem); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetParentItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PARENT, (LPARAM)hItem); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetFirstVisibleItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_FIRSTVISIBLE, 0L); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetNextVisibleItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXTVISIBLE, (LPARAM)hItem); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetPrevVisibleItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PREVIOUSVISIBLE, (LPARAM)hItem); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetSelectedItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_CARET, 0L); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetDropHilightItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_DROPHILITE, 0L); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetRootItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_ROOT, 0L); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetLastVisibleItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_LASTVISIBLE, 0L); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> GetNextSelectedItem(HTREEITEM hItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXTSELECTED, (LPARAM)hItem); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> HitTest(TVHITTESTINFO* pHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_HITTEST, 0, (LPARAM)pHitTestInfo); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> InsertItem(UINT nMask, LPCTSTR lpszItem, int nImage, + int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam, + HTREEITEM hParent, HTREEITEM hInsertAfter) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVINSERTSTRUCT tvis = {}; + tvis.hParent = hParent; + tvis.hInsertAfter = hInsertAfter; + tvis.item.mask = nMask; + tvis.item.pszText = (LPTSTR) lpszItem; + tvis.item.iImage = nImage; + tvis.item.iSelectedImage = nSelectedImage; + tvis.item.state = nState; + tvis.item.stateMask = nStateMask; + tvis.item.lParam = lParam; + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_INSERTITEM, 0, (LPARAM)&tvis); + return CTreeItemT<TBase>(hTreeItem, this); + } + + CTreeItemT<TBase> HitTest(POINT pt, UINT* pFlags) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TVHITTESTINFO hti = {}; + hti.pt = pt; + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_HITTEST, 0, (LPARAM)&hti); + if (pFlags != NULL) + *pFlags = hti.flags; + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } + + CTreeItemT<TBase> MapAccIDToHTREEITEM(UINT uID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_MAPACCIDTOHTREEITEM, uID, 0L); + return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this); + } +}; + +typedef CTreeViewCtrlExT<ATL::CWindow> CTreeViewCtrlEx; + + +// CTreeItem inline methods +template <class TBase> +inline BOOL CTreeItemT<TBase>::GetRect(LPRECT lpRect, BOOL bTextOnly) const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetItemRect(m_hTreeItem,lpRect,bTextOnly); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetNext(UINT nCode) const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetNextItem(m_hTreeItem,nCode); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetChild() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetChildItem(m_hTreeItem); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetNextSibling() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetNextSiblingItem(m_hTreeItem); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetPrevSibling() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetPrevSiblingItem(m_hTreeItem); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetParent() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetParentItem(m_hTreeItem); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetFirstVisible() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetFirstVisibleItem(); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetNextVisible() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetNextVisibleItem(m_hTreeItem); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetPrevVisible() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetPrevVisibleItem(m_hTreeItem); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetSelected() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetSelectedItem(); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetDropHilight() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetDropHilightItem(); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetRoot() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetRootItem(); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetLastVisible() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetLastVisibleItem(); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::GetNextSelected() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetNextSelectedItem(m_hTreeItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::GetText(LPTSTR lpstrText, int nLen) const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetItemText(m_hTreeItem, lpstrText, nLen); +} + +#ifdef _OLEAUTO_H_ +template <class TBase> +inline BOOL CTreeItemT<TBase>::GetText(BSTR& bstrText) const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetItemText(m_hTreeItem, bstrText); +} +#endif // _OLEAUTO_H_ + +#ifdef __ATLSTR_H__ +template <class TBase> +inline BOOL CTreeItemT<TBase>::GetText(ATL::CString& strText) const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetItemText(m_hTreeItem, strText); +} +#endif // __ATLSTR_H__ + +template <class TBase> +inline BOOL CTreeItemT<TBase>::GetImage(int& nImage, int& nSelectedImage) const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetItemImage(m_hTreeItem,nImage,nSelectedImage); +} + +template <class TBase> +inline UINT CTreeItemT<TBase>::GetState(UINT nStateMask) const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetItemState(m_hTreeItem,nStateMask); +} + +template <class TBase> +inline DWORD_PTR CTreeItemT<TBase>::GetData() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetItemData(m_hTreeItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SetItem(UINT nMask, LPCTSTR lpszItem, int nImage, + int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SetItem(m_hTreeItem, nMask, lpszItem, nImage, nSelectedImage, nState, nStateMask, lParam); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SetText(LPCTSTR lpszItem) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SetItemText(m_hTreeItem,lpszItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SetImage(int nImage, int nSelectedImage) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SetItemImage(m_hTreeItem,nImage,nSelectedImage); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SetState(UINT nState, UINT nStateMask) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SetItemState(m_hTreeItem,nState,nStateMask); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SetData(DWORD_PTR dwData) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SetItemData(m_hTreeItem,dwData); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::HasChildren() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->ItemHasChildren(m_hTreeItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::Delete() +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->DeleteItem(m_hTreeItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::Expand(UINT nCode /*= TVE_EXPAND*/) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->Expand(m_hTreeItem,nCode); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::Select(UINT nCode) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->Select(m_hTreeItem,nCode); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::Select() +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SelectItem(m_hTreeItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SelectDropTarget() +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SelectDropTarget(m_hTreeItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SelectSetFirstVisible() +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SelectSetFirstVisible(m_hTreeItem); +} + +template <class TBase> +inline HWND CTreeItemT<TBase>::EditLabel() +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->EditLabel(m_hTreeItem); +} + +template <class TBase> +inline HIMAGELIST CTreeItemT<TBase>::CreateDragImage() +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->CreateDragImage(m_hTreeItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SortChildren(BOOL bRecurse /*= FALSE*/) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SortChildren(m_hTreeItem, bRecurse); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::EnsureVisible() +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->EnsureVisible(m_hTreeItem); +} + +template <class TBase> +inline CTreeItemT<TBase> CTreeItemT<TBase>::_Insert(LPCTSTR lpstrItem, int nImageIndex, HTREEITEM hItemAfter) +{ + ATLASSERT(m_pTreeView != NULL); + TVINSERTSTRUCT ins = {}; + ins.hParent = m_hTreeItem; + ins.hInsertAfter = hItemAfter; + ins.item.mask = TVIF_TEXT; + ins.item.pszText = (LPTSTR)lpstrItem; + if(nImageIndex != -1) + { + ins.item.mask |= TVIF_IMAGE | TVIF_SELECTEDIMAGE; + ins.item.iImage = nImageIndex; + ins.item.iSelectedImage = nImageIndex; + } + return CTreeItemT<TBase>(m_pTreeView->InsertItem(&ins), m_pTreeView); +} + +template <class TBase> +inline int CTreeItemT<TBase>::GetImageIndex() const +{ + ATLASSERT(m_pTreeView != NULL); + TVITEM item = {}; + item.mask = TVIF_HANDLE | TVIF_IMAGE; + item.hItem = m_hTreeItem; + m_pTreeView->GetItem(&item); + return item.iImage; +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::SetInsertMark(BOOL bAfter) +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->SetInsertMark(m_hTreeItem, bAfter); +} + +template <class TBase> +inline UINT CTreeItemT<TBase>::MapHTREEITEMToAccID() const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->MapHTREEITEMToAccID(m_hTreeItem); +} + +#if (_WIN32_WINNT >= 0x0600) +template <class TBase> +inline void CTreeItemT<TBase>::ShowInfoTip() +{ + ATLASSERT(m_pTreeView != NULL); + m_pTreeView->ShowInfoTip(m_hTreeItem); +} + +template <class TBase> +inline BOOL CTreeItemT<TBase>::GetPartRect(TVITEMPART partID, LPRECT lpRect) const +{ + ATLASSERT(m_pTreeView != NULL); + return m_pTreeView->GetItemPartRect(m_hTreeItem, partID, lpRect); +} +#endif // (_WIN32_WINNT >= 0x0600) + + +/////////////////////////////////////////////////////////////////////////////// +// CToolBarCtrl + +template <class TBase> +class CToolBarCtrlT : public TBase +{ +public: +// Construction + CToolBarCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CToolBarCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return TOOLBARCLASSNAME; + } + + BOOL IsButtonEnabled(int nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONENABLED, nID, 0L); + } + + BOOL IsButtonChecked(int nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONCHECKED, nID, 0L); + } + + BOOL IsButtonPressed(int nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONPRESSED, nID, 0L); + } + + BOOL IsButtonHidden(int nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return(BOOL) ::SendMessage(this->m_hWnd, TB_ISBUTTONHIDDEN, nID, 0L); + } + + BOOL IsButtonIndeterminate(int nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONINDETERMINATE, nID, 0L); + } + + int GetState(int nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_GETSTATE, nID, 0L); + } + + BOOL SetState(int nID, UINT nState) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETSTATE, nID, MAKELPARAM(nState, 0)); + } + + BOOL GetButton(int nIndex, LPTBBUTTON lpButton) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_GETBUTTON, nIndex, (LPARAM)lpButton); + } + + int GetButtonCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_BUTTONCOUNT, 0, 0L); + } + + BOOL GetItemRect(int nIndex, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_GETITEMRECT, nIndex, (LPARAM)lpRect); + } + + void SetButtonStructSize(int nSize = sizeof(TBBUTTON)) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_BUTTONSTRUCTSIZE, nSize, 0L); + } + + BOOL SetButtonSize(SIZE size) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONSIZE, 0, MAKELPARAM(size.cx, size.cy)); + } + + BOOL SetButtonSize(int cx, int cy) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONSIZE, 0, MAKELPARAM(cx, cy)); + } + + BOOL SetBitmapSize(SIZE size) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETBITMAPSIZE, 0, MAKELPARAM(size.cx, size.cy)); + } + + BOOL SetBitmapSize(int cx, int cy) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETBITMAPSIZE, 0, MAKELPARAM(cx, cy)); + } + + CToolTipCtrl GetToolTips() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TB_GETTOOLTIPS, 0, 0L)); + } + + void SetToolTips(HWND hWndToolTip) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_SETTOOLTIPS, (WPARAM)hWndToolTip, 0L); + } + + void SetNotifyWnd(HWND hWnd) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_SETPARENT, (WPARAM)hWnd, 0L); + } + + int GetRows() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_GETROWS, 0, 0L); + } + + void SetRows(int nRows, BOOL bLarger, LPRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_SETROWS, MAKELPARAM(nRows, bLarger), (LPARAM)lpRect); + } + + BOOL SetCmdID(int nIndex, UINT nID) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETCMDID, nIndex, nID); + } + + DWORD GetBitmapFlags() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TB_GETBITMAPFLAGS, 0, 0L); + } + + int GetBitmap(int nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_GETBITMAP, nID, 0L); + } + + int GetButtonText(int nID, LPTSTR lpstrText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_GETBUTTONTEXT, nID, (LPARAM)lpstrText); + } + + // nIndex - IE5 or higher only + CImageList GetImageList(int nIndex = 0) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_GETIMAGELIST, nIndex, 0L)); + } + + // nIndex - IE5 or higher only + CImageList SetImageList(HIMAGELIST hImageList, int nIndex = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_SETIMAGELIST, nIndex, (LPARAM)hImageList)); + } + + // nIndex - IE5 or higher only + CImageList GetDisabledImageList(int nIndex = 0) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_GETDISABLEDIMAGELIST, nIndex, 0L)); + } + + // nIndex - IE5 or higher only + CImageList SetDisabledImageList(HIMAGELIST hImageList, int nIndex = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_SETDISABLEDIMAGELIST, nIndex, (LPARAM)hImageList)); + } + + // nIndex - IE5 or higher only + CImageList GetHotImageList(int nIndex = 0) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_GETHOTIMAGELIST, nIndex, 0L)); + } + + // nIndex - IE5 or higher only + CImageList SetHotImageList(HIMAGELIST hImageList, int nIndex = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_SETHOTIMAGELIST, nIndex, (LPARAM)hImageList)); + } + + DWORD GetStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TB_GETSTYLE, 0, 0L); + } + + void SetStyle(DWORD dwStyle) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_SETSTYLE, 0, dwStyle); + } + + DWORD GetButtonSize() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TB_GETBUTTONSIZE, 0, 0L); + } + + void GetButtonSize(SIZE& size) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, TB_GETBUTTONSIZE, 0, 0L); + size.cx = LOWORD(dwRet); + size.cy = HIWORD(dwRet); + } + + BOOL GetRect(int nID, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_GETRECT, nID, (LPARAM)lpRect); + } + + int GetTextRows() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_GETTEXTROWS, 0, 0L); + } + + BOOL SetButtonWidth(int cxMin, int cxMax) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONWIDTH, 0, MAKELPARAM(cxMin, cxMax)); + } + + BOOL SetIndent(int nIndent) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETINDENT, nIndent, 0L); + } + + BOOL SetMaxTextRows(int nMaxTextRows) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETMAXTEXTROWS, nMaxTextRows, 0L); + } + + BOOL GetAnchorHighlight() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_GETANCHORHIGHLIGHT, 0, 0L); + } + + BOOL SetAnchorHighlight(BOOL bEnable = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETANCHORHIGHLIGHT, bEnable, 0L); + } + + int GetButtonInfo(int nID, LPTBBUTTONINFO lptbbi) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_GETBUTTONINFO, nID, (LPARAM)lptbbi); + } + + BOOL SetButtonInfo(int nID, LPTBBUTTONINFO lptbbi) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONINFO, nID, (LPARAM)lptbbi); + } + + BOOL SetButtonInfo(int nID, DWORD dwMask, BYTE Style, BYTE State, LPCTSTR lpszItem, + int iImage, WORD cx, int iCommand, DWORD_PTR lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TBBUTTONINFO tbbi = {}; + tbbi.cbSize = sizeof(TBBUTTONINFO); + tbbi.dwMask = dwMask; + tbbi.idCommand = iCommand; + tbbi.iImage = iImage; + tbbi.fsState = State; + tbbi.fsStyle = Style; + tbbi.cx = cx; + tbbi.pszText = (LPTSTR) lpszItem; + tbbi.lParam = lParam; + return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONINFO, nID, (LPARAM)&tbbi); + } + + int GetHotItem() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_GETHOTITEM, 0, 0L); + } + + int SetHotItem(int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_SETHOTITEM, nItem, 0L); + } + + BOOL IsButtonHighlighted(int nButtonID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONHIGHLIGHTED, nButtonID, 0L); + } + + DWORD SetDrawTextFlags(DWORD dwMask, DWORD dwFlags) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TB_SETDRAWTEXTFLAGS, dwMask, dwFlags); + } + + BOOL GetColorScheme(LPCOLORSCHEME lpcs) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_GETCOLORSCHEME, 0, (LPARAM)lpcs); + } + + void SetColorScheme(LPCOLORSCHEME lpcs) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_SETCOLORSCHEME, 0, (LPARAM)lpcs); + } + + DWORD GetExtendedStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TB_GETEXTENDEDSTYLE, 0, 0L); + } + + DWORD SetExtendedStyle(DWORD dwStyle) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TB_SETEXTENDEDSTYLE, 0, dwStyle); + } + + void GetInsertMark(LPTBINSERTMARK lptbim) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_GETINSERTMARK, 0, (LPARAM)lptbim); + } + + void SetInsertMark(LPTBINSERTMARK lptbim) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_SETINSERTMARK, 0, (LPARAM)lptbim); + } + + COLORREF GetInsertMarkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TB_GETINSERTMARKCOLOR, 0, 0L); + } + + COLORREF SetInsertMarkColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, TB_SETINSERTMARKCOLOR, 0, (LPARAM)clr); + } + + BOOL GetMaxSize(LPSIZE lpSize) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_GETMAXSIZE, 0, (LPARAM)lpSize); + } + + void GetPadding(LPSIZE lpSizePadding) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpSizePadding != NULL); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, TB_GETPADDING, 0, 0L); + lpSizePadding->cx = GET_X_LPARAM(dwRet); + lpSizePadding->cy = GET_Y_LPARAM(dwRet); + } + + void SetPadding(int cx, int cy, LPSIZE lpSizePadding = NULL) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, TB_SETPADDING, 0, MAKELPARAM(cx, cy)); + if(lpSizePadding != NULL) + { + lpSizePadding->cx = GET_X_LPARAM(dwRet); + lpSizePadding->cy = GET_Y_LPARAM(dwRet); + } + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_SETUNICODEFORMAT, bUnicode, 0L); + } + + int GetString(int nString, LPTSTR lpstrString, int cchMaxLen) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(cchMaxLen, nString), (LPARAM)lpstrString); + } + + int GetStringBSTR(int nString, BSTR& bstrString) const + { + USES_CONVERSION; + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(bstrString == NULL); + int nLength = (int)(short)LOWORD(::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(0, nString), NULL)); + if(nLength != -1) + { + ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff; + LPTSTR lpstrText = buff.Allocate(nLength + 1); + if(lpstrText != NULL) + { + nLength = (int)::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(nLength + 1, nString), (LPARAM)lpstrText); + if(nLength != -1) + bstrString = ::SysAllocString(T2OLE(lpstrText)); + } + else + { + nLength = -1; + } + } + + return nLength; + } + +#ifdef __ATLSTR_H__ + int GetString(int nString, ATL::CString& str) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + int nLength = (int)(short)LOWORD(::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(0, nString), NULL)); + if(nLength != -1) + { + LPTSTR lpstr = str.GetBufferSetLength(nLength + 1); + if(lpstr != NULL) + nLength = (int)::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(nLength + 1, nString), (LPARAM)lpstr); + else + nLength = -1; + str.ReleaseBuffer(); + } + return nLength; + } +#endif // __ATLSTR_H__ + + void GetMetrics(LPTBMETRICS lptbm) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_GETMETRICS, 0, (LPARAM)lptbm); + } + + void SetMetrics(LPTBMETRICS lptbm) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_SETMETRICS, 0, (LPARAM)lptbm); + } + + void SetWindowTheme(LPCWSTR lpstrTheme) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_SETWINDOWTHEME, 0, (LPARAM)lpstrTheme); + } + +#if (_WIN32_WINNT >= 0x0600) + CImageList GetPressedImageList(int nIndex = 0) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_GETPRESSEDIMAGELIST, nIndex, 0L)); + } + + CImageList SetPressedImageList(HIMAGELIST hImageList, int nIndex = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_SETPRESSEDIMAGELIST, nIndex, (LPARAM)hImageList)); + } + + void GetItemDropDownRect(int nIndex, LPRECT lpRect) const + { +#ifndef TB_GETITEMDROPDOWNRECT + const int TB_GETITEMDROPDOWNRECT = WM_USER + 103; +#endif + ATLASSERT(::IsWindow(this->m_hWnd)); + BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, TB_GETITEMDROPDOWNRECT, nIndex, (LPARAM)lpRect); + (void)bRet; // avoid level 4 warning + ATLASSERT(bRet != FALSE); + } +#endif // (_WIN32_WINNT >= 0x0600) + +// Operations + BOOL EnableButton(int nID, BOOL bEnable = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_ENABLEBUTTON, nID, MAKELPARAM(bEnable, 0)); + } + + BOOL CheckButton(int nID, BOOL bCheck = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_CHECKBUTTON, nID, MAKELPARAM(bCheck, 0)); + } + + BOOL PressButton(int nID, BOOL bPress = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_PRESSBUTTON, nID, MAKELPARAM(bPress, 0)); + } + + BOOL HideButton(int nID, BOOL bHide = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_HIDEBUTTON, nID, MAKELPARAM(bHide, 0)); + } + + BOOL Indeterminate(int nID, BOOL bIndeterminate = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_INDETERMINATE, nID, MAKELPARAM(bIndeterminate, 0)); + } + + int AddBitmap(int nNumButtons, UINT nBitmapID) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TBADDBITMAP tbab = {}; + tbab.hInst = ModuleHelper::GetResourceInstance(); + ATLASSERT(tbab.hInst != NULL); + tbab.nID = nBitmapID; + return (int)::SendMessage(this->m_hWnd, TB_ADDBITMAP, (WPARAM)nNumButtons, (LPARAM)&tbab); + } + + int AddBitmap(int nNumButtons, HBITMAP hBitmap) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TBADDBITMAP tbab = {}; + tbab.hInst = NULL; + tbab.nID = (UINT_PTR)hBitmap; + return (int)::SendMessage(this->m_hWnd, TB_ADDBITMAP, (WPARAM)nNumButtons, (LPARAM)&tbab); + } + + BOOL AddButtons(int nNumButtons, LPCTBBUTTON lpButtons) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_ADDBUTTONS, nNumButtons, (LPARAM)lpButtons); + } + + BOOL InsertButton(int nIndex, LPCTBBUTTON lpButton) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_INSERTBUTTON, nIndex, (LPARAM)lpButton); + } + + BOOL InsertButton(int nIndex, int iCommand, BYTE Style, BYTE State, int iBitmap, + INT_PTR iString, DWORD_PTR lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TBBUTTON tbb = {}; + tbb.fsStyle = Style; + tbb.fsState = State; + tbb.idCommand = iCommand; + tbb.iBitmap = iBitmap; + tbb.iString = iString; + tbb.dwData = lParam; + return (BOOL)::SendMessage(this->m_hWnd, TB_INSERTBUTTON, nIndex, (LPARAM)&tbb); + } + + BOOL InsertButton(int nIndex, int iCommand, BYTE Style, BYTE State, int iBitmap, + LPCTSTR lpszItem, DWORD_PTR lParam) + { + return InsertButton(nIndex, iCommand, Style, State, iBitmap, (INT_PTR)lpszItem, lParam); + } + + BOOL AddButton(LPTBBUTTON lpButton) + { + return InsertButton(-1, lpButton); + } + + BOOL AddButton(int iCommand, BYTE Style, BYTE State, int iBitmap, INT_PTR iString, DWORD_PTR lParam) + { + return InsertButton(-1, iCommand, Style, State, iBitmap, iString, lParam); + } + + BOOL AddButton(int iCommand, BYTE Style, BYTE State, int iBitmap, LPCTSTR lpszItem, DWORD_PTR lParam) + { + return InsertButton(-1, iCommand, Style, State, iBitmap, lpszItem, lParam); + } + + BOOL DeleteButton(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_DELETEBUTTON, nIndex, 0L); + } + + BOOL InsertSeparator(int nIndex, int cxWidth = 8) + { + return InsertButton(nIndex, 0, BTNS_SEP, 0, cxWidth, (INT_PTR)0, 0); + } + + BOOL AddSeparator(int cxWidth = 8) + { + return AddButton(0, BTNS_SEP, 0, cxWidth, (INT_PTR)0, 0); + } + + int CommandToIndex(UINT nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_COMMANDTOINDEX, nID, 0L); + } + + void SaveState(HKEY hKeyRoot, LPCTSTR lpszSubKey, LPCTSTR lpszValueName) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TBSAVEPARAMS tbs = {}; + tbs.hkr = hKeyRoot; + tbs.pszSubKey = lpszSubKey; + tbs.pszValueName = lpszValueName; + ::SendMessage(this->m_hWnd, TB_SAVERESTORE, (WPARAM)TRUE, (LPARAM)&tbs); + } + + void RestoreState(HKEY hKeyRoot, LPCTSTR lpszSubKey, LPCTSTR lpszValueName) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TBSAVEPARAMS tbs = {}; + tbs.hkr = hKeyRoot; + tbs.pszSubKey = lpszSubKey; + tbs.pszValueName = lpszValueName; + ::SendMessage(this->m_hWnd, TB_SAVERESTORE, (WPARAM)FALSE, (LPARAM)&tbs); + } + + void Customize() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_CUSTOMIZE, 0, 0L); + } + + int AddString(UINT nStringID) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_ADDSTRING, (WPARAM)ModuleHelper::GetResourceInstance(), (LPARAM)nStringID); + } + + int AddStrings(LPCTSTR lpszStrings) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_ADDSTRING, 0, (LPARAM)lpszStrings); + } + + void AutoSize() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TB_AUTOSIZE, 0, 0L); + } + + BOOL ChangeBitmap(int nID, int nBitmap) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_CHANGEBITMAP, nID, MAKELPARAM(nBitmap, 0)); + } + + int LoadImages(int nBitmapID) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_LOADIMAGES, nBitmapID, (LPARAM)ModuleHelper::GetResourceInstance()); + } + + int LoadStdImages(int nBitmapID) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_LOADIMAGES, nBitmapID, (LPARAM)HINST_COMMCTRL); + } + + BOOL ReplaceBitmap(LPTBREPLACEBITMAP ptbrb) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_REPLACEBITMAP, 0, (LPARAM)ptbrb); + } + + int HitTest(LPPOINT lpPoint) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TB_HITTEST, 0, (LPARAM)lpPoint); + } + + BOOL InsertMarkHitTest(LPPOINT lpPoint, LPTBINSERTMARK lptbim) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_INSERTMARKHITTEST, (WPARAM)lpPoint, (LPARAM)lptbim); + } + + BOOL InsertMarkHitTest(int x, int y, LPTBINSERTMARK lptbim) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + POINT pt = { x, y }; + return (BOOL)::SendMessage(this->m_hWnd, TB_INSERTMARKHITTEST, (WPARAM)&pt, (LPARAM)lptbim); + } + + BOOL MapAccelerator(TCHAR chAccel, int& nID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_MAPACCELERATOR, (WPARAM)chAccel, (LPARAM)&nID); + } + + BOOL MarkButton(int nID, BOOL bHighlight = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_MARKBUTTON, nID, MAKELPARAM(bHighlight, 0)); + } + + BOOL MoveButton(int nOldPos, int nNewPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TB_MOVEBUTTON, nOldPos, nNewPos); + } + + HRESULT GetObject(REFIID iid, LPVOID* ppvObject) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HRESULT)::SendMessage(this->m_hWnd, TB_GETOBJECT, (WPARAM)&iid, (LPARAM)ppvObject); + } +}; + +typedef CToolBarCtrlT<ATL::CWindow> CToolBarCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CStatusBarCtrl + +template <class TBase> +class CStatusBarCtrlT : public TBase +{ +public: +// Constructors + CStatusBarCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CStatusBarCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Methods + static LPCTSTR GetWndClassName() + { + return STATUSCLASSNAME; + } + + int GetParts(int nParts, int* pParts) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, SB_GETPARTS, nParts, (LPARAM)pParts); + } + + BOOL SetParts(int nParts, int* pWidths) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, SB_SETPARTS, nParts, (LPARAM)pWidths); + } + + int GetTextLength(int nPane, int* pType = NULL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, SB_GETTEXTLENGTH, (WPARAM)nPane, 0L); + if (pType != NULL) + *pType = (int)(short)HIWORD(dwRet); + return (int)(short)LOWORD(dwRet); + } + + int GetText(int nPane, LPTSTR lpszText, int* pType = NULL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, SB_GETTEXT, (WPARAM)nPane, (LPARAM)lpszText); + if(pType != NULL) + *pType = (int)(short)HIWORD(dwRet); + return (int)(short)LOWORD(dwRet); + } + + BOOL GetTextBSTR(int nPane, BSTR& bstrText, int* pType = NULL) const + { + USES_CONVERSION; + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + ATLASSERT(bstrText == NULL); + int nLength = (int)(short)LOWORD(::SendMessage(this->m_hWnd, SB_GETTEXTLENGTH, (WPARAM)nPane, 0L)); + if(nLength == 0) + return FALSE; + + ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff; + LPTSTR lpstrText = buff.Allocate(nLength + 1); + if(lpstrText == NULL) + return FALSE; + + if(!GetText(nPane, lpstrText, pType)) + return FALSE; + + bstrText = ::SysAllocString(T2OLE(lpstrText)); + return (bstrText != NULL) ? TRUE : FALSE; + } + +#ifdef __ATLSTR_H__ + int GetText(int nPane, ATL::CString& strText, int* pType = NULL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + int nLength = (int)(short)LOWORD(::SendMessage(this->m_hWnd, SB_GETTEXTLENGTH, (WPARAM)nPane, 0L)); + if(nLength == 0) + return 0; + + LPTSTR lpstr = strText.GetBufferSetLength(nLength); + if(lpstr == NULL) + return 0; + return GetText(nPane, lpstr, pType); + } +#endif // __ATLSTR_H__ + + BOOL SetText(int nPane, LPCTSTR lpszText, int nType = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + return (BOOL)::SendMessage(this->m_hWnd, SB_SETTEXT, (nPane | nType), (LPARAM)lpszText); + } + + BOOL GetRect(int nPane, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + return (BOOL)::SendMessage(this->m_hWnd, SB_GETRECT, nPane, (LPARAM)lpRect); + } + + BOOL GetBorders(int* pBorders) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, SB_GETBORDERS, 0, (LPARAM)pBorders); + } + + BOOL GetBorders(int& nHorz, int& nVert, int& nSpacing) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + int borders[3] = {}; + BOOL bResult = (BOOL)::SendMessage(this->m_hWnd, SB_GETBORDERS, 0, (LPARAM)&borders); + if(bResult) + { + nHorz = borders[0]; + nVert = borders[1]; + nSpacing = borders[2]; + } + return bResult; + } + + void SetMinHeight(int nMin) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, SB_SETMINHEIGHT, nMin, 0L); + } + + BOOL SetSimple(BOOL bSimple = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, SB_SIMPLE, bSimple, 0L); + } + + BOOL IsSimple() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, SB_ISSIMPLE, 0, 0L); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, SB_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, SB_SETUNICODEFORMAT, bUnicode, 0L); + } + + void GetTipText(int nPane, LPTSTR lpstrText, int nSize) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + ::SendMessage(this->m_hWnd, SB_GETTIPTEXT, MAKEWPARAM(nPane, nSize), (LPARAM)lpstrText); + } + + void SetTipText(int nPane, LPCTSTR lpstrText) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + ::SendMessage(this->m_hWnd, SB_SETTIPTEXT, nPane, (LPARAM)lpstrText); + } + + COLORREF SetBkColor(COLORREF clrBk) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, SB_SETBKCOLOR, 0, (LPARAM)clrBk); + } + + HICON GetIcon(int nPane) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + return (HICON)::SendMessage(this->m_hWnd, SB_GETICON, nPane, 0L); + } + + BOOL SetIcon(int nPane, HICON hIcon) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(nPane < 256); + return (BOOL)::SendMessage(this->m_hWnd, SB_SETICON, nPane, (LPARAM)hIcon); + } +}; + +typedef CStatusBarCtrlT<ATL::CWindow> CStatusBarCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CTabCtrl + +template <class TBase> +class CTabCtrlT : public TBase +{ +public: +// Constructors + CTabCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CTabCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return WC_TABCONTROL; + } + + CImageList GetImageList() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TCM_GETIMAGELIST, 0, 0L)); + } + + CImageList SetImageList(HIMAGELIST hImageList) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TCM_SETIMAGELIST, 0, (LPARAM)hImageList)); + } + + int GetItemCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TCM_GETITEMCOUNT, 0, 0L); + } + + BOOL GetItem(int nItem, LPTCITEM pTabCtrlItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TCM_GETITEM, nItem, (LPARAM)pTabCtrlItem); + } + + BOOL SetItem(int nItem, LPTCITEM pTabCtrlItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TCM_SETITEM, nItem, (LPARAM)pTabCtrlItem); + } + + int SetItem(int nItem, UINT mask, LPCTSTR lpszItem, DWORD dwState, DWORD dwStateMask, int iImage, LPARAM lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TCITEM tci = {}; + tci.mask = mask; + tci.pszText = (LPTSTR) lpszItem; + tci.dwState = dwState; + tci.dwStateMask = dwStateMask; + tci.iImage = iImage; + tci.lParam = lParam; + return (int)::SendMessage(this->m_hWnd, TCM_SETITEM, nItem, (LPARAM)&tci); + } + + BOOL GetItemRect(int nItem, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TCM_GETITEMRECT, nItem, (LPARAM)lpRect); + } + + int GetCurSel() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TCM_GETCURSEL, 0, 0L); + } + + int SetCurSel(int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TCM_SETCURSEL, nItem, 0L); + } + + SIZE SetItemSize(SIZE size) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwSize = (DWORD)::SendMessage(this->m_hWnd, TCM_SETITEMSIZE, 0, MAKELPARAM(size.cx, size.cy)); + SIZE sizeRet = { GET_X_LPARAM(dwSize), GET_Y_LPARAM(dwSize) }; + return sizeRet; + } + + void SetItemSize(int cx, int cy) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TCM_SETITEMSIZE, 0, MAKELPARAM(cx, cy)); + } + + void SetPadding(SIZE size) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TCM_SETPADDING, 0, MAKELPARAM(size.cx, size.cy)); + } + + int GetRowCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TCM_GETROWCOUNT, 0, 0L); + } + + CToolTipCtrl GetToolTips() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TCM_GETTOOLTIPS, 0, 0L)); + } + + // this method is deprecated, please use GetToolTips + CToolTipCtrl GetTooltips() const { return GetToolTips(); } + + void SetToolTips(HWND hWndToolTip) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TCM_SETTOOLTIPS, (WPARAM)hWndToolTip, 0L); + } + + // this method is deprecated, please use SetToolTips + void SetTooltips(HWND hWndToolTip) { SetToolTips(hWndToolTip); } + + int GetCurFocus() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TCM_GETCURFOCUS, 0, 0L); + } + + void SetCurFocus(int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TCM_SETCURFOCUS, nItem, 0L); + } + + BOOL SetItemExtra(int cbExtra) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(GetItemCount() == 0); // must be empty + return (BOOL)::SendMessage(this->m_hWnd, TCM_SETITEMEXTRA, cbExtra, 0L); + } + + int SetMinTabWidth(int nWidth = -1) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TCM_SETMINTABWIDTH, 0, nWidth); + } + + DWORD GetExtendedStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TCM_GETEXTENDEDSTYLE, 0, 0L); + } + + DWORD SetExtendedStyle(DWORD dwExMask, DWORD dwExStyle) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, TCM_SETEXTENDEDSTYLE, dwExMask, dwExStyle); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TCM_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TCM_SETUNICODEFORMAT, bUnicode, 0L); + } + +// Operations + int InsertItem(int nItem, LPTCITEM pTabCtrlItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TCM_INSERTITEM, nItem, (LPARAM)pTabCtrlItem); + } + + int InsertItem(int nItem, UINT mask, LPCTSTR lpszItem, int iImage, LPARAM lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TCITEM tci = {}; + tci.mask = mask; + tci.pszText = (LPTSTR) lpszItem; + tci.iImage = iImage; + tci.lParam = lParam; + return (int)::SendMessage(this->m_hWnd, TCM_INSERTITEM, nItem, (LPARAM)&tci); + } + + int InsertItem(int nItem, LPCTSTR lpszItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TCITEM tci = {}; + tci.mask = TCIF_TEXT; + tci.pszText = (LPTSTR) lpszItem; + return (int)::SendMessage(this->m_hWnd, TCM_INSERTITEM, nItem, (LPARAM)&tci); + } + + int AddItem(LPTCITEM pTabCtrlItem) + { + return InsertItem(GetItemCount(), pTabCtrlItem); + } + + int AddItem(UINT mask, LPCTSTR lpszItem, int iImage, LPARAM lParam) + { + return InsertItem(GetItemCount(), mask, lpszItem, iImage, lParam); + } + + int AddItem(LPCTSTR lpszItem) + { + return InsertItem(GetItemCount(), lpszItem); + } + + BOOL DeleteItem(int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TCM_DELETEITEM, nItem, 0L); + } + + BOOL DeleteAllItems() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TCM_DELETEALLITEMS, 0, 0L); + } + + void AdjustRect(BOOL bLarger, LPRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TCM_ADJUSTRECT, bLarger, (LPARAM)lpRect); + } + + void RemoveImage(int nImage) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TCM_REMOVEIMAGE, nImage, 0L); + } + + int HitTest(TC_HITTESTINFO* pHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TCM_HITTEST, 0, (LPARAM)pHitTestInfo); + } + + void DeselectAll(BOOL bExcludeFocus = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TCM_DESELECTALL, bExcludeFocus, 0L); + } + + BOOL HighlightItem(int nIndex, BOOL bHighlight = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TCM_HIGHLIGHTITEM, nIndex, MAKELPARAM(bHighlight, 0)); + } +}; + +typedef CTabCtrlT<ATL::CWindow> CTabCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CTrackBarCtrl + +template <class TBase> +class CTrackBarCtrlT : public TBase +{ +public: +// Constructors + CTrackBarCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CTrackBarCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return TRACKBAR_CLASS; + } + + int GetLineSize() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETLINESIZE, 0, 0L); + } + + int SetLineSize(int nSize) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_SETLINESIZE, 0, nSize); + } + + int GetPageSize() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETPAGESIZE, 0, 0L); + } + + int SetPageSize(int nSize) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_SETPAGESIZE, 0, nSize); + } + + int GetRangeMin() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETRANGEMIN, 0, 0L); + } + + void SetRangeMin(int nMin, BOOL bRedraw = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETRANGEMIN, bRedraw, nMin); + } + + int GetRangeMax() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETRANGEMAX, 0, 0L); + } + + void SetRangeMax(int nMax, BOOL bRedraw = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETRANGEMAX, bRedraw, nMax); + } + + void GetRange(int& nMin, int& nMax) const + { + nMin = GetRangeMin(); + nMax = GetRangeMax(); + } + + void SetRange(int nMin, int nMax, BOOL bRedraw = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETRANGE, bRedraw, MAKELPARAM(nMin, nMax)); + } + + int GetSelStart() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETSELSTART, 0, 0L); + } + + void SetSelStart(int nMin, BOOL bRedraw = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETSELSTART, bRedraw, (LPARAM)nMin); + } + + int GetSelEnd() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETSELEND, 0, 0L); + } + + void SetSelEnd(int nMax, BOOL bRedraw = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETSELEND, bRedraw, (LPARAM)nMax); + } + + void GetSelection(int& nMin, int& nMax) const + { + nMin = GetSelStart(); + nMax = GetSelEnd(); + } + + void SetSelection(int nMin, int nMax, BOOL bRedraw = TRUE) + { + SetSelStart(nMin, FALSE); + SetSelEnd(nMax, bRedraw); + } + + void GetChannelRect(LPRECT lprc) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_GETCHANNELRECT, 0, (LPARAM)lprc); + } + + void GetThumbRect(LPRECT lprc) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_GETTHUMBRECT, 0, (LPARAM)lprc); + } + + int GetPos() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETPOS, 0, 0L); + } + + void SetPos(int nPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETPOS, TRUE, nPos); + } + + UINT GetNumTics() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, TBM_GETNUMTICS, 0, 0L); + } + + DWORD* GetTicArray() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD*)::SendMessage(this->m_hWnd, TBM_GETPTICS, 0, 0L); + } + + int GetTic(int nTic) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETTIC, nTic, 0L); + } + + BOOL SetTic(int nTic) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TBM_SETTIC, 0, nTic); + } + + int GetTicPos(int nTic) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETTICPOS, nTic, 0L); + } + + void SetTicFreq(int nFreq) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETTICFREQ, nFreq, 0L); + } + + int GetThumbLength() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_GETTHUMBLENGTH, 0, 0L); + } + + void SetThumbLength(int nLength) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETTHUMBLENGTH, nLength, 0L); + } + + void SetSel(int nStart, int nEnd, BOOL bRedraw = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & TBS_ENABLESELRANGE) != 0); + ::SendMessage(this->m_hWnd, TBM_SETSEL, bRedraw, MAKELPARAM(nStart, nEnd)); + } + + ATL::CWindow GetBuddy(BOOL bLeft = TRUE) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, TBM_GETBUDDY, bLeft, 0L)); + } + + ATL::CWindow SetBuddy(HWND hWndBuddy, BOOL bLeft = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, TBM_SETBUDDY, bLeft, (LPARAM)hWndBuddy)); + } + + CToolTipCtrl GetToolTips() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TBM_GETTOOLTIPS, 0, 0L)); + } + + void SetToolTips(HWND hWndTT) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETTOOLTIPS, (WPARAM)hWndTT, 0L); + } + + int SetTipSide(int nSide) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, TBM_SETTIPSIDE, nSide, 0L); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TBM_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, TBM_SETUNICODEFORMAT, bUnicode, 0L); + } + +// Operations + void ClearSel(BOOL bRedraw = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_CLEARSEL, bRedraw, 0L); + } + + void VerifyPos() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_SETPOS, FALSE, 0L); + } + + void ClearTics(BOOL bRedraw = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, TBM_CLEARTICS, bRedraw, 0L); + } +}; + +typedef CTrackBarCtrlT<ATL::CWindow> CTrackBarCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CUpDownCtrl + +template <class TBase> +class CUpDownCtrlT : public TBase +{ +public: +// Constructors + CUpDownCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CUpDownCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return UPDOWN_CLASS; + } + + UINT GetAccel(int nAccel, UDACCEL* pAccel) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)LOWORD(::SendMessage(this->m_hWnd, UDM_GETACCEL, nAccel, (LPARAM)pAccel)); + } + + BOOL SetAccel(int nAccel, UDACCEL* pAccel) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)LOWORD(::SendMessage(this->m_hWnd, UDM_SETACCEL, nAccel, (LPARAM)pAccel)); + } + + UINT GetBase() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)LOWORD(::SendMessage(this->m_hWnd, UDM_GETBASE, 0, 0L)); + } + + int SetBase(int nBase) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, UDM_SETBASE, nBase, 0L); + } + + ATL::CWindow GetBuddy() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, UDM_GETBUDDY, 0, 0L)); + } + + ATL::CWindow SetBuddy(HWND hWndBuddy) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, UDM_SETBUDDY, (WPARAM)hWndBuddy, 0L)); + } + + int GetPos(LPBOOL lpbError = NULL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, UDM_GETPOS, 0, 0L); + // Note: Seems that Windows always sets error to TRUE if + // UDS_SETBUDDYINT style is not used + if(lpbError != NULL) + *lpbError = (HIWORD(dwRet) != 0) ? TRUE : FALSE; + return (int)(short)LOWORD(dwRet); + } + + int SetPos(int nPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)(short)LOWORD(::SendMessage(this->m_hWnd, UDM_SETPOS, 0, MAKELPARAM(nPos, 0))); + } + + DWORD GetRange() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, UDM_GETRANGE, 0, 0L); + } + + void GetRange(int& nLower, int& nUpper) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, UDM_GETRANGE, 0, 0L); + nLower = (int)(short)HIWORD(dwRet); + nUpper = (int)(short)LOWORD(dwRet); + } + + void SetRange(int nLower, int nUpper) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, UDM_SETRANGE, 0, MAKELPARAM(nUpper, nLower)); + } + + void SetRange32(int nLower, int nUpper) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, UDM_SETRANGE32, nLower, nUpper); + } + + void GetRange32(int& nLower, int& nUpper) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, UDM_GETRANGE32, (WPARAM)&nLower, (LPARAM)&nUpper); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, UDM_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, UDM_SETUNICODEFORMAT, bUnicode, 0L); + } + + int GetPos32(LPBOOL lpbError = NULL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + // Note: Seems that Windows always sets error to TRUE if + // UDS_SETBUDDYINT style is not used + return (int)::SendMessage(this->m_hWnd, UDM_GETPOS32, 0, (LPARAM)lpbError); + } + + int SetPos32(int nPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, UDM_SETPOS32, 0, (LPARAM)nPos); + } +}; + +typedef CUpDownCtrlT<ATL::CWindow> CUpDownCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CProgressBarCtrl + +template <class TBase> +class CProgressBarCtrlT : public TBase +{ +public: +// Constructors + CProgressBarCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CProgressBarCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return PROGRESS_CLASS; + } + + DWORD SetRange(int nLower, int nUpper) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, PBM_SETRANGE, 0, MAKELPARAM(nLower, nUpper)); + } + + int SetPos(int nPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)(short)LOWORD(::SendMessage(this->m_hWnd, PBM_SETPOS, nPos, 0L)); + } + + int OffsetPos(int nPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)(short)LOWORD(::SendMessage(this->m_hWnd, PBM_DELTAPOS, nPos, 0L)); + } + + int SetStep(int nStep) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)(short)LOWORD(::SendMessage(this->m_hWnd, PBM_SETSTEP, nStep, 0L)); + } + + UINT GetPos() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, PBM_GETPOS, 0, 0L); + } + + void GetRange(PPBRANGE pPBRange) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(pPBRange != NULL); + ::SendMessage(this->m_hWnd, PBM_GETRANGE, TRUE, (LPARAM)pPBRange); + } + + void GetRange(int& nLower, int& nUpper) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + PBRANGE range = {}; + ::SendMessage(this->m_hWnd, PBM_GETRANGE, TRUE, (LPARAM)&range); + nLower = range.iLow; + nUpper = range.iHigh; + } + + int GetRangeLimit(BOOL bLowLimit) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PBM_GETRANGE, bLowLimit, (LPARAM)NULL); + } + + DWORD SetRange32(int nMin, int nMax) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, PBM_SETRANGE32, nMin, nMax); + } + + COLORREF SetBarColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, PBM_SETBARCOLOR, 0, (LPARAM)clr); + } + + COLORREF SetBkColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, PBM_SETBKCOLOR, 0, (LPARAM)clr); + } + +#ifdef PBM_SETMARQUEE + BOOL SetMarquee(BOOL bMarquee, UINT uUpdateTime = 0U) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, PBM_SETMARQUEE, (WPARAM)bMarquee, (LPARAM)uUpdateTime); + } +#endif + +#if (_WIN32_WINNT >= 0x0600) + int GetStep() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PBM_GETSTEP, 0, 0L); + } + + COLORREF GetBkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, PBM_GETBKCOLOR, 0, 0L); + } + + COLORREF GetBarColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, PBM_GETBARCOLOR, 0, 0L); + } + + int GetState() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PBM_GETSTATE, 0, 0L); + } + + int SetState(int nState) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PBM_SETSTATE, nState, 0L); + } +#endif // (_WIN32_WINNT >= 0x0600) + +// Operations + int StepIt() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)(short)LOWORD(::SendMessage(this->m_hWnd, PBM_STEPIT, 0, 0L)); + } +}; + +typedef CProgressBarCtrlT<ATL::CWindow> CProgressBarCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CHotKeyCtrl + +template <class TBase> +class CHotKeyCtrlT : public TBase +{ +public: +// Constructors + CHotKeyCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CHotKeyCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return HOTKEY_CLASS; + } + + DWORD GetHotKey() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, HKM_GETHOTKEY, 0, 0L); + } + + void GetHotKey(WORD &wVirtualKeyCode, WORD &wModifiers) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dw = (DWORD)::SendMessage(this->m_hWnd, HKM_GETHOTKEY, 0, 0L); + wVirtualKeyCode = LOBYTE(LOWORD(dw)); + wModifiers = HIBYTE(LOWORD(dw)); + } + + void SetHotKey(WORD wVirtualKeyCode, WORD wModifiers) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, HKM_SETHOTKEY, MAKEWORD(wVirtualKeyCode, wModifiers), 0L); + } + + void SetRules(WORD wInvalidComb, WORD wModifiers) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, HKM_SETRULES, wInvalidComb, MAKELPARAM(wModifiers, 0)); + } +}; + +typedef CHotKeyCtrlT<ATL::CWindow> CHotKeyCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CAnimateCtrl + +template <class TBase> +class CAnimateCtrlT : public TBase +{ +public: +// Constructors + CAnimateCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CAnimateCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return ANIMATE_CLASS; + } + +// Operations + BOOL Open(ATL::_U_STRINGorID FileName) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, ACM_OPEN, 0, (LPARAM)FileName.m_lpstr); + } + + BOOL Play(UINT nFrom, UINT nTo, UINT nRep) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, ACM_PLAY, nRep, MAKELPARAM(nFrom, nTo)); + } + + BOOL Stop() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, ACM_STOP, 0, 0L); + } + + BOOL Close() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, ACM_OPEN, 0, 0L); + } + + BOOL Seek(UINT nTo) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, ACM_PLAY, 0, MAKELPARAM(nTo, nTo)); + } + + // Vista only + BOOL IsPlaying() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, ACM_ISPLAYING, 0, 0L); + } +}; + +typedef CAnimateCtrlT<ATL::CWindow> CAnimateCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CRichEditCtrl + +#if !defined(_UNICODE) && (_RICHEDIT_VER >= 0x0500) + #undef MSFTEDIT_CLASS + #define MSFTEDIT_CLASS "RICHEDIT50W" +#endif + +template <class TBase> +class CRichEditCtrlT : public TBase +{ +public: +// Constructors + CRichEditCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CRichEditCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { +#if (_RICHEDIT_VER >= 0x0500) + return MSFTEDIT_CLASS; +#else + return RICHEDIT_CLASS; +#endif + } + + static LPCTSTR GetLibraryName() + { +#if (_RICHEDIT_VER >= 0x0500) + return _T("MSFTEDIT.DLL"); +#else + return _T("RICHED20.DLL"); +#endif + } + + int GetLineCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETLINECOUNT, 0, 0L); + } + + BOOL GetModify() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_GETMODIFY, 0, 0L); + } + + void SetModify(BOOL bModified = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETMODIFY, bModified, 0L); + } + + void GetRect(LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_GETRECT, 0, (LPARAM)lpRect); + } + + DWORD GetOptions() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETOPTIONS, 0, 0L); + } + + DWORD SetOptions(WORD wOperation, DWORD dwOptions) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_SETOPTIONS, wOperation, dwOptions); + } + + // NOTE: first word in lpszBuffer must contain the size of the buffer! + int GetLine(int nIndex, LPTSTR lpszBuffer) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETLINE, nIndex, (LPARAM)lpszBuffer); + } + + int GetLine(int nIndex, LPTSTR lpszBuffer, int nMaxLength) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + *(LPWORD)lpszBuffer = (WORD)nMaxLength; + return (int)::SendMessage(this->m_hWnd, EM_GETLINE, nIndex, (LPARAM)lpszBuffer); + } + + BOOL CanUndo() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_CANUNDO, 0, 0L); + } + + BOOL CanPaste(UINT nFormat = 0) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_CANPASTE, nFormat, 0L); + } + + void GetSel(LONG& nStartChar, LONG& nEndChar) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + CHARRANGE cr = {}; + ::SendMessage(this->m_hWnd, EM_EXGETSEL, 0, (LPARAM)&cr); + nStartChar = cr.cpMin; + nEndChar = cr.cpMax; + } + + void GetSel(CHARRANGE &cr) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_EXGETSEL, 0, (LPARAM)&cr); + } + + int SetSel(LONG nStartChar, LONG nEndChar) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + CHARRANGE cr = { nStartChar, nEndChar }; + return (int)::SendMessage(this->m_hWnd, EM_EXSETSEL, 0, (LPARAM)&cr); + } + + int SetSel(CHARRANGE &cr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_EXSETSEL, 0, (LPARAM)&cr); + } + + int SetSelAll() + { + return SetSel(0, -1); + } + + int SetSelNone() + { + return SetSel(-1, 0); + } + + DWORD GetDefaultCharFormat(CHARFORMAT& cf) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETCHARFORMAT, 0, (LPARAM)&cf); + } + + DWORD GetSelectionCharFormat(CHARFORMAT& cf) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETCHARFORMAT, 1, (LPARAM)&cf); + } + + DWORD GetEventMask() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETEVENTMASK, 0, 0L); + } + + LONG GetLimitText() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LONG)::SendMessage(this->m_hWnd, EM_GETLIMITTEXT, 0, 0L); + } + + DWORD GetParaFormat(PARAFORMAT& pf) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + pf.cbSize = sizeof(PARAFORMAT); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETPARAFORMAT, 0, (LPARAM)&pf); + } + + LONG GetSelText(LPTSTR lpstrBuff) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LONG)::SendMessage(this->m_hWnd, EM_GETSELTEXT, 0, (LPARAM)lpstrBuff); + } + + BOOL GetSelTextBSTR(BSTR& bstrText) const + { + USES_CONVERSION; + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(bstrText == NULL); + + CHARRANGE cr = {}; + ::SendMessage(this->m_hWnd, EM_EXGETSEL, 0, (LPARAM)&cr); + + ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff; + LPTSTR lpstrText = buff.Allocate(cr.cpMax - cr.cpMin + 1); + if(lpstrText == NULL) + return FALSE; + if(::SendMessage(this->m_hWnd, EM_GETSELTEXT, 0, (LPARAM)lpstrText) == 0) + return FALSE; + + bstrText = ::SysAllocString(T2W(lpstrText)); + + return (bstrText != NULL) ? TRUE : FALSE; + } + +#ifdef __ATLSTR_H__ + LONG GetSelText(ATL::CString& strText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + + CHARRANGE cr = {}; + ::SendMessage(this->m_hWnd, EM_EXGETSEL, 0, (LPARAM)&cr); + + LONG lLen = 0; + LPTSTR lpstrText = strText.GetBufferSetLength(cr.cpMax - cr.cpMin); + if(lpstrText != NULL) + { + lLen = (LONG)::SendMessage(this->m_hWnd, EM_GETSELTEXT, 0, (LPARAM)lpstrText); + strText.ReleaseBuffer(); + } + + return lLen; + } +#endif // __ATLSTR_H__ + + WORD GetSelectionType() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (WORD)::SendMessage(this->m_hWnd, EM_SELECTIONTYPE, 0, 0L); + } + + COLORREF SetBackgroundColor(COLORREF cr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, EM_SETBKGNDCOLOR, 0, cr); + } + + COLORREF SetBackgroundColor() // sets to system background + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, EM_SETBKGNDCOLOR, 1, 0); + } + + BOOL SetCharFormat(CHARFORMAT& cf, WORD wFlags) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, (WPARAM)wFlags, (LPARAM)&cf); + } + + BOOL SetDefaultCharFormat(CHARFORMAT& cf) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, 0, (LPARAM)&cf); + } + + BOOL SetSelectionCharFormat(CHARFORMAT& cf) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, SCF_SELECTION, (LPARAM)&cf); + } + + BOOL SetWordCharFormat(CHARFORMAT& cf) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, SCF_SELECTION | SCF_WORD, (LPARAM)&cf); + } + + DWORD SetEventMask(DWORD dwEventMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_SETEVENTMASK, 0, dwEventMask); + } + + BOOL SetParaFormat(PARAFORMAT& pf) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + pf.cbSize = sizeof(PARAFORMAT); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETPARAFORMAT, 0, (LPARAM)&pf); + } + + BOOL SetTargetDevice(HDC hDC, int cxLineWidth) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETTARGETDEVICE, (WPARAM)hDC, cxLineWidth); + } + + int GetTextLength() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, WM_GETTEXTLENGTH, 0, 0L); + } + + BOOL SetReadOnly(BOOL bReadOnly = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETREADONLY, bReadOnly, 0L); + } + + int GetFirstVisibleLine() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETFIRSTVISIBLELINE, 0, 0L); + } + + int GetTextRange(TEXTRANGE* pTextRange) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETTEXTRANGE, 0, (LPARAM)pTextRange); + } + + int GetTextRange(LONG nStartChar, LONG nEndChar, LPTSTR lpstrText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + TEXTRANGE tr = {}; + tr.chrg.cpMin = nStartChar; + tr.chrg.cpMax = nEndChar; + tr.lpstrText = lpstrText; + return (int)::SendMessage(this->m_hWnd, EM_GETTEXTRANGE, 0, (LPARAM)&tr); + } + + DWORD GetDefaultCharFormat(CHARFORMAT2& cf) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT2); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETCHARFORMAT, 0, (LPARAM)&cf); + } + + BOOL SetCharFormat(CHARFORMAT2& cf, WORD wFlags) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT2); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, (WPARAM)wFlags, (LPARAM)&cf); + } + + BOOL SetDefaultCharFormat(CHARFORMAT2& cf) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT2); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, 0, (LPARAM)&cf); + } + + DWORD GetSelectionCharFormat(CHARFORMAT2& cf) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT2); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETCHARFORMAT, 1, (LPARAM)&cf); + } + + BOOL SetSelectionCharFormat(CHARFORMAT2& cf) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT2); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, SCF_SELECTION, (LPARAM)&cf); + } + + BOOL SetWordCharFormat(CHARFORMAT2& cf) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + cf.cbSize = sizeof(CHARFORMAT2); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, SCF_SELECTION | SCF_WORD, (LPARAM)&cf); + } + + DWORD GetParaFormat(PARAFORMAT2& pf) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + pf.cbSize = sizeof(PARAFORMAT2); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETPARAFORMAT, 0, (LPARAM)&pf); + } + + BOOL SetParaFormat(PARAFORMAT2& pf) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + pf.cbSize = sizeof(PARAFORMAT2); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETPARAFORMAT, 0, (LPARAM)&pf); + } + + TEXTMODE GetTextMode() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (TEXTMODE)::SendMessage(this->m_hWnd, EM_GETTEXTMODE, 0, 0L); + } + + BOOL SetTextMode(TEXTMODE enumTextMode) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return !(BOOL)::SendMessage(this->m_hWnd, EM_SETTEXTMODE, enumTextMode, 0L); + } + + UNDONAMEID GetUndoName() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UNDONAMEID)::SendMessage(this->m_hWnd, EM_GETUNDONAME, 0, 0L); + } + + UNDONAMEID GetRedoName() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UNDONAMEID)::SendMessage(this->m_hWnd, EM_GETREDONAME, 0, 0L); + } + + BOOL CanRedo() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_CANREDO, 0, 0L); + } + + BOOL GetAutoURLDetect() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_GETAUTOURLDETECT, 0, 0L); + } + + BOOL SetAutoURLDetect(BOOL bAutoDetect = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return !(BOOL)::SendMessage(this->m_hWnd, EM_AUTOURLDETECT, bAutoDetect, 0L); + } + + // this method is deprecated, please use SetAutoURLDetect + BOOL EnableAutoURLDetect(BOOL bEnable = TRUE) { return SetAutoURLDetect(bEnable); } + + UINT SetUndoLimit(UINT uUndoLimit) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, EM_SETUNDOLIMIT, uUndoLimit, 0L); + } + + void SetPalette(HPALETTE hPalette) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETPALETTE, (WPARAM)hPalette, 0L); + } + + int GetTextEx(GETTEXTEX* pGetTextEx, LPTSTR lpstrText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETTEXTEX, (WPARAM)pGetTextEx, (LPARAM)lpstrText); + } + + int GetTextEx(LPTSTR lpstrText, int nTextLen, DWORD dwFlags = GT_DEFAULT, UINT uCodePage = CP_ACP, LPCSTR lpDefaultChar = NULL, LPBOOL lpUsedDefChar = NULL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + GETTEXTEX gte = {}; + gte.cb = nTextLen * sizeof(TCHAR); + gte.codepage = uCodePage; + gte.flags = dwFlags; + gte.lpDefaultChar = lpDefaultChar; + gte.lpUsedDefChar = lpUsedDefChar; + return (int)::SendMessage(this->m_hWnd, EM_GETTEXTEX, (WPARAM)>e, (LPARAM)lpstrText); + } + + int GetTextLengthEx(GETTEXTLENGTHEX* pGetTextLengthEx) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETTEXTLENGTHEX, (WPARAM)pGetTextLengthEx, 0L); + } + + int GetTextLengthEx(DWORD dwFlags = GTL_DEFAULT, UINT uCodePage = CP_ACP) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + GETTEXTLENGTHEX gtle = {}; + gtle.codepage = uCodePage; + gtle.flags = dwFlags; + return (int)::SendMessage(this->m_hWnd, EM_GETTEXTLENGTHEX, (WPARAM)>le, 0L); + } + + EDITWORDBREAKPROC GetWordBreakProc() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (EDITWORDBREAKPROC)::SendMessage(this->m_hWnd, EM_GETWORDBREAKPROC, 0, 0L); + } + + void SetWordBreakProc(EDITWORDBREAKPROC ewbprc) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETWORDBREAKPROC, 0, (LPARAM)ewbprc); + } + + int SetTextEx(SETTEXTEX* pSetTextEx, LPCTSTR lpstrText) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_SETTEXTEX, (WPARAM)pSetTextEx, (LPARAM)lpstrText); + } + + int SetTextEx(LPCTSTR lpstrText, DWORD dwFlags = ST_DEFAULT, UINT uCodePage = CP_ACP) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + SETTEXTEX ste = {}; + ste.flags = dwFlags; + ste.codepage = uCodePage; + return (int)::SendMessage(this->m_hWnd, EM_SETTEXTEX, (WPARAM)&ste, (LPARAM)lpstrText); + } + + int GetEditStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_GETEDITSTYLE, 0, 0L); + } + + int SetEditStyle(int nStyle, int nMask = -1) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + if(nMask == -1) + nMask = nStyle; // set everything specified + return (int)::SendMessage(this->m_hWnd, EM_SETEDITSTYLE, nStyle, nMask); + } + + BOOL SetFontSize(int nFontSizeDelta) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((nFontSizeDelta >= -1637) && (nFontSizeDelta <= 1638)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETFONTSIZE, nFontSizeDelta, 0L); + } + + void GetScrollPos(LPPOINT lpPoint) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpPoint != NULL); + ::SendMessage(this->m_hWnd, EM_GETSCROLLPOS, 0, (LPARAM)lpPoint); + } + + void SetScrollPos(LPPOINT lpPoint) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpPoint != NULL); + ::SendMessage(this->m_hWnd, EM_SETSCROLLPOS, 0, (LPARAM)lpPoint); + } + + BOOL GetZoom(int& nNum, int& nDen) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_GETZOOM, (WPARAM)&nNum, (LPARAM)&nDen); + } + + BOOL SetZoom(int nNum, int nDen) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((nNum >= 0) && (nNum <= 64)); + ATLASSERT((nDen >= 0) && (nDen <= 64)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETZOOM, nNum, nDen); + } + + BOOL SetZoomOff() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETZOOM, 0, 0L); + } + + void SetMargins(UINT nLeft, UINT nRight, WORD wFlags = EC_LEFTMARGIN | EC_RIGHTMARGIN) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETMARGINS, wFlags, MAKELONG(nLeft, nRight)); + } + + WORD GetTypographyOptions() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (WORD)::SendMessage(this->m_hWnd, EM_GETTYPOGRAPHYOPTIONS, 0, 0L); + } + + BOOL SetTypographyOptions(WORD wOptions, WORD wMask) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETTYPOGRAPHYOPTIONS, wOptions, wMask); + } + +// Operations + void LimitText(LONG nChars = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_EXLIMITTEXT, 0, nChars); + } + + int LineFromChar(LONG nIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_EXLINEFROMCHAR, 0, nIndex); + } + + POINT PosFromChar(LONG nChar) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + POINT point = {}; + ::SendMessage(this->m_hWnd, EM_POSFROMCHAR, (WPARAM)&point, nChar); + return point; + } + + int CharFromPos(POINT pt) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + POINTL ptl = { pt.x, pt.y }; + return (int)::SendMessage(this->m_hWnd, EM_CHARFROMPOS, 0, (LPARAM)&ptl); + } + + void EmptyUndoBuffer() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_EMPTYUNDOBUFFER, 0, 0L); + } + + int LineIndex(int nLine = -1) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_LINEINDEX, nLine, 0L); + } + + int LineLength(int nLine = -1) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, EM_LINELENGTH, nLine, 0L); + } + + BOOL LineScroll(int nLines) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_LINESCROLL, 0, nLines); + } + + void ReplaceSel(LPCTSTR lpszNewText, BOOL bCanUndo = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_REPLACESEL, (WPARAM) bCanUndo, (LPARAM)lpszNewText); + } + + void SetRect(LPCRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETRECT, 0, (LPARAM)lpRect); + } + + BOOL DisplayBand(LPRECT pDisplayRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_DISPLAYBAND, 0, (LPARAM)pDisplayRect); + } + + LONG FindText(DWORD dwFlags, FINDTEXT& ft) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); +#ifdef _UNICODE + return (LONG)::SendMessage(this->m_hWnd, EM_FINDTEXTW, dwFlags, (LPARAM)&ft); +#else + return (LONG)::SendMessage(this->m_hWnd, EM_FINDTEXT, dwFlags, (LPARAM)&ft); +#endif + } + + LONG FindText(DWORD dwFlags, FINDTEXTEX& ft) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); +#ifdef _UNICODE + return (LONG)::SendMessage(this->m_hWnd, EM_FINDTEXTEXW, dwFlags, (LPARAM)&ft); +#else + return (LONG)::SendMessage(this->m_hWnd, EM_FINDTEXTEX, dwFlags, (LPARAM)&ft); +#endif + } + + LONG FormatRange(FORMATRANGE& fr, BOOL bDisplay = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LONG)::SendMessage(this->m_hWnd, EM_FORMATRANGE, bDisplay, (LPARAM)&fr); + } + + LONG FormatRange(FORMATRANGE* pFormatRange, BOOL bDisplay = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LONG)::SendMessage(this->m_hWnd, EM_FORMATRANGE, bDisplay, (LPARAM)pFormatRange); + } + + void HideSelection(BOOL bHide = TRUE, BOOL bChangeStyle = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_HIDESELECTION, bHide, bChangeStyle); + } + + void PasteSpecial(UINT uClipFormat, DWORD dwAspect = 0, HMETAFILE hMF = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + REPASTESPECIAL reps = { dwAspect, (DWORD_PTR)hMF }; + ::SendMessage(this->m_hWnd, EM_PASTESPECIAL, uClipFormat, (LPARAM)&reps); + } + + void RequestResize() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_REQUESTRESIZE, 0, 0L); + } + + LONG StreamIn(UINT uFormat, EDITSTREAM& es) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LONG)::SendMessage(this->m_hWnd, EM_STREAMIN, uFormat, (LPARAM)&es); + } + + LONG StreamOut(UINT uFormat, EDITSTREAM& es) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (LONG)::SendMessage(this->m_hWnd, EM_STREAMOUT, uFormat, (LPARAM)&es); + } + + DWORD FindWordBreak(int nCode, LONG nStartChar) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_FINDWORDBREAK, nCode, nStartChar); + } + + // Additional operations + void ScrollCaret() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SCROLLCARET, 0, 0L); + } + + int InsertText(long nInsertAfterChar, LPCTSTR lpstrText, BOOL bCanUndo = FALSE) + { + int nRet = SetSel(nInsertAfterChar, nInsertAfterChar); + ReplaceSel(lpstrText, bCanUndo); + return nRet; + } + + int AppendText(LPCTSTR lpstrText, BOOL bCanUndo = FALSE) + { + return InsertText(this->GetWindowTextLength(), lpstrText, bCanUndo); + } + + // Clipboard operations + BOOL Undo() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_UNDO, 0, 0L); + } + + void Clear() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_CLEAR, 0, 0L); + } + + void Copy() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_COPY, 0, 0L); + } + + void Cut() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_CUT, 0, 0L); + } + + void Paste() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, WM_PASTE, 0, 0L); + } + + // OLE support + IRichEditOle* GetOleInterface() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + IRichEditOle *pRichEditOle = NULL; + ::SendMessage(this->m_hWnd, EM_GETOLEINTERFACE, 0, (LPARAM)&pRichEditOle); + return pRichEditOle; + } + + BOOL SetOleCallback(IRichEditOleCallback* pCallback) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETOLECALLBACK, 0, (LPARAM)pCallback); + } + + BOOL Redo() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_REDO, 0, 0L); + } + + void StopGroupTyping() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_STOPGROUPTYPING, 0, 0L); + } + + void ShowScrollBar(int nBarType, BOOL bVisible = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SHOWSCROLLBAR, nBarType, bVisible); + } + + BOOL SetTabStops(int nTabStops, LPINT rgTabStops) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, nTabStops, (LPARAM)rgTabStops); + } + + BOOL SetTabStops() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, 0, 0L); + } + + BOOL SetTabStops(const int& cxEachStop) // takes an 'int' + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, 1, (LPARAM)(LPINT)&cxEachStop); + } + +#if (_RICHEDIT_VER >= 0x0800) + AutoCorrectProc GetAutoCorrectProc() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (AutoCorrectProc)::SendMessage(this->m_hWnd, EM_GETAUTOCORRECTPROC, 0, 0L); + } + + BOOL SetAutoCorrectProc(AutoCorrectProc pfn) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETAUTOCORRECTPROC, (WPARAM)pfn, 0L); + } + + BOOL CallAutoCorrectProc(WCHAR ch) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_CALLAUTOCORRECTPROC, (WPARAM)ch, 0L); + } + + DWORD GetEditStyleEx() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETEDITSTYLEEX, 0, 0L); + } + + DWORD SetEditStyleEx(DWORD dwStyleEx, DWORD dwMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_SETEDITSTYLEEX, dwStyleEx, dwMask); + } + + DWORD GetStoryType(int nStoryIndex) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_GETSTORYTYPE, nStoryIndex, 0L); + } + + DWORD SetStoryType(int nStoryIndex, DWORD dwStoryType) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, EM_SETSTORYTYPE, nStoryIndex, dwStoryType); + } + + DWORD GetEllipsisMode() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + + DWORD dwMode = 0; + BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, EM_GETELLIPSISMODE, 0, (LPARAM)&dwMode); + (void)bRet; // avoid level 4 warning + ATLASSERT(bRet != FALSE); + + return dwMode; + } + + BOOL SetEllipsisMode(DWORD dwEllipsisMode) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETELLIPSISMODE, 0, dwEllipsisMode); + } + + BOOL GetEllipsisState() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_GETELLIPSISSTATE, 0, 0L); + } + + BOOL GetTouchOptions(int nTouchOptions) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_GETTOUCHOPTIONS, nTouchOptions, 0L); + } + + void SetTouchOptions(int nTouchOptions, BOOL bEnable) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, EM_SETTOUCHOPTIONS, nTouchOptions, bEnable); + } + + HRESULT InsertTable(TABLEROWPARMS* pRowParams, TABLECELLPARMS* pCellParams) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HRESULT)::SendMessage(this->m_hWnd, EM_INSERTTABLE, (WPARAM)pRowParams, (LPARAM)pCellParams); + } + + HRESULT GetTableParams(TABLEROWPARMS* pRowParams, TABLECELLPARMS* pCellParams) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HRESULT)::SendMessage(this->m_hWnd, EM_GETTABLEPARMS, (WPARAM)pRowParams, (LPARAM)pCellParams); + } + + HRESULT SetTableParams(TABLEROWPARMS* pRowParams, TABLECELLPARMS* pCellParams) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HRESULT)::SendMessage(this->m_hWnd, EM_SETTABLEPARMS, (WPARAM)pRowParams, (LPARAM)pCellParams); + } + + HRESULT InsertImage(RICHEDIT_IMAGE_PARAMETERS* pParams) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HRESULT)::SendMessage(this->m_hWnd, EM_INSERTIMAGE, 0, (LPARAM)pParams); + } + + BOOL SetUiaName(LPCTSTR lpstrName) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, EM_SETUIANAME, 0, (LPARAM)lpstrName); + } +#endif // (_RICHEDIT_VER >= 0x0800) +}; + +typedef CRichEditCtrlT<ATL::CWindow> CRichEditCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CRichEditCommands - message handlers for standard EDIT commands + +// Chain to CRichEditCommands message map. Your class must also derive from CRichEditCtrl. +// Example: +// class CMyRichEdit : public CWindowImpl<CMyRichEdit, CRichEditCtrl>, +// public CRichEditCommands<CMyRichEdit> +// { +// public: +// BEGIN_MSG_MAP(CMyRichEdit) +// // your handlers... +// CHAIN_MSG_MAP_ALT(CRichEditCommands<CMyRichEdit>, 1) +// END_MSG_MAP() +// // other stuff... +// }; + +template <class T> +class CRichEditCommands : public CEditCommands< T > +{ +public: + BEGIN_MSG_MAP(CRichEditCommands< T >) + ALT_MSG_MAP(1) + COMMAND_ID_HANDLER(ID_EDIT_CLEAR, CEditCommands< T >::OnEditClear) + COMMAND_ID_HANDLER(ID_EDIT_CLEAR_ALL, CEditCommands< T >::OnEditClearAll) + COMMAND_ID_HANDLER(ID_EDIT_COPY, CEditCommands< T >::OnEditCopy) + COMMAND_ID_HANDLER(ID_EDIT_CUT, CEditCommands< T >::OnEditCut) + COMMAND_ID_HANDLER(ID_EDIT_PASTE, CEditCommands< T >::OnEditPaste) + COMMAND_ID_HANDLER(ID_EDIT_SELECT_ALL, CEditCommands< T >::OnEditSelectAll) + COMMAND_ID_HANDLER(ID_EDIT_UNDO, CEditCommands< T >::OnEditUndo) + COMMAND_ID_HANDLER(ID_EDIT_REDO, OnEditRedo) + END_MSG_MAP() + + LRESULT OnEditRedo(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/) + { + T* pT = static_cast<T*>(this); + pT->Redo(); + return 0; + } + +// State (update UI) helpers + BOOL CanCut() const + { return HasSelection(); } + + BOOL CanCopy() const + { return HasSelection(); } + + BOOL CanClear() const + { return HasSelection(); } + +// Implementation + BOOL HasSelection() const + { + const T* pT = static_cast<const T*>(this); + return (pT->GetSelectionType() != SEL_EMPTY); + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CDragListBox + +template <class TBase> +class CDragListBoxT : public CListBoxT< TBase > +{ +public: +// Constructors + CDragListBoxT(HWND hWnd = NULL) : CListBoxT< TBase >(hWnd) + { } + + CDragListBoxT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + HWND hWnd = TBase::Create(TBase::GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + if(hWnd != NULL) + MakeDragList(); + return hWnd; + } + +// Operations + BOOL MakeDragList() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) == 0); + return ::MakeDragList(this->m_hWnd); + } + + int LBItemFromPt(POINT pt, BOOL bAutoScroll = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ::LBItemFromPt(this->m_hWnd, pt, bAutoScroll); + } + + void DrawInsert(int nItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::DrawInsert(this->GetParent(), this->m_hWnd, nItem); + } + + static UINT GetDragListMessage() + { + static UINT uDragListMessage = 0; + if(uDragListMessage == 0) + { + CStaticDataInitCriticalSectionLock lock; + if(FAILED(lock.Lock())) + { + ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CDragListBox::GetDragListMessage.\n")); + ATLASSERT(FALSE); + return 0; + } + + if(uDragListMessage == 0) + uDragListMessage = ::RegisterWindowMessage(DRAGLISTMSGSTRING); + + lock.Unlock(); + } + ATLASSERT(uDragListMessage != 0); + return uDragListMessage; + } +}; + +typedef CDragListBoxT<ATL::CWindow> CDragListBox; + +template <class T> +class CDragListNotifyImpl +{ +public: + BEGIN_MSG_MAP(CDragListNotifyImpl< T >) + MESSAGE_HANDLER(CDragListBox::GetDragListMessage(), OnDragListNotify) + END_MSG_MAP() + + LRESULT OnDragListNotify(UINT uMsg, WPARAM wParam, LPARAM lParam, BOOL& bHandled) + { + (void)uMsg; // avoid level 4 warning + ATLASSERT(uMsg == CDragListBox::GetDragListMessage()); + T* pT = static_cast<T*>(this); + LPDRAGLISTINFO lpDragListInfo = (LPDRAGLISTINFO)lParam; + LRESULT lRet = 0; + switch(lpDragListInfo->uNotification) + { + case DL_BEGINDRAG: + lRet = (LPARAM)pT->OnBeginDrag((int)wParam, lpDragListInfo->hWnd, lpDragListInfo->ptCursor); + break; + case DL_CANCELDRAG: + pT->OnCancelDrag((int)wParam, lpDragListInfo->hWnd, lpDragListInfo->ptCursor); + break; + case DL_DRAGGING: + lRet = (LPARAM)pT->OnDragging((int)wParam, lpDragListInfo->hWnd, lpDragListInfo->ptCursor); + break; + case DL_DROPPED: + pT->OnDropped((int)wParam, lpDragListInfo->hWnd, lpDragListInfo->ptCursor); + break; + default: + ATLTRACE2(atlTraceUI, 0, _T("Unknown DragListBox notification\n")); + bHandled = FALSE; // don't handle it + break; + } + return lRet; + } + +// Overrideables + BOOL OnBeginDrag(int /*nCtlID*/, HWND /*hWndDragList*/, POINT /*ptCursor*/) + { + return TRUE; // allow dragging + } + + void OnCancelDrag(int /*nCtlID*/, HWND /*hWndDragList*/, POINT /*ptCursor*/) + { + // nothing to do + } + + int OnDragging(int /*nCtlID*/, HWND /*hWndDragList*/, POINT /*ptCursor*/) + { + return 0; // don't change cursor + } + + void OnDropped(int /*nCtlID*/, HWND /*hWndDragList*/, POINT /*ptCursor*/) + { + // nothing to do + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CReBarCtrl + +template <class TBase> +class CReBarCtrlT : public TBase +{ +public: +// Constructors + CReBarCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CReBarCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return REBARCLASSNAME; + } + + UINT GetBandCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, RB_GETBANDCOUNT, 0, 0L); + } + + BOOL GetBandInfo(int nBand, LPREBARBANDINFO lprbbi) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_GETBANDINFO, nBand, (LPARAM)lprbbi); + } + + BOOL SetBandInfo(int nBand, LPREBARBANDINFO lprbbi) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_SETBANDINFO, nBand, (LPARAM)lprbbi); + } + + BOOL GetBarInfo(LPREBARINFO lprbi) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_GETBARINFO, 0, (LPARAM)lprbi); + } + + BOOL SetBarInfo(LPREBARINFO lprbi) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_SETBARINFO, 0, (LPARAM)lprbi); + } + + CImageList GetImageList() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + REBARINFO rbi = {}; + rbi.cbSize = sizeof(REBARINFO); + rbi.fMask = RBIM_IMAGELIST; + BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, RB_GETBARINFO, 0, (LPARAM)&rbi); + return CImageList((bRet != FALSE) ? rbi.himl : NULL); + } + + BOOL SetImageList(HIMAGELIST hImageList) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + REBARINFO rbi = {}; + rbi.cbSize = sizeof(REBARINFO); + rbi.fMask = RBIM_IMAGELIST; + rbi.himl = hImageList; + return (BOOL)::SendMessage(this->m_hWnd, RB_SETBARINFO, 0, (LPARAM)&rbi); + } + + UINT GetRowCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, RB_GETROWCOUNT, 0, 0L); + } + + UINT GetRowHeight(int nBand) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, RB_GETROWHEIGHT, nBand, 0L); + } + + COLORREF GetTextColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, RB_GETTEXTCOLOR, 0, 0L); + } + + COLORREF SetTextColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, RB_SETTEXTCOLOR, 0, (LPARAM)clr); + } + + COLORREF GetBkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, RB_GETBKCOLOR, 0, 0L); + } + + COLORREF SetBkColor(COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, RB_SETBKCOLOR, 0, (LPARAM)clr); + } + + UINT GetBarHeight() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (UINT)::SendMessage(this->m_hWnd, RB_GETBARHEIGHT, 0, 0L); + } + + BOOL GetRect(int nBand, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_GETRECT, nBand, (LPARAM)lpRect); + } + + CToolTipCtrl GetToolTips() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, RB_GETTOOLTIPS, 0, 0L)); + } + + void SetToolTips(HWND hwndToolTip) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_SETTOOLTIPS, (WPARAM)hwndToolTip, 0L); + } + + void GetBandBorders(int nBand, LPRECT lpRect) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpRect != NULL); + ::SendMessage(this->m_hWnd, RB_GETBANDBORDERS, nBand, (LPARAM)lpRect); + } + + BOOL GetColorScheme(LPCOLORSCHEME lpColorScheme) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpColorScheme != NULL); + return (BOOL)::SendMessage(this->m_hWnd, RB_GETCOLORSCHEME, 0, (LPARAM)lpColorScheme); + } + + void SetColorScheme(LPCOLORSCHEME lpColorScheme) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpColorScheme != NULL); + ::SendMessage(this->m_hWnd, RB_SETCOLORSCHEME, 0, (LPARAM)lpColorScheme); + } + + HPALETTE GetPalette() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HPALETTE)::SendMessage(this->m_hWnd, RB_GETPALETTE, 0, 0L); + } + + HPALETTE SetPalette(HPALETTE hPalette) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (HPALETTE)::SendMessage(this->m_hWnd, RB_SETPALETTE, 0, (LPARAM)hPalette); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_SETUNICODEFORMAT, bUnicode, 0L); + } + + // requires uxtheme.h to be included to use MARGINS struct +#ifndef _UXTHEME_H_ + typedef struct _MARGINS* PMARGINS; +#endif // !_UXTHEME_H_ + void GetBandMargins(PMARGINS pMargins) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_GETBANDMARGINS, 0, (LPARAM)pMargins); + } + + void SetWindowTheme(LPCWSTR lpstrTheme) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_SETWINDOWTHEME, 0, (LPARAM)lpstrTheme); + } + + DWORD GetExtendedStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, RB_GETEXTENDEDSTYLE, 0, 0L); + } + + DWORD SetExtendedStyle(DWORD dwStyle, DWORD dwMask) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, RB_SETEXTENDEDSTYLE, dwMask, dwStyle); + } + +// Operations + BOOL InsertBand(int nBand, LPREBARBANDINFO lprbbi) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_INSERTBAND, nBand, (LPARAM)lprbbi); + } + + BOOL AddBand(LPREBARBANDINFO lprbbi) + { + return InsertBand(-1, lprbbi); + } + + BOOL DeleteBand(int nBand) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_DELETEBAND, nBand, 0L); + } + + ATL::CWindow SetNotifyWnd(HWND hWnd) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, RB_SETPARENT, (WPARAM)hWnd, 0L)); + } + + void BeginDrag(int nBand, DWORD dwPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_BEGINDRAG, nBand, dwPos); + } + + void BeginDrag(int nBand, int xPos, int yPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_BEGINDRAG, nBand, MAKELPARAM(xPos, yPos)); + } + + void EndDrag() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_ENDDRAG, 0, 0L); + } + + void DragMove(DWORD dwPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_DRAGMOVE, 0, dwPos); + } + + void DragMove(int xPos, int yPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_DRAGMOVE, 0, MAKELPARAM(xPos, yPos)); + } + + void GetDropTarget(IDropTarget** ppDropTarget) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_GETDROPTARGET, 0, (LPARAM)ppDropTarget); + } + + void MaximizeBand(int nBand, BOOL bIdeal = FALSE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_MAXIMIZEBAND, nBand, bIdeal); + } + + void MinimizeBand(int nBand) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_MINIMIZEBAND, nBand, 0L); + } + + BOOL SizeToRect(LPRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_SIZETORECT, 0, (LPARAM)lpRect); + } + + int IdToIndex(UINT uBandID) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, RB_IDTOINDEX, uBandID, 0L); + } + + int HitTest(LPRBHITTESTINFO lprbht) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, RB_HITTEST, 0, (LPARAM)lprbht); + } + + BOOL ShowBand(int nBand, BOOL bShow) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_SHOWBAND, nBand, bShow); + } + + BOOL MoveBand(int nBand, int nNewPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((nNewPos >= 0) && (nNewPos <= ((int)GetBandCount() - 1))); + return (BOOL)::SendMessage(this->m_hWnd, RB_MOVEBAND, nBand, nNewPos); + } + + void PushChevron(int nBand, LPARAM lAppValue) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, RB_PUSHCHEVRON, nBand, lAppValue); + } + +// Extra operations + void LockBands(bool bLock) + { + int nBandCount = GetBandCount(); + for(int i =0; i < nBandCount; i++) + { + REBARBANDINFO rbbi = { RunTimeHelper::SizeOf_REBARBANDINFO() }; + rbbi.fMask = RBBIM_STYLE; + BOOL bRet = GetBandInfo(i, &rbbi); + ATLASSERT(bRet); + + if((rbbi.fStyle & RBBS_GRIPPERALWAYS) == 0) + { + rbbi.fStyle |= RBBS_GRIPPERALWAYS; + bRet = SetBandInfo(i, &rbbi); + ATLASSERT(bRet); + rbbi.fStyle &= ~RBBS_GRIPPERALWAYS; + } + + if(bLock) + rbbi.fStyle |= RBBS_NOGRIPPER; + else + rbbi.fStyle &= ~RBBS_NOGRIPPER; + + bRet = SetBandInfo(i, &rbbi); + ATLASSERT(bRet); + } + } + +#if (_WIN32_WINNT >= 0x0600) + BOOL SetBandWidth(int nBand, int cxWidth) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, RB_SETBANDWIDTH, nBand, cxWidth); + } +#endif // (_WIN32_WINNT >= 0x0600) +}; + +typedef CReBarCtrlT<ATL::CWindow> CReBarCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CComboBoxEx + +template <class TBase> +class CComboBoxExT : public CComboBoxT< TBase > +{ +public: +// Constructors + CComboBoxExT(HWND hWnd = NULL) : CComboBoxT< TBase >(hWnd) + { } + + CComboBoxExT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return WC_COMBOBOXEX; + } + + CImageList GetImageList() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, CBEM_GETIMAGELIST, 0, 0L)); + } + + CImageList SetImageList(HIMAGELIST hImageList) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, CBEM_SETIMAGELIST, 0, (LPARAM)hImageList)); + } + + DWORD GetExtendedStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, CBEM_GETEXTENDEDSTYLE, 0, 0L); + } + + DWORD SetExtendedStyle(DWORD dwExMask, DWORD dwExStyle) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, CBEM_SETEXTENDEDSTYLE, dwExMask, dwExStyle); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CBEM_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CBEM_SETUNICODEFORMAT, bUnicode, 0L); + } + + void SetWindowTheme(LPCWSTR lpstrTheme) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, CBEM_SETWINDOWTHEME, 0, (LPARAM)lpstrTheme); + } + +// Operations + int InsertItem(const COMBOBOXEXITEM* lpcCBItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CBEM_INSERTITEM, 0, (LPARAM)lpcCBItem); + } + + int InsertItem(UINT nMask, int nIndex, LPCTSTR lpszItem, int nImage, int nSelImage, + int iIndent, int iOverlay, LPARAM lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + COMBOBOXEXITEM cbex = {}; + cbex.mask = nMask; + cbex.iItem = nIndex; + cbex.pszText = (LPTSTR) lpszItem; + cbex.iImage = nImage; + cbex.iSelectedImage = nSelImage; + cbex.iIndent = iIndent; + cbex.iOverlay = iOverlay; + cbex.lParam = lParam; + return (int)::SendMessage(this->m_hWnd, CBEM_INSERTITEM, 0, (LPARAM)&cbex); + } + + int InsertItem(int nIndex, LPCTSTR lpszItem, int nImage, int nSelImage, int iIndent, LPARAM lParam = 0) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + COMBOBOXEXITEM cbex = {}; + cbex.mask = CBEIF_TEXT | CBEIF_IMAGE | CBEIF_SELECTEDIMAGE | CBEIF_INDENT | CBEIF_LPARAM; + cbex.iItem = nIndex; + cbex.pszText = (LPTSTR) lpszItem; + cbex.iImage = nImage; + cbex.iSelectedImage = nSelImage; + cbex.iIndent = iIndent; + cbex.lParam = lParam; + return (int)::SendMessage(this->m_hWnd, CBEM_INSERTITEM, 0, (LPARAM)&cbex); + } + + int AddItem(UINT nMask, LPCTSTR lpszItem, int nImage, int nSelImage, int iIndent, int iOverlay, LPARAM lParam) + { + return InsertItem(nMask, -1, lpszItem, nImage, nSelImage, iIndent, iOverlay, lParam); + } + + int AddItem(LPCTSTR lpszItem, int nImage, int nSelImage, int iIndent, LPARAM lParam = 0) + { + return InsertItem(-1, lpszItem, nImage, nSelImage, iIndent, lParam); + } + + int DeleteItem(int nIndex) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, CBEM_DELETEITEM, nIndex, 0L); + } + + BOOL GetItem(PCOMBOBOXEXITEM pCBItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CBEM_GETITEM, 0, (LPARAM)pCBItem); + } + + BOOL SetItem(const COMBOBOXEXITEM* lpcCBItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CBEM_SETITEM, 0, (LPARAM)lpcCBItem); + } + + int SetItem(int nIndex, UINT nMask, LPCTSTR lpszItem, int nImage, int nSelImage, + int iIndent, int iOverlay, LPARAM lParam) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + COMBOBOXEXITEM cbex = {}; + cbex.mask = nMask; + cbex.iItem = nIndex; + cbex.pszText = (LPTSTR) lpszItem; + cbex.iImage = nImage; + cbex.iSelectedImage = nSelImage; + cbex.iIndent = iIndent; + cbex.iOverlay = iOverlay; + cbex.lParam = lParam; + return (int)::SendMessage(this->m_hWnd, CBEM_SETITEM, 0, (LPARAM)&cbex); + } + + BOOL GetItemText(int nIndex, LPTSTR lpszItem, int nLen) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(lpszItem != NULL); + + COMBOBOXEXITEM cbex = {}; + cbex.mask = CBEIF_TEXT; + cbex.iItem = nIndex; + cbex.pszText = lpszItem; + cbex.cchTextMax = nLen; + + return (BOOL)::SendMessage(this->m_hWnd, CBEM_GETITEM, 0, (LPARAM)&cbex); + } + + BOOL GetItemText(int nIndex, BSTR& bstrText) const + { + USES_CONVERSION; + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(bstrText == NULL); + + COMBOBOXEXITEM cbex = {}; + cbex.mask = CBEIF_TEXT; + cbex.iItem = nIndex; + + LPTSTR lpstrText = NULL; + BOOL bRet = FALSE; + for(int nLen = 256; ; nLen *= 2) + { + ATLTRY(lpstrText = new TCHAR[nLen]); + if(lpstrText == NULL) + break; + lpstrText[0] = NULL; + cbex.pszText = lpstrText; + cbex.cchTextMax = nLen; + bRet = (BOOL)::SendMessage(this->m_hWnd, CBEM_GETITEM, 0, (LPARAM)&cbex); + if(!bRet || (lstrlen(cbex.pszText) < (nLen - 1))) + break; + delete [] lpstrText; + lpstrText = NULL; + } + + if(lpstrText != NULL) + { + if(bRet) + bstrText = ::SysAllocString(T2OLE(lpstrText)); + delete [] lpstrText; + } + + return (bstrText != NULL) ? TRUE : FALSE; + } + +#ifdef __ATLSTR_H__ + BOOL GetItemText(int nIndex, ATL::CString& strText) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + + COMBOBOXEXITEM cbex = {}; + cbex.mask = CBEIF_TEXT; + cbex.iItem = nIndex; + + strText.Empty(); + BOOL bRet = FALSE; + for(int nLen = 256; ; nLen *= 2) + { + cbex.pszText = strText.GetBufferSetLength(nLen); + if(cbex.pszText == NULL) + { + bRet = FALSE; + break; + } + cbex.cchTextMax = nLen; + bRet = (BOOL)::SendMessage(this->m_hWnd, CBEM_GETITEM, 0, (LPARAM)&cbex); + if(!bRet || (lstrlen(cbex.pszText) < (nLen - 1))) + break; + } + strText.ReleaseBuffer(); + return bRet; + } +#endif // __ATLSTR_H__ + + BOOL SetItemText(int nIndex, LPCTSTR lpszItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return SetItem(nIndex, CBEIF_TEXT, lpszItem, 0, 0, 0, 0, 0); + } + + CComboBox GetComboCtrl() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CComboBox((HWND)::SendMessage(this->m_hWnd, CBEM_GETCOMBOCONTROL, 0, 0L)); + } + + CEdit GetEditCtrl() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CEdit((HWND)::SendMessage(this->m_hWnd, CBEM_GETEDITCONTROL, 0, 0L)); + } + + BOOL HasEditChanged() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, CBEM_HASEDITCHANGED, 0, 0L); + } + +// Non-functional + int AddString(LPCTSTR /*lpszItem*/) + { + ATLASSERT(FALSE); // Not available in CComboBoxEx; use InsertItem + return 0; + } + + int InsertString(int /*nIndex*/, LPCTSTR /*lpszString*/) + { + ATLASSERT(FALSE); // Not available in CComboBoxEx; use InsertItem + return 0; + } + + int Dir(UINT /*attr*/, LPCTSTR /*lpszWildCard*/) + { + ATLASSERT(FALSE); // Not available in CComboBoxEx + return 0; + } + + int FindString(int /*nStartAfter*/, LPCTSTR /*lpszString*/) const + { + ATLASSERT(FALSE); // Not available in CComboBoxEx; try FindStringExact + return 0; + } +}; + +typedef CComboBoxExT<ATL::CWindow> CComboBoxEx; + + +/////////////////////////////////////////////////////////////////////////////// +// CMonthCalendarCtrl + +template <class TBase> +class CMonthCalendarCtrlT : public TBase +{ +public: +// Constructors + CMonthCalendarCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CMonthCalendarCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return MONTHCAL_CLASS; + } + + COLORREF GetColor(int nColorType) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, MCM_GETCOLOR, nColorType, 0L); + } + + COLORREF SetColor(int nColorType, COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, MCM_SETCOLOR, nColorType, clr); + } + + BOOL GetCurSel(LPSYSTEMTIME lpSysTime) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_GETCURSEL, 0, (LPARAM)lpSysTime); + } + + BOOL SetCurSel(LPSYSTEMTIME lpSysTime) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_SETCURSEL, 0, (LPARAM)lpSysTime); + } + + int GetFirstDayOfWeek(BOOL* pbLocaleVal = NULL) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, MCM_GETFIRSTDAYOFWEEK, 0, 0L); + if(pbLocaleVal != NULL) + *pbLocaleVal = (BOOL)HIWORD(dwRet); + return (int)(short)LOWORD(dwRet); + } + + int SetFirstDayOfWeek(int nDay, BOOL* pbLocaleVal = NULL) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, MCM_SETFIRSTDAYOFWEEK, 0, nDay); + if(pbLocaleVal != NULL) + *pbLocaleVal = (BOOL)HIWORD(dwRet); + return (int)(short)LOWORD(dwRet); + } + + int GetMaxSelCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, MCM_GETMAXSELCOUNT, 0, 0L); + } + + BOOL SetMaxSelCount(int nMax) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_SETMAXSELCOUNT, nMax, 0L); + } + + int GetMonthDelta() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, MCM_GETMONTHDELTA, 0, 0L); + } + + int SetMonthDelta(int nDelta) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, MCM_SETMONTHDELTA, nDelta, 0L); + } + + DWORD GetRange(LPSYSTEMTIME lprgSysTimeArray) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, MCM_GETRANGE, 0, (LPARAM)lprgSysTimeArray); + } + + BOOL SetRange(DWORD dwFlags, LPSYSTEMTIME lprgSysTimeArray) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_SETRANGE, dwFlags, (LPARAM)lprgSysTimeArray); + } + + BOOL GetSelRange(LPSYSTEMTIME lprgSysTimeArray) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_GETSELRANGE, 0, (LPARAM)lprgSysTimeArray); + } + + BOOL SetSelRange(LPSYSTEMTIME lprgSysTimeArray) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_SETSELRANGE, 0, (LPARAM)lprgSysTimeArray); + } + + BOOL GetToday(LPSYSTEMTIME lpSysTime) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_GETTODAY, 0, (LPARAM)lpSysTime); + } + + void SetToday(LPSYSTEMTIME lpSysTime) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, MCM_SETTODAY, 0, (LPARAM)lpSysTime); + } + + BOOL GetMinReqRect(LPRECT lpRectInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_GETMINREQRECT, 0, (LPARAM)lpRectInfo); + } + + int GetMaxTodayWidth() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, MCM_GETMAXTODAYWIDTH, 0, 0L); + } + + BOOL GetUnicodeFormat() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_GETUNICODEFORMAT, 0, 0L); + } + + BOOL SetUnicodeFormat(BOOL bUnicode = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_SETUNICODEFORMAT, bUnicode, 0L); + } + +#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) + DWORD GetCurrentView() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, MCM_GETCURRENTVIEW, 0, 0L); + } + + BOOL SetCurrentView(DWORD dwView) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_SETCURRENTVIEW, 0, dwView); + } + + DWORD GetCalendarCount() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, MCM_GETCALENDARCOUNT, 0, 0L); + } + + BOOL GetCalendarGridInfo(PMCGRIDINFO pGridInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_GETCALENDARGRIDINFO, 0, (LPARAM)pGridInfo); + } + + CALID GetCALID() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (CALID)::SendMessage(this->m_hWnd, MCM_GETCALID, 0, 0L); + } + + void SetCALID(CALID calid) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, MCM_SETCALID, (LPARAM)calid, 0L); + } + + int GetCalendarBorder() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, MCM_GETCALENDARBORDER, 0, 0L); + } + + void SetCalendarBorder(int cxyBorder, BOOL bSet = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, MCM_SETCALENDARBORDER, (WPARAM)bSet, (LPARAM)cxyBorder); + } +#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) + +// Operations + int GetMonthRange(DWORD dwFlags, LPSYSTEMTIME lprgSysTimeArray) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, MCM_GETMONTHRANGE, dwFlags, (LPARAM)lprgSysTimeArray); + } + + BOOL SetDayState(int nMonths, LPMONTHDAYSTATE lpDayStateArray) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, MCM_SETDAYSTATE, nMonths, (LPARAM)lpDayStateArray); + } + + DWORD HitTest(PMCHITTESTINFO pMCHitTest) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, MCM_HITTEST, 0, (LPARAM)pMCHitTest); + } + +#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) + void SizeRectToMin(LPRECT lpRect) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, MCM_SIZERECTTOMIN, 0, (LPARAM)lpRect); + } +#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) +}; + +typedef CMonthCalendarCtrlT<ATL::CWindow> CMonthCalendarCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CDateTimePickerCtrl + +template <class TBase> +class CDateTimePickerCtrlT : public TBase +{ +public: +// Constructors + CDateTimePickerCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CDateTimePickerCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Operations + static LPCTSTR GetWndClassName() + { + return DATETIMEPICK_CLASS; + } + + BOOL SetFormat(LPCTSTR lpszFormat) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, DTM_SETFORMAT, 0, (LPARAM)lpszFormat); + } + + COLORREF GetMonthCalColor(int nColorType) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, DTM_GETMCCOLOR, nColorType, 0L); + } + + COLORREF SetMonthCalColor(int nColorType, COLORREF clr) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, DTM_SETMCCOLOR, nColorType, clr); + } + + DWORD GetRange(LPSYSTEMTIME lpSysTimeArray) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, DTM_GETRANGE, 0, (LPARAM)lpSysTimeArray); + } + + BOOL SetRange(DWORD dwFlags, LPSYSTEMTIME lpSysTimeArray) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, DTM_SETRANGE, dwFlags, (LPARAM)lpSysTimeArray); + } + + DWORD GetSystemTime(LPSYSTEMTIME lpSysTime) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, DTM_GETSYSTEMTIME, 0, (LPARAM)lpSysTime); + } + + BOOL SetSystemTime(DWORD dwFlags, LPSYSTEMTIME lpSysTime) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, DTM_SETSYSTEMTIME, dwFlags, (LPARAM)lpSysTime); + } + + CMonthCalendarCtrl GetMonthCal() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CMonthCalendarCtrl((HWND)::SendMessage(this->m_hWnd, DTM_GETMONTHCAL, 0, 0L)); + } + + CFontHandle GetMonthCalFont() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return CFontHandle((HFONT)::SendMessage(this->m_hWnd, DTM_GETMCFONT, 0, 0L)); + } + + void SetMonthCalFont(HFONT hFont, BOOL bRedraw = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, DTM_SETMCFONT, (WPARAM)hFont, MAKELPARAM(bRedraw, 0)); + } + +#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) + DWORD GetMonthCalStyle() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, DTM_GETMCSTYLE, 0, 0L); + } + + DWORD SetMonthCalStyle(DWORD dwStyle) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (DWORD)::SendMessage(this->m_hWnd, DTM_SETMCSTYLE, 0, (LPARAM)dwStyle); + } + + void GetDateTimePickerInfo(LPDATETIMEPICKERINFO lpPickerInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, DTM_GETDATETIMEPICKERINFO, 0, (LPARAM)lpPickerInfo); + } + + BOOL GetIdealSize(LPSIZE lpSize) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, DTM_GETIDEALSIZE, 0, (LPARAM)lpSize); + } + + void CloseMonthCal() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, DTM_CLOSEMONTHCAL, 0, 0L); + } +#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) +}; + +typedef CDateTimePickerCtrlT<ATL::CWindow> CDateTimePickerCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CFlatScrollBarImpl - support for flat scroll bars + +template <class T> +class CFlatScrollBarImpl +{ +public: +// Initialization + BOOL FlatSB_Initialize() + { + T* pT = static_cast<T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::InitializeFlatSB(pT->m_hWnd); + } + + HRESULT FlatSB_Uninitialize() + { + T* pT = static_cast<T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::UninitializeFlatSB(pT->m_hWnd); + } + +// Flat scroll bar properties + BOOL FlatSB_GetScrollProp(UINT uIndex, LPINT lpnValue) const + { + const T* pT = static_cast<const T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_GetScrollProp(pT->m_hWnd, uIndex, lpnValue); + } + + BOOL FlatSB_SetScrollProp(UINT uIndex, int nValue, BOOL bRedraw = TRUE) + { + T* pT = static_cast<T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_SetScrollProp(pT->m_hWnd, uIndex, nValue, bRedraw); + } + +// Attributes + int FlatSB_GetScrollPos(int nBar) const + { + const T* pT = static_cast<const T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_GetScrollPos(pT->m_hWnd, nBar); + } + + int FlatSB_SetScrollPos(int nBar, int nPos, BOOL bRedraw = TRUE) + { + T* pT = static_cast<T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_SetScrollPos(pT->m_hWnd, nBar, nPos, bRedraw); + } + + BOOL FlatSB_GetScrollRange(int nBar, LPINT lpMinPos, LPINT lpMaxPos) const + { + const T* pT = static_cast<const T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_GetScrollRange(pT->m_hWnd, nBar, lpMinPos, lpMaxPos); + } + + BOOL FlatSB_SetScrollRange(int nBar, int nMinPos, int nMaxPos, BOOL bRedraw = TRUE) + { + T* pT = static_cast<T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_SetScrollRange(pT->m_hWnd, nBar, nMinPos, nMaxPos, bRedraw); + } + + BOOL FlatSB_GetScrollInfo(int nBar, LPSCROLLINFO lpScrollInfo) const + { + const T* pT = static_cast<const T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_GetScrollInfo(pT->m_hWnd, nBar, lpScrollInfo); + } + + int FlatSB_SetScrollInfo(int nBar, LPSCROLLINFO lpScrollInfo, BOOL bRedraw = TRUE) + { + T* pT = static_cast<T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_SetScrollInfo(pT->m_hWnd, nBar, lpScrollInfo, bRedraw); + } + +// Operations + BOOL FlatSB_ShowScrollBar(UINT nBar, BOOL bShow = TRUE) + { + T* pT = static_cast<T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_ShowScrollBar(pT->m_hWnd, nBar, bShow); + } + + BOOL FlatSB_EnableScrollBar(UINT uSBFlags, UINT uArrowFlags = ESB_ENABLE_BOTH) + { + T* pT = static_cast<T*>(this); + ATLASSERT(::IsWindow(pT->m_hWnd)); + return ::FlatSB_EnableScrollBar(pT->m_hWnd, uSBFlags, uArrowFlags); + } +}; + +template <class TBase> +class CFlatScrollBarT : public TBase, public CFlatScrollBarImpl<CFlatScrollBarT< TBase > > +{ +public: + CFlatScrollBarT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CFlatScrollBarT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } +}; + +typedef CFlatScrollBarT<ATL::CWindow> CFlatScrollBar; + + +/////////////////////////////////////////////////////////////////////////////// +// CIPAddressCtrl + +template <class TBase> +class CIPAddressCtrlT : public TBase +{ +public: +// Constructors + CIPAddressCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CIPAddressCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Atteributes + static LPCTSTR GetWndClassName() + { + return WC_IPADDRESS; + } + + BOOL IsBlank() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, IPM_ISBLANK, 0, 0L); + } + + int GetAddress(LPDWORD lpdwAddress) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, IPM_GETADDRESS, 0, (LPARAM)lpdwAddress); + } + + void SetAddress(DWORD dwAddress) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, IPM_SETADDRESS, 0, dwAddress); + } + + void ClearAddress() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, IPM_CLEARADDRESS, 0, 0L); + } + + void SetRange(int nField, WORD wRange) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, IPM_SETRANGE, nField, wRange); + } + + void SetRange(int nField, BYTE nMin, BYTE nMax) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, IPM_SETRANGE, nField, MAKEIPRANGE(nMin, nMax)); + } + + void SetFocus(int nField) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, IPM_SETFOCUS, nField, 0L); + } +}; + +typedef CIPAddressCtrlT<ATL::CWindow> CIPAddressCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CPagerCtrl + +template <class TBase> +class CPagerCtrlT : public TBase +{ +public: +// Constructors + CPagerCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CPagerCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { + return WC_PAGESCROLLER; + } + + int GetButtonSize() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PGM_GETBUTTONSIZE, 0, 0L); + } + + int SetButtonSize(int nButtonSize) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PGM_SETBUTTONSIZE, 0, nButtonSize); + } + + DWORD GetButtonState(int nButton) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT((nButton == PGB_TOPORLEFT) || (nButton == PGB_BOTTOMORRIGHT)); + return (DWORD)::SendMessage(this->m_hWnd, PGM_GETBUTTONSTATE, 0, nButton); + } + + COLORREF GetBkColor() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, PGM_GETBKCOLOR, 0, 0L); + } + + COLORREF SetBkColor(COLORREF clrBk) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (COLORREF)::SendMessage(this->m_hWnd, PGM_SETBKCOLOR, 0, (LPARAM)clrBk); + } + + int GetBorder() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PGM_GETBORDER, 0, 0L); + } + + int SetBorder(int nBorderSize) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PGM_SETBORDER, 0, nBorderSize); + } + + int GetPos() const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PGM_GETPOS, 0, 0L); + } + + int SetPos(int nPos) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, PGM_SETPOS, 0, nPos); + } + +// Operations + void SetChild(HWND hWndChild) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, PGM_SETCHILD, 0, (LPARAM)hWndChild); + } + + void ForwardMouse(BOOL bForward = TRUE) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, PGM_FORWARDMOUSE, bForward, 0L); + } + + void RecalcSize() + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ::SendMessage(this->m_hWnd, PGM_RECALCSIZE, 0, 0L); + } + + void GetDropTarget(IDropTarget** ppDropTarget) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + ATLASSERT(ppDropTarget != NULL); + ::SendMessage(this->m_hWnd, PGM_GETDROPTARGET, 0, (LPARAM)ppDropTarget); + } +}; + +typedef CPagerCtrlT<ATL::CWindow> CPagerCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CLinkCtrl - Windows SYSLINK control + +template <class TBase> +class CLinkCtrlT : public TBase +{ +public: +// Constructors + CLinkCtrlT(HWND hWnd = NULL) : TBase(hWnd) + { } + + CLinkCtrlT< TBase >& operator =(HWND hWnd) + { + this->m_hWnd = hWnd; + return *this; + } + + HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL, + DWORD dwStyle = 0, DWORD dwExStyle = 0, + ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL) + { + return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam); + } + +// Attributes + static LPCTSTR GetWndClassName() + { +#ifdef _UNICODE + return WC_LINK; +#else // !_UNICODE + return "SysLink"; +#endif // !_UNICODE + } + + int GetIdealHeight(int cxMaxWidth = 0) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LM_GETIDEALHEIGHT, cxMaxWidth, 0L); + } + + BOOL GetItem(PLITEM pLItem) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LM_GETITEM, 0, (LPARAM)pLItem); + } + + BOOL SetItem(PLITEM pLItem) + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LM_SETITEM, 0, (LPARAM)pLItem); + } + + // Vista only + int GetIdealSize(SIZE& size, int cxMaxWidth = 0) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (int)::SendMessage(this->m_hWnd, LM_GETIDEALSIZE, cxMaxWidth, (LPARAM)&size); + } + +// Operations + BOOL HitTest(PLHITTESTINFO pLHitTestInfo) const + { + ATLASSERT(::IsWindow(this->m_hWnd)); + return (BOOL)::SendMessage(this->m_hWnd, LM_HITTEST, 0, (LPARAM)pLHitTestInfo); + } +}; + +typedef CLinkCtrlT<ATL::CWindow> CLinkCtrl; + + +/////////////////////////////////////////////////////////////////////////////// +// CCustomDraw - MI class for custom-draw support + +template <class T> +class CCustomDraw +{ +public: +// Message map and handlers + BEGIN_MSG_MAP(CCustomDraw< T >) + NOTIFY_CODE_HANDLER(NM_CUSTOMDRAW, OnCustomDraw) + ALT_MSG_MAP(1) + REFLECTED_NOTIFY_CODE_HANDLER(NM_CUSTOMDRAW, OnCustomDraw) + END_MSG_MAP() + +// message handler + LRESULT OnCustomDraw(int idCtrl, LPNMHDR pnmh, BOOL& bHandled) + { + T* pT = static_cast<T*>(this); + pT->SetMsgHandled(TRUE); + LPNMCUSTOMDRAW lpNMCustomDraw = (LPNMCUSTOMDRAW)pnmh; + DWORD dwRet = 0; + switch(lpNMCustomDraw->dwDrawStage) + { + case CDDS_PREPAINT: + dwRet = pT->OnPrePaint(idCtrl, lpNMCustomDraw); + break; + case CDDS_POSTPAINT: + dwRet = pT->OnPostPaint(idCtrl, lpNMCustomDraw); + break; + case CDDS_PREERASE: + dwRet = pT->OnPreErase(idCtrl, lpNMCustomDraw); + break; + case CDDS_POSTERASE: + dwRet = pT->OnPostErase(idCtrl, lpNMCustomDraw); + break; + case CDDS_ITEMPREPAINT: + dwRet = pT->OnItemPrePaint(idCtrl, lpNMCustomDraw); + break; + case CDDS_ITEMPOSTPAINT: + dwRet = pT->OnItemPostPaint(idCtrl, lpNMCustomDraw); + break; + case CDDS_ITEMPREERASE: + dwRet = pT->OnItemPreErase(idCtrl, lpNMCustomDraw); + break; + case CDDS_ITEMPOSTERASE: + dwRet = pT->OnItemPostErase(idCtrl, lpNMCustomDraw); + break; + case (CDDS_ITEMPREPAINT | CDDS_SUBITEM): + dwRet = pT->OnSubItemPrePaint(idCtrl, lpNMCustomDraw); + break; + default: + pT->SetMsgHandled(FALSE); + break; + } + bHandled = pT->IsMsgHandled(); + return dwRet; + } + +// Overrideables + DWORD OnPrePaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } + + DWORD OnPostPaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } + + DWORD OnPreErase(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } + + DWORD OnPostErase(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } + + DWORD OnItemPrePaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } + + DWORD OnItemPostPaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } + + DWORD OnItemPreErase(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } + + DWORD OnItemPostErase(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } + + DWORD OnSubItemPrePaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/) + { + return CDRF_DODEFAULT; + } +}; + +} // namespace WTL + +#endif // __ATLCTRLS_H__ diff --git a/Examples/WhisperDesktop/Utils/WTL/atlddx.h b/Examples/WhisperDesktop/Utils/WTL/atlddx.h new file mode 100644 index 0000000..7723b99 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/WTL/atlddx.h @@ -0,0 +1,667 @@ +// Windows Template Library - WTL version 10.0 +// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved. +// +// This file is a part of the Windows Template Library. +// The use and distribution terms for this software are covered by the +// Microsoft Public License (http://opensource.org/licenses/MS-PL) +// which can be found in the file MS-PL.txt at the root folder. + +#ifndef __ATLDDX_H__ +#define __ATLDDX_H__ + +#pragma once + +#ifndef __ATLAPP_H__ + #error atlddx.h requires atlapp.h to be included first +#endif + +#include <float.h> + + +/////////////////////////////////////////////////////////////////////////////// +// Classes in this file: +// +// CWinDataExchange<T> + + +namespace WTL +{ + +// Constants +#define DDX_LOAD FALSE +#define DDX_SAVE TRUE + +// DDX map macros +#define BEGIN_DDX_MAP(thisClass) \ + BOOL DoDataExchange(BOOL bSaveAndValidate = FALSE, UINT nCtlID = (UINT)-1) \ + { \ + (bSaveAndValidate); \ + (nCtlID); + +#define DDX_TEXT(nID, var) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Text(nID, var, sizeof(var), bSaveAndValidate)) \ + return FALSE; \ + } + +#define DDX_TEXT_LEN(nID, var, len) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Text(nID, var, sizeof(var), bSaveAndValidate, TRUE, len)) \ + return FALSE; \ + } + +#define DDX_INT(nID, var) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Int(nID, var, TRUE, bSaveAndValidate)) \ + return FALSE; \ + } + +#define DDX_INT_RANGE(nID, var, min, max) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Int(nID, var, TRUE, bSaveAndValidate, TRUE, min, max)) \ + return FALSE; \ + } + +#define DDX_UINT(nID, var) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Int(nID, var, FALSE, bSaveAndValidate)) \ + return FALSE; \ + } + +#define DDX_UINT_RANGE(nID, var, min, max) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Int(nID, var, FALSE, bSaveAndValidate, TRUE, min, max)) \ + return FALSE; \ + } + +#define DDX_FLOAT(nID, var) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Float(nID, var, bSaveAndValidate)) \ + return FALSE; \ + } + +#define DDX_FLOAT_RANGE(nID, var, min, max) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Float(nID, var, bSaveAndValidate, TRUE, min, max)) \ + return FALSE; \ + } +#define DDX_FLOAT_P(nID, var, precision) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Float(nID, var, bSaveAndValidate, FALSE, 0, 0, precision)) \ + return FALSE; \ + } + +#define DDX_FLOAT_P_RANGE(nID, var, min, max, precision) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + { \ + if(!DDX_Float(nID, var, bSaveAndValidate, TRUE, min, max, precision)) \ + return FALSE; \ + } + +#define DDX_CONTROL(nID, obj) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + DDX_Control(nID, obj, bSaveAndValidate); + +#define DDX_CONTROL_HANDLE(nID, obj) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + DDX_Control_Handle(nID, obj, bSaveAndValidate); + +#define DDX_CHECK(nID, var) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + DDX_Check(nID, var, bSaveAndValidate); + +#define DDX_RADIO(nID, var) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + DDX_Radio(nID, var, bSaveAndValidate); + +#define END_DDX_MAP() \ + return TRUE; \ + } + +// DDX support for Tab, Combo, ListBox and ListView selection index +// Note: Specialized versions require atlctrls.h to be included first + +#define DDX_INDEX(CtrlClass, nID, var) \ + if((nCtlID == (UINT)-1) || (nCtlID == nID)) \ + DDX_Index<CtrlClass>(nID, var, bSaveAndValidate); + +#ifdef __ATLCTRLS_H__ + #define DDX_TAB_INDEX(nID, var) DDX_INDEX(WTL::CTabCtrl, nID, var) + #define DDX_COMBO_INDEX(nID, var) DDX_INDEX(WTL::CComboBox, nID, var) + #define DDX_LISTBOX_INDEX(nID, var) DDX_INDEX(WTL::CListBox, nID, var) + #define DDX_LISTVIEW_INDEX(nID, var) DDX_INDEX(WTL::CListViewCtrl, nID, var) +#endif // __ATLCTRLS_H__ + + +/////////////////////////////////////////////////////////////////////////////// +// CWinDataExchange - provides support for DDX + +template <class T> +class CWinDataExchange +{ +public: +// Data exchange method - override in your derived class + BOOL DoDataExchange(BOOL /*bSaveAndValidate*/ = FALSE, UINT /*nCtlID*/ = (UINT)-1) + { + // this one should never be called, override it in + // your derived class by implementing DDX map + ATLASSERT(FALSE); + return FALSE; + } + +// Helpers for validation error reporting + enum _XDataType + { + ddxDataNull = 0, + ddxDataText = 1, + ddxDataInt = 2, + ddxDataFloat = 3, + ddxDataDouble = 4 + }; + + struct _XTextData + { + int nLength; + int nMaxLength; + }; + + struct _XIntData + { + long nVal; + long nMin; + long nMax; + }; + + struct _XFloatData + { + double nVal; + double nMin; + double nMax; + }; + + struct _XData + { + _XDataType nDataType; + union + { + _XTextData textData; + _XIntData intData; + _XFloatData floatData; + }; + }; + +// Text exchange + BOOL DDX_Text(UINT nID, LPTSTR lpstrText, int cbSize, BOOL bSave, BOOL bValidate = FALSE, int nLength = 0) + { + T* pT = static_cast<T*>(this); + BOOL bSuccess = TRUE; + + if(bSave) + { + HWND hWndCtrl = pT->GetDlgItem(nID); + int nRetLen = ::GetWindowText(hWndCtrl, lpstrText, cbSize / sizeof(TCHAR)); + if(nRetLen < ::GetWindowTextLength(hWndCtrl)) + bSuccess = FALSE; + } + else + { + ATLASSERT(!bValidate || (lstrlen(lpstrText) <= nLength)); + bSuccess = pT->SetDlgItemText(nID, lpstrText); + } + + if(!bSuccess) + { + pT->OnDataExchangeError(nID, bSave); + } + else if(bSave && bValidate) // validation + { + ATLASSERT(nLength > 0); + if(lstrlen(lpstrText) > nLength) + { + _XData data = { ddxDataText }; + data.textData.nLength = lstrlen(lpstrText); + data.textData.nMaxLength = nLength; + pT->OnDataValidateError(nID, bSave, data); + bSuccess = FALSE; + } + } + return bSuccess; + } + + BOOL DDX_Text(UINT nID, BSTR& bstrText, int /*cbSize*/, BOOL bSave, BOOL bValidate = FALSE, int nLength = 0) + { + T* pT = static_cast<T*>(this); + BOOL bSuccess = TRUE; + + if(bSave) + { + bSuccess = pT->GetDlgItemText(nID, bstrText); + } + else + { + USES_CONVERSION; + LPTSTR lpstrText = OLE2T(bstrText); + ATLASSERT(!bValidate || (lstrlen(lpstrText) <= nLength)); + bSuccess = pT->SetDlgItemText(nID, lpstrText); + } + + if(!bSuccess) + { + pT->OnDataExchangeError(nID, bSave); + } + else if(bSave && bValidate) // validation + { + ATLASSERT(nLength > 0); + if((int)::SysStringLen(bstrText) > nLength) + { + _XData data = { ddxDataText }; + data.textData.nLength = (int)::SysStringLen(bstrText); + data.textData.nMaxLength = nLength; + pT->OnDataValidateError(nID, bSave, data); + bSuccess = FALSE; + } + } + return bSuccess; + } + + BOOL DDX_Text(UINT nID, ATL::CComBSTR& bstrText, int /*cbSize*/, BOOL bSave, BOOL bValidate = FALSE, int nLength = 0) + { + T* pT = static_cast<T*>(this); + BOOL bSuccess = TRUE; + + if(bSave) + { + bSuccess = pT->GetDlgItemText(nID, (BSTR&)bstrText); + } + else + { + USES_CONVERSION; + LPTSTR lpstrText = OLE2T(bstrText); + ATLASSERT(!bValidate || (lstrlen(lpstrText) <= nLength)); + bSuccess = pT->SetDlgItemText(nID, lpstrText); + } + + if(!bSuccess) + { + pT->OnDataExchangeError(nID, bSave); + } + else if(bSave && bValidate) // validation + { + ATLASSERT(nLength > 0); + if((int)bstrText.Length() > nLength) + { + _XData data = { ddxDataText }; + data.textData.nLength = (int)bstrText.Length(); + data.textData.nMaxLength = nLength; + pT->OnDataValidateError(nID, bSave, data); + bSuccess = FALSE; + } + } + return bSuccess; + } + +#ifdef __ATLSTR_H__ + BOOL DDX_Text(UINT nID, ATL::CString& strText, int /*cbSize*/, BOOL bSave, BOOL bValidate = FALSE, int nLength = 0) + { + T* pT = static_cast<T*>(this); + BOOL bSuccess = TRUE; + + if(bSave) + { + HWND hWndCtrl = pT->GetDlgItem(nID); + int nLen = ::GetWindowTextLength(hWndCtrl); + int nRetLen = -1; + LPTSTR lpstr = strText.GetBufferSetLength(nLen); + if(lpstr != NULL) + { + nRetLen = ::GetWindowText(hWndCtrl, lpstr, nLen + 1); + strText.ReleaseBuffer(); + } + if(nRetLen < nLen) + bSuccess = FALSE; + } + else + { + bSuccess = pT->SetDlgItemText(nID, strText); + } + + if(!bSuccess) + { + pT->OnDataExchangeError(nID, bSave); + } + else if(bSave && bValidate) // validation + { + ATLASSERT(nLength > 0); + if(strText.GetLength() > nLength) + { + _XData data = { ddxDataText }; + data.textData.nLength = strText.GetLength(); + data.textData.nMaxLength = nLength; + pT->OnDataValidateError(nID, bSave, data); + bSuccess = FALSE; + } + } + return bSuccess; + } +#endif // __ATLSTR_H__ + +// Numeric exchange + template <class Type> + BOOL DDX_Int(UINT nID, Type& nVal, BOOL bSigned, BOOL bSave, BOOL bValidate = FALSE, Type nMin = 0, Type nMax = 0) + { + T* pT = static_cast<T*>(this); + BOOL bSuccess = TRUE; + + if(bSave) + { + nVal = (Type)pT->GetDlgItemInt(nID, &bSuccess, bSigned); + } + else + { + ATLASSERT(!bValidate || ((nVal >= nMin) && (nVal <= nMax))); + bSuccess = pT->SetDlgItemInt(nID, nVal, bSigned); + } + + if(!bSuccess) + { + pT->OnDataExchangeError(nID, bSave); + } + else if(bSave && bValidate) // validation + { + ATLASSERT(nMin != nMax); + if((nVal < nMin) || (nVal > nMax)) + { + _XData data = { ddxDataInt }; + data.intData.nVal = (long)nVal; + data.intData.nMin = (long)nMin; + data.intData.nMax = (long)nMax; + pT->OnDataValidateError(nID, bSave, data); + bSuccess = FALSE; + } + } + return bSuccess; + } + +// Float exchange + static BOOL _AtlSimpleFloatParse(LPCTSTR lpszText, double& d) + { + ATLASSERT(lpszText != NULL); + while ((*lpszText == _T(' ')) || (*lpszText == _T('\t'))) + lpszText++; + + TCHAR chFirst = lpszText[0]; + d = _tcstod(lpszText, (LPTSTR*)&lpszText); + if ((d == 0.0) && (chFirst != _T('0'))) + return FALSE; // could not convert + while ((*lpszText == _T(' ')) || (*lpszText == _T('\t'))) + lpszText++; + + if (*lpszText != _T('\0')) + return FALSE; // not terminated properly + + return TRUE; + } + + BOOL DDX_Float(UINT nID, float& nVal, BOOL bSave, BOOL bValidate = FALSE, float nMin = 0.F, float nMax = 0.F, int nPrecision = FLT_DIG) + { + T* pT = static_cast<T*>(this); + BOOL bSuccess = TRUE; + const int cchBuff = 32; + TCHAR szBuff[cchBuff] = {}; + + if(bSave) + { + pT->GetDlgItemText(nID, szBuff, cchBuff); + double d = 0; + if(_AtlSimpleFloatParse(szBuff, d)) + nVal = (float)d; + else + bSuccess = FALSE; + } + else + { + ATLASSERT(!bValidate || ((nVal >= nMin) && (nVal <= nMax))); + _stprintf_s(szBuff, cchBuff, _T("%.*g"), nPrecision, nVal); + bSuccess = pT->SetDlgItemText(nID, szBuff); + } + + if(!bSuccess) + { + pT->OnDataExchangeError(nID, bSave); + } + else if(bSave && bValidate) // validation + { + ATLASSERT(nMin != nMax); + if((nVal < nMin) || (nVal > nMax)) + { + _XData data = { ddxDataFloat }; + data.floatData.nVal = (double)nVal; + data.floatData.nMin = (double)nMin; + data.floatData.nMax = (double)nMax; + pT->OnDataValidateError(nID, bSave, data); + bSuccess = FALSE; + } + } + return bSuccess; + } + + BOOL DDX_Float(UINT nID, double& nVal, BOOL bSave, BOOL bValidate = FALSE, double nMin = 0., double nMax = 0., int nPrecision = DBL_DIG) + { + T* pT = static_cast<T*>(this); + BOOL bSuccess = TRUE; + const int cchBuff = 32; + TCHAR szBuff[cchBuff] = {}; + + if(bSave) + { + pT->GetDlgItemText(nID, szBuff, cchBuff); + double d = 0; + if(_AtlSimpleFloatParse(szBuff, d)) + nVal = d; + else + bSuccess = FALSE; + } + else + { + ATLASSERT(!bValidate || ((nVal >= nMin) && (nVal <= nMax))); + _stprintf_s(szBuff, cchBuff, _T("%.*g"), nPrecision, nVal); + bSuccess = pT->SetDlgItemText(nID, szBuff); + } + + if(!bSuccess) + { + pT->OnDataExchangeError(nID, bSave); + } + else if(bSave && bValidate) // validation + { + ATLASSERT(nMin != nMax); + if((nVal < nMin) || (nVal > nMax)) + { + _XData data = { ddxDataFloat }; + data.floatData.nVal = nVal; + data.floatData.nMin = nMin; + data.floatData.nMax = nMax; + pT->OnDataValidateError(nID, bSave, data); + bSuccess = FALSE; + } + } + return bSuccess; + } + +// Full control subclassing (for CWindowImpl derived controls) + template <class TControl> + void DDX_Control(UINT nID, TControl& ctrl, BOOL bSave) + { + if(!bSave && (ctrl.m_hWnd == NULL)) + { + T* pT = static_cast<T*>(this); + ctrl.SubclassWindow(pT->GetDlgItem(nID)); + } + } + +// Simple control attaching (for HWND wrapper controls) + template <class TControl> + void DDX_Control_Handle(UINT nID, TControl& ctrl, BOOL bSave) + { + if(!bSave && (ctrl.m_hWnd == NULL)) + { + T* pT = static_cast<T*>(this); + ctrl = pT->GetDlgItem(nID); + } + } + +// Control state + void DDX_Check(UINT nID, int& nValue, BOOL bSave) + { + T* pT = static_cast<T*>(this); + HWND hWndCtrl = pT->GetDlgItem(nID); + if(bSave) + { + nValue = (int)::SendMessage(hWndCtrl, BM_GETCHECK, 0, 0L); + ATLASSERT((nValue >= 0) && (nValue <= 2)); + } + else + { + if((nValue < 0) || (nValue > 2)) + { + ATLTRACE2(atlTraceUI, 0, _T("ATL: Warning - dialog data checkbox value (%d) out of range.\n"), nValue); + nValue = 0; // default to off + } + ::SendMessage(hWndCtrl, BM_SETCHECK, nValue, 0L); + } + } + + // variant that supports bool (checked/not-checked, no intermediate state) + void DDX_Check(UINT nID, bool& bCheck, BOOL bSave) + { + int nValue = bCheck ? 1 : 0; + DDX_Check(nID, nValue, bSave); + + if(bSave) + { + if(nValue == 2) + ATLTRACE2(atlTraceUI, 0, _T("ATL: Warning - checkbox state (%d) out of supported range.\n"), nValue); + bCheck = (nValue == 1); + } + } + + void DDX_Radio(UINT nID, int& nValue, BOOL bSave) + { + T* pT = static_cast<T*>(this); + HWND hWndCtrl = pT->GetDlgItem(nID); + ATLASSERT(hWndCtrl != NULL); + + // must be first in a group of auto radio buttons + ATLASSERT(::GetWindowLong(hWndCtrl, GWL_STYLE) & WS_GROUP); + ATLASSERT(::SendMessage(hWndCtrl, WM_GETDLGCODE, 0, 0L) & DLGC_RADIOBUTTON); + + if(bSave) + nValue = -1; // value if none found + + // walk all children in group + int nButton = 0; + do + { + if(::SendMessage(hWndCtrl, WM_GETDLGCODE, 0, 0L) & DLGC_RADIOBUTTON) + { + // control in group is a radio button + if(bSave) + { + if(::SendMessage(hWndCtrl, BM_GETCHECK, 0, 0L) != 0) + { + ATLASSERT(nValue == -1); // only set once + nValue = nButton; + } + } + else + { + // select button + ::SendMessage(hWndCtrl, BM_SETCHECK, (nButton == nValue), 0L); + } + nButton++; + } + else + { + ATLTRACE2(atlTraceUI, 0, _T("ATL: Warning - skipping non-radio button in group.\n")); + } + hWndCtrl = ::GetWindow(hWndCtrl, GW_HWNDNEXT); + } + while ((hWndCtrl != NULL) && !(GetWindowLong(hWndCtrl, GWL_STYLE) & WS_GROUP)); + } + +// DDX support for Tab, Combo, ListBox and ListView selection index + template <class TCtrl> + INT _getSel(TCtrl& tCtrl) + { + return tCtrl.GetCurSel(); + } + + template <class TCtrl> + void _setSel(TCtrl& tCtrl, INT iSel) + { + if(iSel < 0) + tCtrl.SetCurSel(-1); + else + tCtrl.SetCurSel(iSel); + } + +#ifdef __ATLCTRLS_H__ + // ListViewCtrl specialization + template <> + INT _getSel(WTL::CListViewCtrl& tCtrl) + { + return tCtrl.GetSelectedIndex(); + } + + template <> + void _setSel(WTL::CListViewCtrl& tCtrl, INT iSel) + { + if(iSel < 0) + tCtrl.SelectItem(-1); + else + tCtrl.SelectItem(iSel); + } +#endif // __ATLCTRLS_H__ + + template <class TCtrl> + void DDX_Index(UINT nID, INT& nVal, BOOL bSave) + { + T* pT = static_cast<T*>(this); + TCtrl ctrl(pT->GetDlgItem(nID)); + + if(bSave) + nVal = _getSel(ctrl); + else + _setSel(ctrl, nVal); + } + +// Overrideables + void OnDataExchangeError(UINT nCtrlID, BOOL /*bSave*/) + { + // Override to display an error message + ::MessageBeep((UINT)-1); + T* pT = static_cast<T*>(this); + ::SetFocus(pT->GetDlgItem(nCtrlID)); + } + + void OnDataValidateError(UINT nCtrlID, BOOL /*bSave*/, _XData& /*data*/) + { + // Override to display an error message + ::MessageBeep((UINT)-1); + T* pT = static_cast<T*>(this); + ::SetFocus(pT->GetDlgItem(nCtrlID)); + } +}; + +} // namespace WTL + +#endif // __ATLDDX_H__ diff --git a/Examples/WhisperDesktop/Utils/WTL/atlgdi.h b/Examples/WhisperDesktop/Utils/WTL/atlgdi.h new file mode 100644 index 0000000..14e1518 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/WTL/atlgdi.h @@ -0,0 +1,3445 @@ +// Windows Template Library - WTL version 10.0 +// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved. +// +// This file is a part of the Windows Template Library. +// The use and distribution terms for this software are covered by the +// Microsoft Public License (http://opensource.org/licenses/MS-PL) +// which can be found in the file MS-PL.txt at the root folder. + +#ifndef __ATLGDI_H__ +#define __ATLGDI_H__ + +#pragma once + +#ifndef __ATLAPP_H__ + #error atlgdi.h requires atlapp.h to be included first +#endif + + +// protect template members from windowsx.h macros +#ifdef _INC_WINDOWSX + #undef CopyRgn + #undef CreateBrush + #undef CreatePen + #undef SelectBrush + #undef SelectPen + #undef SelectFont + #undef SelectBitmap +#endif // _INC_WINDOWSX + +// required libraries +#pragma comment(lib, "msimg32.lib") +#if !defined(_ATL_NO_OPENGL) + #pragma comment(lib, "opengl32.lib") +#endif + + +/////////////////////////////////////////////////////////////////////////////// +// Classes in this file: +// +// CPenT<t_bManaged> +// CBrushT<t_bManaged> +// CLogFont +// CFontT<t_bManaged> +// CBitmapT<t_bManaged> +// CPaletteT<t_bManaged> +// CRgnT<t_bManaged> +// CDCT<t_bManaged> +// CPaintDC +// CClientDC +// CWindowDC +// CMemoryDC +// CEnhMetaFileInfo +// CEnhMetaFileT<t_bManaged> +// CEnhMetaFileDC + + +namespace WTL +{ + +/////////////////////////////////////////////////////////////////////////////// +// Bitmap resource helpers to extract bitmap information for a bitmap resource + +inline LPBITMAPINFOHEADER AtlGetBitmapResourceInfo(HMODULE hModule, ATL::_U_STRINGorID image) +{ + HRSRC hResource = ::FindResource(hModule, image.m_lpstr, RT_BITMAP); + ATLASSERT(hResource != NULL); + HGLOBAL hGlobal = ::LoadResource(hModule, hResource); + ATLASSERT(hGlobal != NULL); + LPBITMAPINFOHEADER pBitmapInfoHeader = (LPBITMAPINFOHEADER)::LockResource(hGlobal); + ATLASSERT(pBitmapInfoHeader != NULL); + return pBitmapInfoHeader; +} + +inline WORD AtlGetBitmapResourceBitsPerPixel(HMODULE hModule, ATL::_U_STRINGorID image) +{ + LPBITMAPINFOHEADER pBitmapInfoHeader = AtlGetBitmapResourceInfo(hModule, image); + ATLASSERT(pBitmapInfoHeader != NULL); + return pBitmapInfoHeader->biBitCount; +} + +inline WORD AtlGetBitmapResourceBitsPerPixel(ATL::_U_STRINGorID image) +{ + return AtlGetBitmapResourceBitsPerPixel(ModuleHelper::GetResourceInstance(), image); +} + +/////////////////////////////////////////////////////////////////////////////// +// 32-bit (alpha channel) bitmap resource helper + +// Note: 32-bit (alpha channel) images work only on Windows XP with Common Controls version 6. +// If you want your app to work on older version of Windows, load non-alpha images if Common +// Controls version is less than 6. + +inline bool AtlIsAlphaBitmapResource(ATL::_U_STRINGorID image) +{ + return (AtlGetBitmapResourceBitsPerPixel(image) == 32); +} + + +/////////////////////////////////////////////////////////////////////////////// +// CPen + +template <bool t_bManaged> +class CPenT +{ +public: +// Data members + HPEN m_hPen; + +// Constructor/destructor/operators + CPenT(HPEN hPen = NULL) : m_hPen(hPen) + { } + + ~CPenT() + { + if(t_bManaged && (m_hPen != NULL)) + DeleteObject(); + } + + CPenT<t_bManaged>& operator =(HPEN hPen) + { + Attach(hPen); + return *this; + } + + void Attach(HPEN hPen) + { + if(t_bManaged && (m_hPen != NULL) && (m_hPen != hPen)) + ::DeleteObject(m_hPen); + m_hPen = hPen; + } + + HPEN Detach() + { + HPEN hPen = m_hPen; + m_hPen = NULL; + return hPen; + } + + operator HPEN() const { return m_hPen; } + + bool IsNull() const { return (m_hPen == NULL); } + +// Create methods + HPEN CreatePen(int nPenStyle, int nWidth, COLORREF crColor) + { + ATLASSERT(m_hPen == NULL); + m_hPen = ::CreatePen(nPenStyle, nWidth, crColor); + return m_hPen; + } + + HPEN CreatePen(int nPenStyle, int nWidth, const LOGBRUSH* pLogBrush, int nStyleCount = 0, const DWORD* lpStyle = NULL) + { + ATLASSERT(m_hPen == NULL); + m_hPen = ::ExtCreatePen(nPenStyle, nWidth, pLogBrush, nStyleCount, lpStyle); + return m_hPen; + } + + HPEN CreatePenIndirect(LPLOGPEN lpLogPen) + { + ATLASSERT(m_hPen == NULL); + m_hPen = ::CreatePenIndirect(lpLogPen); + return m_hPen; + } + + BOOL DeleteObject() + { + ATLASSERT(m_hPen != NULL); + BOOL bRet = ::DeleteObject(m_hPen); + if(bRet) + m_hPen = NULL; + return bRet; + } + +// Attributes + int GetLogPen(LOGPEN* pLogPen) const + { + ATLASSERT(m_hPen != NULL); + return ::GetObject(m_hPen, sizeof(LOGPEN), pLogPen); + } + + bool GetLogPen(LOGPEN& LogPen) const + { + ATLASSERT(m_hPen != NULL); + return (::GetObject(m_hPen, sizeof(LOGPEN), &LogPen) == sizeof(LOGPEN)); + } + + int GetExtLogPen(EXTLOGPEN* pLogPen, int nSize = sizeof(EXTLOGPEN)) const + { + ATLASSERT(m_hPen != NULL); + return ::GetObject(m_hPen, nSize, pLogPen); + } + + bool GetExtLogPen(EXTLOGPEN& ExtLogPen, int nSize = sizeof(EXTLOGPEN)) const + { + ATLASSERT(m_hPen != NULL); + int nRet = ::GetObject(m_hPen, nSize, &ExtLogPen); + return ((nRet > 0) && (nRet <= nSize)); + } +}; + +typedef CPenT<false> CPenHandle; +typedef CPenT<true> CPen; + + +/////////////////////////////////////////////////////////////////////////////// +// CBrush + +template <bool t_bManaged> +class CBrushT +{ +public: +// Data members + HBRUSH m_hBrush; + +// Constructor/destructor/operators + CBrushT(HBRUSH hBrush = NULL) : m_hBrush(hBrush) + { } + + ~CBrushT() + { + if(t_bManaged && (m_hBrush != NULL)) + DeleteObject(); + } + + CBrushT<t_bManaged>& operator =(HBRUSH hBrush) + { + Attach(hBrush); + return *this; + } + + void Attach(HBRUSH hBrush) + { + if(t_bManaged && (m_hBrush != NULL) && (m_hBrush != hBrush)) + ::DeleteObject(m_hBrush); + m_hBrush = hBrush; + } + + HBRUSH Detach() + { + HBRUSH hBrush = m_hBrush; + m_hBrush = NULL; + return hBrush; + } + + operator HBRUSH() const { return m_hBrush; } + + bool IsNull() const { return (m_hBrush == NULL); } + +// Create methods + HBRUSH CreateSolidBrush(COLORREF crColor) + { + ATLASSERT(m_hBrush == NULL); + m_hBrush = ::CreateSolidBrush(crColor); + return m_hBrush; + } + + HBRUSH CreateHatchBrush(int nIndex, COLORREF crColor) + { + ATLASSERT(m_hBrush == NULL); + m_hBrush = ::CreateHatchBrush(nIndex, crColor); + return m_hBrush; + } + + HBRUSH CreateBrushIndirect(const LOGBRUSH* lpLogBrush) + { + ATLASSERT(m_hBrush == NULL); + m_hBrush = ::CreateBrushIndirect(lpLogBrush); + return m_hBrush; + } + + HBRUSH CreatePatternBrush(HBITMAP hBitmap) + { + ATLASSERT(m_hBrush == NULL); + m_hBrush = ::CreatePatternBrush(hBitmap); + return m_hBrush; + } + + HBRUSH CreateDIBPatternBrush(HGLOBAL hPackedDIB, UINT nUsage) + { + ATLASSERT(hPackedDIB != NULL); + const void* lpPackedDIB = GlobalLock(hPackedDIB); + ATLASSERT(lpPackedDIB != NULL); + m_hBrush = ::CreateDIBPatternBrushPt(lpPackedDIB, nUsage); + GlobalUnlock(hPackedDIB); + return m_hBrush; + } + + HBRUSH CreateDIBPatternBrush(const void* lpPackedDIB, UINT nUsage) + { + ATLASSERT(m_hBrush == NULL); + m_hBrush = ::CreateDIBPatternBrushPt(lpPackedDIB, nUsage); + return m_hBrush; + } + + HBRUSH CreateSysColorBrush(int nIndex) + { + ATLASSERT(m_hBrush == NULL); + m_hBrush = ::GetSysColorBrush(nIndex); + return m_hBrush; + } + + BOOL DeleteObject() + { + ATLASSERT(m_hBrush != NULL); + BOOL bRet = ::DeleteObject(m_hBrush); + if(bRet) + m_hBrush = NULL; + return bRet; + } + +// Attributes + int GetLogBrush(LOGBRUSH* pLogBrush) const + { + ATLASSERT(m_hBrush != NULL); + return ::GetObject(m_hBrush, sizeof(LOGBRUSH), pLogBrush); + } + + bool GetLogBrush(LOGBRUSH& LogBrush) const + { + ATLASSERT(m_hBrush != NULL); + return (::GetObject(m_hBrush, sizeof(LOGBRUSH), &LogBrush) == sizeof(LOGBRUSH)); + } +}; + +typedef CBrushT<false> CBrushHandle; +typedef CBrushT<true> CBrush; + + +/////////////////////////////////////////////////////////////////////////////// +// CFont + +class CLogFont : public LOGFONT +{ +public: + CLogFont() + { + memset(this, 0, sizeof(LOGFONT)); + } + + CLogFont(const LOGFONT& lf) + { + Copy(&lf); + } + + CLogFont(HFONT hFont) + { + ATLASSERT(::GetObjectType(hFont) == OBJ_FONT); + ::GetObject(hFont, sizeof(LOGFONT), (LOGFONT*)this); + } + + HFONT CreateFontIndirect() + { + return ::CreateFontIndirect(this); + } + + void SetBold() + { + lfWeight = FW_BOLD; + } + + bool IsBold() const + { + return (lfWeight >= FW_BOLD); + } + + void MakeBolder(int iScale = 1) + { + lfWeight += FW_BOLD * iScale; + } + + void MakeLarger(int iScale) + { + if(lfHeight > 0) + lfHeight += iScale; + else + lfHeight -= iScale; + } + + void SetHeight(LONG nPointSize, HDC hDC = NULL) + { + HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL); + // For MM_TEXT mapping mode + lfHeight = -::MulDiv(nPointSize, ::GetDeviceCaps(hDC1, LOGPIXELSY), 72); + if(hDC == NULL) + ::ReleaseDC(NULL, hDC1); + } + + LONG GetHeight(HDC hDC = NULL) const + { + HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL); + // For MM_TEXT mapping mode + LONG nPointSize = ::MulDiv(-lfHeight, 72, ::GetDeviceCaps(hDC1, LOGPIXELSY)); + if(hDC == NULL) + ::ReleaseDC(NULL, hDC1); + + return nPointSize; + } + + LONG GetDeciPointHeight(HDC hDC = NULL) const + { + HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL); + POINT ptOrg = { 0, 0 }; + ::DPtoLP(hDC1, &ptOrg, 1); + POINT pt = { 0, 0 }; + pt.y = abs(lfHeight) + ptOrg.y; + ::LPtoDP(hDC1, &pt, 1); + LONG nDeciPoint = ::MulDiv(pt.y, 720, ::GetDeviceCaps(hDC1, LOGPIXELSY)); // 72 points/inch, 10 decipoints/point + if(hDC == NULL) + ::ReleaseDC(NULL, hDC1); + + return nDeciPoint; + } + + void SetHeightFromDeciPoint(LONG nDeciPtHeight, HDC hDC = NULL) + { + HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL); + POINT pt = { 0, 0 }; + pt.y = ::MulDiv(::GetDeviceCaps(hDC1, LOGPIXELSY), nDeciPtHeight, 720); // 72 points/inch, 10 decipoints/point + ::DPtoLP(hDC1, &pt, 1); + POINT ptOrg = { 0, 0 }; + ::DPtoLP(hDC1, &ptOrg, 1); + lfHeight = -abs(pt.y - ptOrg.y); + if(hDC == NULL) + ::ReleaseDC(NULL, hDC1); + } + + void SetCaptionFont() + { + NONCLIENTMETRICS ncm = { RunTimeHelper::SizeOf_NONCLIENTMETRICS() }; + ATLVERIFY(::SystemParametersInfo(SPI_GETNONCLIENTMETRICS, sizeof(ncm), &ncm, 0)); + Copy(&ncm.lfCaptionFont); + } + + void SetMenuFont() + { + NONCLIENTMETRICS ncm = { RunTimeHelper::SizeOf_NONCLIENTMETRICS() }; + ATLVERIFY(::SystemParametersInfo(SPI_GETNONCLIENTMETRICS, sizeof(ncm), &ncm, 0)); + Copy(&ncm.lfMenuFont); + } + + void SetStatusFont() + { + NONCLIENTMETRICS ncm = { RunTimeHelper::SizeOf_NONCLIENTMETRICS() }; + ATLVERIFY(::SystemParametersInfo(SPI_GETNONCLIENTMETRICS, sizeof(ncm), &ncm, 0)); + Copy(&ncm.lfStatusFont); + } + + void SetMessageBoxFont() + { + NONCLIENTMETRICS ncm = { RunTimeHelper::SizeOf_NONCLIENTMETRICS() }; + ATLVERIFY(::SystemParametersInfo(SPI_GETNONCLIENTMETRICS, sizeof(ncm), &ncm, 0)); + Copy(&ncm.lfMessageFont); + } + + void Copy(const LOGFONT* pLogFont) + { + ATLASSERT(pLogFont != NULL); + *(LOGFONT*)this = *pLogFont; + } + + CLogFont& operator =(const CLogFont& src) + { + Copy(&src); + return *this; + } + + CLogFont& operator =(const LOGFONT& src) + { + Copy(&src); + return *this; + } + + CLogFont& operator =(HFONT hFont) + { + ATLASSERT(::GetObjectType(hFont) == OBJ_FONT); + ::GetObject(hFont, sizeof(LOGFONT), (LOGFONT*)this); + return *this; + } + + bool operator ==(const LOGFONT& logfont) const + { + return((logfont.lfHeight == lfHeight) && + (logfont.lfWidth == lfWidth) && + (logfont.lfEscapement == lfEscapement) && + (logfont.lfOrientation == lfOrientation) && + (logfont.lfWeight == lfWeight) && + (logfont.lfItalic == lfItalic) && + (logfont.lfUnderline == lfUnderline) && + (logfont.lfStrikeOut == lfStrikeOut) && + (logfont.lfCharSet == lfCharSet) && + (logfont.lfOutPrecision == lfOutPrecision) && + (logfont.lfClipPrecision == lfClipPrecision) && + (logfont.lfQuality == lfQuality) && + (logfont.lfPitchAndFamily == lfPitchAndFamily) && + (lstrcmp(logfont.lfFaceName, lfFaceName) == 0)); + } +}; + + +template <bool t_bManaged> +class CFontT +{ +public: +// Data members + HFONT m_hFont; + +// Constructor/destructor/operators + CFontT(HFONT hFont = NULL) : m_hFont(hFont) + { } + + ~CFontT() + { + if(t_bManaged && (m_hFont != NULL)) + DeleteObject(); + } + + CFontT<t_bManaged>& operator =(HFONT hFont) + { + Attach(hFont); + return *this; + } + + void Attach(HFONT hFont) + { + if(t_bManaged && (m_hFont != NULL) && (m_hFont != hFont)) + ::DeleteObject(m_hFont); + m_hFont = hFont; + } + + HFONT Detach() + { + HFONT hFont = m_hFont; + m_hFont = NULL; + return hFont; + } + + operator HFONT() const { return m_hFont; } + + bool IsNull() const { return (m_hFont == NULL); } + +// Create methods + HFONT CreateFontIndirect(const LOGFONT* lpLogFont) + { + ATLASSERT(m_hFont == NULL); + m_hFont = ::CreateFontIndirect(lpLogFont); + return m_hFont; + } + + HFONT CreateFontIndirectEx(CONST ENUMLOGFONTEXDV* penumlfex) + { + ATLASSERT(m_hFont == NULL); + m_hFont = ::CreateFontIndirectEx(penumlfex); + return m_hFont; + } + + HFONT CreateFont(int nHeight, int nWidth, int nEscapement, + int nOrientation, int nWeight, BYTE bItalic, BYTE bUnderline, + BYTE cStrikeOut, BYTE nCharSet, BYTE nOutPrecision, + BYTE nClipPrecision, BYTE nQuality, BYTE nPitchAndFamily, + LPCTSTR lpszFacename) + { + ATLASSERT(m_hFont == NULL); + m_hFont = ::CreateFont(nHeight, nWidth, nEscapement, + nOrientation, nWeight, bItalic, bUnderline, cStrikeOut, + nCharSet, nOutPrecision, nClipPrecision, nQuality, + nPitchAndFamily, lpszFacename); + return m_hFont; + } + + HFONT CreatePointFont(int nPointSize, LPCTSTR lpszFaceName, HDC hDC = NULL, bool bBold = false, bool bItalic = false) + { + LOGFONT logFont = {}; + logFont.lfCharSet = DEFAULT_CHARSET; + logFont.lfHeight = nPointSize; + ATL::Checked::tcsncpy_s(logFont.lfFaceName, _countof(logFont.lfFaceName), lpszFaceName, _TRUNCATE); + + if(bBold) + logFont.lfWeight = FW_BOLD; + if(bItalic) + logFont.lfItalic = (BYTE)TRUE; + + return CreatePointFontIndirect(&logFont, hDC); + } + + HFONT CreatePointFontIndirect(const LOGFONT* lpLogFont, HDC hDC = NULL) + { + HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL); + + // convert nPointSize to logical units based on hDC + LOGFONT logFont = *lpLogFont; + POINT pt = { 0, 0 }; + pt.y = ::MulDiv(::GetDeviceCaps(hDC1, LOGPIXELSY), logFont.lfHeight, 720); // 72 points/inch, 10 decipoints/point + ::DPtoLP(hDC1, &pt, 1); + POINT ptOrg = { 0, 0 }; + ::DPtoLP(hDC1, &ptOrg, 1); + logFont.lfHeight = -abs(pt.y - ptOrg.y); + + if(hDC == NULL) + ::ReleaseDC(NULL, hDC1); + + return CreateFontIndirect(&logFont); + } + + BOOL DeleteObject() + { + ATLASSERT(m_hFont != NULL); + BOOL bRet = ::DeleteObject(m_hFont); + if(bRet) + m_hFont = NULL; + return bRet; + } + +// Attributes + int GetLogFont(LOGFONT* pLogFont) const + { + ATLASSERT(m_hFont != NULL); + return ::GetObject(m_hFont, sizeof(LOGFONT), pLogFont); + } + + bool GetLogFont(LOGFONT& LogFont) const + { + ATLASSERT(m_hFont != NULL); + return (::GetObject(m_hFont, sizeof(LOGFONT), &LogFont) == sizeof(LOGFONT)); + } +}; + +typedef CFontT<false> CFontHandle; +typedef CFontT<true> CFont; + + +/////////////////////////////////////////////////////////////////////////////// +// CBitmap + +template <bool t_bManaged> +class CBitmapT +{ +public: +// Data members + HBITMAP m_hBitmap; + +// Constructor/destructor/operators + CBitmapT(HBITMAP hBitmap = NULL) : m_hBitmap(hBitmap) + { } + + ~CBitmapT() + { + if(t_bManaged && (m_hBitmap != NULL)) + DeleteObject(); + } + + CBitmapT<t_bManaged>& operator =(HBITMAP hBitmap) + { + Attach(hBitmap); + return *this; + } + + void Attach(HBITMAP hBitmap) + { + if(t_bManaged && (m_hBitmap != NULL) && (m_hBitmap != hBitmap)) + ::DeleteObject(m_hBitmap); + m_hBitmap = hBitmap; + } + + HBITMAP Detach() + { + HBITMAP hBitmap = m_hBitmap; + m_hBitmap = NULL; + return hBitmap; + } + + operator HBITMAP() const { return m_hBitmap; } + + bool IsNull() const { return (m_hBitmap == NULL); } + +// Create and load methods + HBITMAP LoadBitmap(ATL::_U_STRINGorID bitmap) + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::LoadBitmap(ModuleHelper::GetResourceInstance(), bitmap.m_lpstr); + return m_hBitmap; + } + + HBITMAP LoadOEMBitmap(UINT nIDBitmap) // for OBM_/OCR_/OIC_ + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::LoadBitmap(NULL, MAKEINTRESOURCE(nIDBitmap)); + return m_hBitmap; + } + + HBITMAP LoadMappedBitmap(UINT nIDBitmap, UINT nFlags = 0, LPCOLORMAP lpColorMap = NULL, int nMapSize = 0) + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::CreateMappedBitmap(ModuleHelper::GetResourceInstance(), nIDBitmap, (WORD)nFlags, lpColorMap, nMapSize); + return m_hBitmap; + } + + HBITMAP CreateBitmap(int nWidth, int nHeight, UINT nPlanes, UINT nBitsPerPixel, const void* lpBits) + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::CreateBitmap(nWidth, nHeight, nPlanes, nBitsPerPixel, lpBits); + return m_hBitmap; + } + + HBITMAP CreateBitmapIndirect(LPBITMAP lpBitmap) + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::CreateBitmapIndirect(lpBitmap); + return m_hBitmap; + } + + HBITMAP CreateCompatibleBitmap(HDC hDC, int nWidth, int nHeight) + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::CreateCompatibleBitmap(hDC, nWidth, nHeight); + return m_hBitmap; + } + + HBITMAP CreateDiscardableBitmap(HDC hDC, int nWidth, int nHeight) + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::CreateDiscardableBitmap(hDC, nWidth, nHeight); + return m_hBitmap; + } + + BOOL DeleteObject() + { + ATLASSERT(m_hBitmap != NULL); + BOOL bRet = ::DeleteObject(m_hBitmap); + if(bRet) + m_hBitmap = NULL; + return bRet; + } + +// Attributes + int GetBitmap(BITMAP* pBitMap) const + { + ATLASSERT(m_hBitmap != NULL); + return ::GetObject(m_hBitmap, sizeof(BITMAP), pBitMap); + } + + bool GetBitmap(BITMAP& bm) const + { + ATLASSERT(m_hBitmap != NULL); + return (::GetObject(m_hBitmap, sizeof(BITMAP), &bm) == sizeof(BITMAP)); + } + + bool GetSize(SIZE& size) const + { + ATLASSERT(m_hBitmap != NULL); + BITMAP bm = {}; + if(!GetBitmap(&bm)) + return false; + size.cx = bm.bmWidth; + size.cy = bm.bmHeight; + return true; + } + + DWORD GetBitmapBits(DWORD dwCount, LPVOID lpBits) const + { + ATLASSERT(m_hBitmap != NULL); + return ::GetBitmapBits(m_hBitmap, dwCount, lpBits); + } + + DWORD SetBitmapBits(DWORD dwCount, const void* lpBits) + { + ATLASSERT(m_hBitmap != NULL); + return ::SetBitmapBits(m_hBitmap, dwCount, lpBits); + } + + BOOL GetBitmapDimension(LPSIZE lpSize) const + { + ATLASSERT(m_hBitmap != NULL); + return ::GetBitmapDimensionEx(m_hBitmap, lpSize); + } + + BOOL SetBitmapDimension(int nWidth, int nHeight, LPSIZE lpSize = NULL) + { + ATLASSERT(m_hBitmap != NULL); + return ::SetBitmapDimensionEx(m_hBitmap, nWidth, nHeight, lpSize); + } + +// DIB support + HBITMAP CreateDIBitmap(HDC hDC, CONST BITMAPINFOHEADER* lpbmih, DWORD dwInit, CONST VOID* lpbInit, CONST BITMAPINFO* lpbmi, UINT uColorUse) + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::CreateDIBitmap(hDC, lpbmih, dwInit, lpbInit, lpbmi, uColorUse); + return m_hBitmap; + } + + HBITMAP CreateDIBSection(HDC hDC, CONST BITMAPINFO* lpbmi, UINT uColorUse, VOID** ppvBits, HANDLE hSection, DWORD dwOffset) + { + ATLASSERT(m_hBitmap == NULL); + m_hBitmap = ::CreateDIBSection(hDC, lpbmi, uColorUse, ppvBits, hSection, dwOffset); + return m_hBitmap; + } + + int GetDIBits(HDC hDC, UINT uStartScan, UINT cScanLines, LPVOID lpvBits, LPBITMAPINFO lpbmi, UINT uColorUse) const + { + ATLASSERT(m_hBitmap != NULL); + return ::GetDIBits(hDC, m_hBitmap, uStartScan, cScanLines, lpvBits, lpbmi, uColorUse); + } + + int SetDIBits(HDC hDC, UINT uStartScan, UINT cScanLines, CONST VOID* lpvBits, CONST BITMAPINFO* lpbmi, UINT uColorUse) + { + ATLASSERT(m_hBitmap != NULL); + return ::SetDIBits(hDC, m_hBitmap, uStartScan, cScanLines, lpvBits, lpbmi, uColorUse); + } +}; + +typedef CBitmapT<false> CBitmapHandle; +typedef CBitmapT<true> CBitmap; + + +/////////////////////////////////////////////////////////////////////////////// +// CPalette + +template <bool t_bManaged> +class CPaletteT +{ +public: +// Data members + HPALETTE m_hPalette; + +// Constructor/destructor/operators + CPaletteT(HPALETTE hPalette = NULL) : m_hPalette(hPalette) + { } + + ~CPaletteT() + { + if(t_bManaged && (m_hPalette != NULL)) + DeleteObject(); + } + + CPaletteT<t_bManaged>& operator =(HPALETTE hPalette) + { + Attach(hPalette); + return *this; + } + + void Attach(HPALETTE hPalette) + { + if(t_bManaged && (m_hPalette != NULL) && (m_hPalette != hPalette)) + ::DeleteObject(m_hPalette); + m_hPalette = hPalette; + } + + HPALETTE Detach() + { + HPALETTE hPalette = m_hPalette; + m_hPalette = NULL; + return hPalette; + } + + operator HPALETTE() const { return m_hPalette; } + + bool IsNull() const { return (m_hPalette == NULL); } + +// Create methods + HPALETTE CreatePalette(LPLOGPALETTE lpLogPalette) + { + ATLASSERT(m_hPalette == NULL); + m_hPalette = ::CreatePalette(lpLogPalette); + return m_hPalette; + } + + HPALETTE CreateHalftonePalette(HDC hDC) + { + ATLASSERT(m_hPalette == NULL); + ATLASSERT(hDC != NULL); + m_hPalette = ::CreateHalftonePalette(hDC); + return m_hPalette; + } + + BOOL DeleteObject() + { + ATLASSERT(m_hPalette != NULL); + BOOL bRet = ::DeleteObject(m_hPalette); + if(bRet) + m_hPalette = NULL; + return bRet; + } + +// Attributes + int GetEntryCount() const + { + ATLASSERT(m_hPalette != NULL); + WORD nEntries = 0; + ::GetObject(m_hPalette, sizeof(WORD), &nEntries); + return (int)nEntries; + } + + UINT GetPaletteEntries(UINT nStartIndex, UINT nNumEntries, LPPALETTEENTRY lpPaletteColors) const + { + ATLASSERT(m_hPalette != NULL); + return ::GetPaletteEntries(m_hPalette, nStartIndex, nNumEntries, lpPaletteColors); + } + + UINT SetPaletteEntries(UINT nStartIndex, UINT nNumEntries, LPPALETTEENTRY lpPaletteColors) + { + ATLASSERT(m_hPalette != NULL); + return ::SetPaletteEntries(m_hPalette, nStartIndex, nNumEntries, lpPaletteColors); + } + +// Operations + void AnimatePalette(UINT nStartIndex, UINT nNumEntries, LPPALETTEENTRY lpPaletteColors) + { + ATLASSERT(m_hPalette != NULL); + ::AnimatePalette(m_hPalette, nStartIndex, nNumEntries, lpPaletteColors); + } + + BOOL ResizePalette(UINT nNumEntries) + { + ATLASSERT(m_hPalette != NULL); + return ::ResizePalette(m_hPalette, nNumEntries); + } + + UINT GetNearestPaletteIndex(COLORREF crColor) const + { + ATLASSERT(m_hPalette != NULL); + return ::GetNearestPaletteIndex(m_hPalette, crColor); + } +}; + +typedef CPaletteT<false> CPaletteHandle; +typedef CPaletteT<true> CPalette; + + +/////////////////////////////////////////////////////////////////////////////// +// CRgn + +template <bool t_bManaged> +class CRgnT +{ +public: +// Data members + HRGN m_hRgn; + +// Constructor/destructor/operators + CRgnT(HRGN hRgn = NULL) : m_hRgn(hRgn) + { } + + ~CRgnT() + { + if(t_bManaged && (m_hRgn != NULL)) + DeleteObject(); + } + + CRgnT<t_bManaged>& operator =(HRGN hRgn) + { + Attach(hRgn); + return *this; + } + + void Attach(HRGN hRgn) + { + if(t_bManaged && (m_hRgn != NULL) && (m_hRgn != hRgn)) + ::DeleteObject(m_hRgn); + m_hRgn = hRgn; + } + + HRGN Detach() + { + HRGN hRgn = m_hRgn; + m_hRgn = NULL; + return hRgn; + } + + operator HRGN() const { return m_hRgn; } + + bool IsNull() const { return (m_hRgn == NULL); } + +// Create methods + HRGN CreateRectRgn(int x1, int y1, int x2, int y2) + { + ATLASSERT(m_hRgn == NULL); + m_hRgn = ::CreateRectRgn(x1, y1, x2, y2); + return m_hRgn; + } + + HRGN CreateRectRgnIndirect(LPCRECT lpRect) + { + ATLASSERT(m_hRgn == NULL); + m_hRgn = ::CreateRectRgnIndirect(lpRect); + return m_hRgn; + } + + HRGN CreateEllipticRgn(int x1, int y1, int x2, int y2) + { + ATLASSERT(m_hRgn == NULL); + m_hRgn = ::CreateEllipticRgn(x1, y1, x2, y2); + return m_hRgn; + } + + HRGN CreateEllipticRgnIndirect(LPCRECT lpRect) + { + ATLASSERT(m_hRgn == NULL); + m_hRgn = ::CreateEllipticRgnIndirect(lpRect); + return m_hRgn; + } + + HRGN CreatePolygonRgn(const POINT* lpPoints, int nCount, int nMode) + { + ATLASSERT(m_hRgn == NULL); + m_hRgn = ::CreatePolygonRgn(lpPoints, nCount, nMode); + return m_hRgn; + } + + HRGN CreatePolyPolygonRgn(const POINT* lpPoints, const INT* lpPolyCounts, int nCount, int nPolyFillMode) + { + ATLASSERT(m_hRgn == NULL); + m_hRgn = ::CreatePolyPolygonRgn(lpPoints, lpPolyCounts, nCount, nPolyFillMode); + return m_hRgn; + } + + HRGN CreateRoundRectRgn(int x1, int y1, int x2, int y2, int x3, int y3) + { + ATLASSERT(m_hRgn == NULL); + m_hRgn = ::CreateRoundRectRgn(x1, y1, x2, y2, x3, y3); + return m_hRgn; + } + + HRGN CreateFromPath(HDC hDC) + { + ATLASSERT(m_hRgn == NULL); + ATLASSERT(hDC != NULL); + m_hRgn = ::PathToRegion(hDC); + return m_hRgn; + } + + HRGN CreateFromData(const XFORM* lpXForm, int nCount, const RGNDATA* pRgnData) + { + ATLASSERT(m_hRgn == NULL); + m_hRgn = ::ExtCreateRegion(lpXForm, nCount, pRgnData); + return m_hRgn; + } + + BOOL DeleteObject() + { + ATLASSERT(m_hRgn != NULL); + BOOL bRet = ::DeleteObject(m_hRgn); + if(bRet) + m_hRgn = NULL; + return bRet; + } + +// Operations + void SetRectRgn(int x1, int y1, int x2, int y2) + { + ATLASSERT(m_hRgn != NULL); + ::SetRectRgn(m_hRgn, x1, y1, x2, y2); + } + + void SetRectRgn(LPCRECT lpRect) + { + ATLASSERT(m_hRgn != NULL); + ::SetRectRgn(m_hRgn, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom); + } + + int CombineRgn(HRGN hRgnSrc1, HRGN hRgnSrc2, int nCombineMode) + { + ATLASSERT(m_hRgn != NULL); + return ::CombineRgn(m_hRgn, hRgnSrc1, hRgnSrc2, nCombineMode); + } + + int CombineRgn(HRGN hRgnSrc, int nCombineMode) + { + ATLASSERT(m_hRgn != NULL); + return ::CombineRgn(m_hRgn, m_hRgn, hRgnSrc, nCombineMode); + } + + int CopyRgn(HRGN hRgnSrc) + { + ATLASSERT(m_hRgn != NULL); + return ::CombineRgn(m_hRgn, hRgnSrc, NULL, RGN_COPY); + } + + BOOL EqualRgn(HRGN hRgn) const + { + ATLASSERT(m_hRgn != NULL); + return ::EqualRgn(m_hRgn, hRgn); + } + + int OffsetRgn(int x, int y) + { + ATLASSERT(m_hRgn != NULL); + return ::OffsetRgn(m_hRgn, x, y); + } + + int OffsetRgn(POINT point) + { + ATLASSERT(m_hRgn != NULL); + return ::OffsetRgn(m_hRgn, point.x, point.y); + } + + int GetRgnBox(LPRECT lpRect) const + { + ATLASSERT(m_hRgn != NULL); + return ::GetRgnBox(m_hRgn, lpRect); + } + + BOOL PtInRegion(int x, int y) const + { + ATLASSERT(m_hRgn != NULL); + return ::PtInRegion(m_hRgn, x, y); + } + + BOOL PtInRegion(POINT point) const + { + ATLASSERT(m_hRgn != NULL); + return ::PtInRegion(m_hRgn, point.x, point.y); + } + + BOOL RectInRegion(LPCRECT lpRect) const + { + ATLASSERT(m_hRgn != NULL); + return ::RectInRegion(m_hRgn, lpRect); + } + + int GetRegionData(LPRGNDATA lpRgnData, int nDataSize) const + { + ATLASSERT(m_hRgn != NULL); + return (int)::GetRegionData(m_hRgn, nDataSize, lpRgnData); + } +}; + +typedef CRgnT<false> CRgnHandle; +typedef CRgnT<true> CRgn; + + +/////////////////////////////////////////////////////////////////////////////// +// CDC - The device context class + +template <bool t_bManaged> +class CDCT; +typedef CDCT<false> CDCHandle; +typedef CDCT<true> CDC; + +template <bool t_bManaged> +class CDCT +{ +public: +// Data members + HDC m_hDC; + +// Constructor/destructor/operators + CDCT(HDC hDC = NULL) : m_hDC(hDC) + { + } + + ~CDCT() + { + if(t_bManaged && (m_hDC != NULL)) + ::DeleteDC(Detach()); + } + + CDCT<t_bManaged>& operator =(HDC hDC) + { + Attach(hDC); + return *this; + } + + void Attach(HDC hDC) + { + if(t_bManaged && (m_hDC != NULL) && (m_hDC != hDC)) + ::DeleteDC(m_hDC); + m_hDC = hDC; + } + + HDC Detach() + { + HDC hDC = m_hDC; + m_hDC = NULL; + return hDC; + } + + operator HDC() const { return m_hDC; } + + bool IsNull() const { return (m_hDC == NULL); } + +// Operations + HWND WindowFromDC() const + { + ATLASSERT(m_hDC != NULL); + return ::WindowFromDC(m_hDC); + } + + CPenHandle GetCurrentPen() const + { + ATLASSERT(m_hDC != NULL); + return CPenHandle((HPEN)::GetCurrentObject(m_hDC, OBJ_PEN)); + } + + CBrushHandle GetCurrentBrush() const + { + ATLASSERT(m_hDC != NULL); + return CBrushHandle((HBRUSH)::GetCurrentObject(m_hDC, OBJ_BRUSH)); + } + + CPaletteHandle GetCurrentPalette() const + { + ATLASSERT(m_hDC != NULL); + return CPaletteHandle((HPALETTE)::GetCurrentObject(m_hDC, OBJ_PAL)); + } + + CFontHandle GetCurrentFont() const + { + ATLASSERT(m_hDC != NULL); + return CFontHandle((HFONT)::GetCurrentObject(m_hDC, OBJ_FONT)); + } + + CBitmapHandle GetCurrentBitmap() const + { + ATLASSERT(m_hDC != NULL); + return CBitmapHandle((HBITMAP)::GetCurrentObject(m_hDC, OBJ_BITMAP)); + } + + HDC CreateDC(LPCTSTR lpszDriverName, LPCTSTR lpszDeviceName, LPCTSTR lpszOutput, const DEVMODE* lpInitData) + { + ATLASSERT(m_hDC == NULL); + m_hDC = ::CreateDC(lpszDriverName, lpszDeviceName, lpszOutput, lpInitData); + return m_hDC; + } + + HDC CreateCompatibleDC(HDC hDC = NULL) + { + ATLASSERT(m_hDC == NULL); + m_hDC = ::CreateCompatibleDC(hDC); + return m_hDC; + } + + BOOL DeleteDC() + { + if(m_hDC == NULL) + return FALSE; + BOOL bRet = ::DeleteDC(m_hDC); + if(bRet) + m_hDC = NULL; + return bRet; + } + +// Device-Context Functions + int SaveDC() + { + ATLASSERT(m_hDC != NULL); + return ::SaveDC(m_hDC); + } + + BOOL RestoreDC(int nSavedDC) + { + ATLASSERT(m_hDC != NULL); + return ::RestoreDC(m_hDC, nSavedDC); + } + + int GetDeviceCaps(int nIndex) const + { + ATLASSERT(m_hDC != NULL); + return ::GetDeviceCaps(m_hDC, nIndex); + } + + UINT SetBoundsRect(LPCRECT lpRectBounds, UINT flags) + { + ATLASSERT(m_hDC != NULL); + return ::SetBoundsRect(m_hDC, lpRectBounds, flags); + } + + UINT GetBoundsRect(LPRECT lpRectBounds, UINT flags) const + { + ATLASSERT(m_hDC != NULL); + return ::GetBoundsRect(m_hDC, lpRectBounds, flags); + } + + BOOL ResetDC(const DEVMODE* lpDevMode) + { + ATLASSERT(m_hDC != NULL); + return ::ResetDC(m_hDC, lpDevMode) != NULL; + } + +// Drawing-Tool Functions + BOOL GetBrushOrg(LPPOINT lpPoint) const + { + ATLASSERT(m_hDC != NULL); + return ::GetBrushOrgEx(m_hDC, lpPoint); + } + + BOOL SetBrushOrg(int x, int y, LPPOINT lpPoint = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::SetBrushOrgEx(m_hDC, x, y, lpPoint); + } + + BOOL SetBrushOrg(POINT point, LPPOINT lpPointRet = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::SetBrushOrgEx(m_hDC, point.x, point.y, lpPointRet); + } + + int EnumObjects(int nObjectType, int (CALLBACK* lpfn)(LPVOID, LPARAM), LPARAM lpData) + { + ATLASSERT(m_hDC != NULL); +#ifdef STRICT + return ::EnumObjects(m_hDC, nObjectType, (GOBJENUMPROC)lpfn, lpData); +#else + return ::EnumObjects(m_hDC, nObjectType, (GOBJENUMPROC)lpfn, (LPVOID)lpData); +#endif + } + +// Type-safe selection helpers + HPEN SelectPen(HPEN hPen) + { + ATLASSERT(m_hDC != NULL); + ATLASSERT((hPen == NULL) || (::GetObjectType(hPen) == OBJ_PEN) || (::GetObjectType(hPen) == OBJ_EXTPEN)); + return (HPEN)::SelectObject(m_hDC, hPen); + } + + HBRUSH SelectBrush(HBRUSH hBrush) + { + ATLASSERT(m_hDC != NULL); + ATLASSERT((hBrush == NULL) || (::GetObjectType(hBrush) == OBJ_BRUSH)); + return (HBRUSH)::SelectObject(m_hDC, hBrush); + } + + HFONT SelectFont(HFONT hFont) + { + ATLASSERT(m_hDC != NULL); + ATLASSERT((hFont == NULL) || (::GetObjectType(hFont) == OBJ_FONT)); + return (HFONT)::SelectObject(m_hDC, hFont); + } + + HBITMAP SelectBitmap(HBITMAP hBitmap) + { + ATLASSERT(m_hDC != NULL); + ATLASSERT((hBitmap == NULL) || (::GetObjectType(hBitmap) == OBJ_BITMAP)); + return (HBITMAP)::SelectObject(m_hDC, hBitmap); + } + + int SelectRgn(HRGN hRgn) // special return for regions + { + ATLASSERT(m_hDC != NULL); + ATLASSERT((hRgn == NULL) || (::GetObjectType(hRgn) == OBJ_REGION)); + return PtrToInt(::SelectObject(m_hDC, hRgn)); + } + +// Type-safe selection helpers for stock objects + HPEN SelectStockPen(int nPen) + { + ATLASSERT(m_hDC != NULL); + ATLASSERT((nPen == WHITE_PEN) || (nPen == BLACK_PEN) || (nPen == NULL_PEN) || (nPen == DC_PEN)); + return SelectPen((HPEN)::GetStockObject(nPen)); + } + + HBRUSH SelectStockBrush(int nBrush) + { + ATLASSERT(((nBrush >= WHITE_BRUSH) && (nBrush <= HOLLOW_BRUSH)) || (nBrush == DC_BRUSH)); + return SelectBrush((HBRUSH)::GetStockObject(nBrush)); + } + + HFONT SelectStockFont(int nFont) + { + ATLASSERT(((nFont >= OEM_FIXED_FONT) && (nFont <= SYSTEM_FIXED_FONT)) || (nFont == DEFAULT_GUI_FONT)); + return SelectFont((HFONT)::GetStockObject(nFont)); + } + + HPALETTE SelectStockPalette(int nPalette, BOOL bForceBackground) + { + ATLASSERT(nPalette == DEFAULT_PALETTE); // the only one supported + return SelectPalette((HPALETTE)::GetStockObject(nPalette), bForceBackground); + } + +// Color and Color Palette Functions + COLORREF GetNearestColor(COLORREF crColor) const + { + ATLASSERT(m_hDC != NULL); + return ::GetNearestColor(m_hDC, crColor); + } + + HPALETTE SelectPalette(HPALETTE hPalette, BOOL bForceBackground) + { + ATLASSERT(m_hDC != NULL); + + return ::SelectPalette(m_hDC, hPalette, bForceBackground); + } + + UINT RealizePalette() + { + ATLASSERT(m_hDC != NULL); + return ::RealizePalette(m_hDC); + } + + void UpdateColors() + { + ATLASSERT(m_hDC != NULL); + ::UpdateColors(m_hDC); + } + +// Drawing-Attribute Functions + COLORREF GetBkColor() const + { + ATLASSERT(m_hDC != NULL); + return ::GetBkColor(m_hDC); + } + + int GetBkMode() const + { + ATLASSERT(m_hDC != NULL); + return ::GetBkMode(m_hDC); + } + + int GetPolyFillMode() const + { + ATLASSERT(m_hDC != NULL); + return ::GetPolyFillMode(m_hDC); + } + + int GetROP2() const + { + ATLASSERT(m_hDC != NULL); + return ::GetROP2(m_hDC); + } + + int GetStretchBltMode() const + { + ATLASSERT(m_hDC != NULL); + return ::GetStretchBltMode(m_hDC); + } + + COLORREF GetTextColor() const + { + ATLASSERT(m_hDC != NULL); + return ::GetTextColor(m_hDC); + } + + COLORREF SetBkColor(COLORREF crColor) + { + ATLASSERT(m_hDC != NULL); + return ::SetBkColor(m_hDC, crColor); + } + + int SetBkMode(int nBkMode) + { + ATLASSERT(m_hDC != NULL); + return ::SetBkMode(m_hDC, nBkMode); + } + + int SetPolyFillMode(int nPolyFillMode) + { + ATLASSERT(m_hDC != NULL); + return ::SetPolyFillMode(m_hDC, nPolyFillMode); + } + + int SetROP2(int nDrawMode) + { + ATLASSERT(m_hDC != NULL); + return ::SetROP2(m_hDC, nDrawMode); + } + + int SetStretchBltMode(int nStretchMode) + { + ATLASSERT(m_hDC != NULL); + return ::SetStretchBltMode(m_hDC, nStretchMode); + } + + COLORREF SetTextColor(COLORREF crColor) + { + ATLASSERT(m_hDC != NULL); + return ::SetTextColor(m_hDC, crColor); + } + + BOOL GetColorAdjustment(LPCOLORADJUSTMENT lpColorAdjust) const + { + ATLASSERT(m_hDC != NULL); + return ::GetColorAdjustment(m_hDC, lpColorAdjust); + } + + BOOL SetColorAdjustment(const COLORADJUSTMENT* lpColorAdjust) + { + ATLASSERT(m_hDC != NULL); + return ::SetColorAdjustment(m_hDC, lpColorAdjust); + } + +// Mapping Functions + int GetMapMode() const + { + ATLASSERT(m_hDC != NULL); + return ::GetMapMode(m_hDC); + } + + BOOL GetViewportOrg(LPPOINT lpPoint) const + { + ATLASSERT(m_hDC != NULL); + return ::GetViewportOrgEx(m_hDC, lpPoint); + } + + int SetMapMode(int nMapMode) + { + ATLASSERT(m_hDC != NULL); + return ::SetMapMode(m_hDC, nMapMode); + } + + // Viewport Origin + BOOL SetViewportOrg(int x, int y, LPPOINT lpPoint = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::SetViewportOrgEx(m_hDC, x, y, lpPoint); + } + + BOOL SetViewportOrg(POINT point, LPPOINT lpPointRet = NULL) + { + ATLASSERT(m_hDC != NULL); + return SetViewportOrg(point.x, point.y, lpPointRet); + } + + BOOL OffsetViewportOrg(int nWidth, int nHeight, LPPOINT lpPoint = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::OffsetViewportOrgEx(m_hDC, nWidth, nHeight, lpPoint); + } + + // Viewport Extent + BOOL GetViewportExt(LPSIZE lpSize) const + { + ATLASSERT(m_hDC != NULL); + return ::GetViewportExtEx(m_hDC, lpSize); + } + + BOOL SetViewportExt(int x, int y, LPSIZE lpSize = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::SetViewportExtEx(m_hDC, x, y, lpSize); + } + + BOOL SetViewportExt(SIZE size, LPSIZE lpSizeRet = NULL) + { + ATLASSERT(m_hDC != NULL); + return SetViewportExt(size.cx, size.cy, lpSizeRet); + } + + BOOL ScaleViewportExt(int xNum, int xDenom, int yNum, int yDenom, LPSIZE lpSize = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::ScaleViewportExtEx(m_hDC, xNum, xDenom, yNum, yDenom, lpSize); + } + + // Window Origin + BOOL GetWindowOrg(LPPOINT lpPoint) const + { + ATLASSERT(m_hDC != NULL); + return ::GetWindowOrgEx(m_hDC, lpPoint); + } + + BOOL SetWindowOrg(int x, int y, LPPOINT lpPoint = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::SetWindowOrgEx(m_hDC, x, y, lpPoint); + } + + BOOL SetWindowOrg(POINT point, LPPOINT lpPointRet = NULL) + { + ATLASSERT(m_hDC != NULL); + return SetWindowOrg(point.x, point.y, lpPointRet); + } + + BOOL OffsetWindowOrg(int nWidth, int nHeight, LPPOINT lpPoint = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::OffsetWindowOrgEx(m_hDC, nWidth, nHeight, lpPoint); + } + + // Window extent + BOOL GetWindowExt(LPSIZE lpSize) const + { + ATLASSERT(m_hDC != NULL); + return ::GetWindowExtEx(m_hDC, lpSize); + } + + BOOL SetWindowExt(int x, int y, LPSIZE lpSize = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::SetWindowExtEx(m_hDC, x, y, lpSize); + } + + BOOL SetWindowExt(SIZE size, LPSIZE lpSizeRet = NULL) + { + ATLASSERT(m_hDC != NULL); + return SetWindowExt(size.cx, size.cy, lpSizeRet); + } + + BOOL ScaleWindowExt(int xNum, int xDenom, int yNum, int yDenom, LPSIZE lpSize = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::ScaleWindowExtEx(m_hDC, xNum, xDenom, yNum, yDenom, lpSize); + } + +// Coordinate Functions + BOOL DPtoLP(LPPOINT lpPoints, int nCount = 1) const + { + ATLASSERT(m_hDC != NULL); + return ::DPtoLP(m_hDC, lpPoints, nCount); + } + + BOOL DPtoLP(LPRECT lpRect) const + { + ATLASSERT(m_hDC != NULL); + return ::DPtoLP(m_hDC, (LPPOINT)lpRect, 2); + } + + BOOL DPtoLP(LPSIZE lpSize) const + { + SIZE sizeWinExt = {}; + if(!GetWindowExt(&sizeWinExt)) + return FALSE; + SIZE sizeVpExt = {}; + if(!GetViewportExt(&sizeVpExt)) + return FALSE; + lpSize->cx = ::MulDiv(lpSize->cx, abs(sizeWinExt.cx), abs(sizeVpExt.cx)); + lpSize->cy = ::MulDiv(lpSize->cy, abs(sizeWinExt.cy), abs(sizeVpExt.cy)); + return TRUE; + } + + BOOL LPtoDP(LPPOINT lpPoints, int nCount = 1) const + { + ATLASSERT(m_hDC != NULL); + return ::LPtoDP(m_hDC, lpPoints, nCount); + } + + BOOL LPtoDP(LPRECT lpRect) const + { + ATLASSERT(m_hDC != NULL); + return ::LPtoDP(m_hDC, (LPPOINT)lpRect, 2); + } + + BOOL LPtoDP(LPSIZE lpSize) const + { + SIZE sizeWinExt = {}; + if(!GetWindowExt(&sizeWinExt)) + return FALSE; + SIZE sizeVpExt = {}; + if(!GetViewportExt(&sizeVpExt)) + return FALSE; + lpSize->cx = ::MulDiv(lpSize->cx, abs(sizeVpExt.cx), abs(sizeWinExt.cx)); + lpSize->cy = ::MulDiv(lpSize->cy, abs(sizeVpExt.cy), abs(sizeWinExt.cy)); + return TRUE; + } + +// Special Coordinate Functions (useful for dealing with metafiles and OLE) + #define HIMETRIC_INCH 2540 // HIMETRIC units per inch + + void DPtoHIMETRIC(LPSIZE lpSize) + { + ATLASSERT(m_hDC != NULL); + int nMapMode = GetMapMode(); + if((nMapMode < MM_ISOTROPIC) && (nMapMode != MM_TEXT)) + { + // when using a constrained map mode, map against physical inch + SetMapMode(MM_HIMETRIC); + DPtoLP(lpSize); + SetMapMode(nMapMode); + } + else + { + // map against logical inch for non-constrained mapping modes + int cxPerInch = GetDeviceCaps(LOGPIXELSX); + int cyPerInch = GetDeviceCaps(LOGPIXELSY); + ATLASSERT((cxPerInch != 0) && (cyPerInch != 0)); + lpSize->cx = ::MulDiv(lpSize->cx, HIMETRIC_INCH, cxPerInch); + lpSize->cy = ::MulDiv(lpSize->cy, HIMETRIC_INCH, cyPerInch); + } + } + + void HIMETRICtoDP(LPSIZE lpSize) + { + ATLASSERT(m_hDC != NULL); + int nMapMode = GetMapMode(); + if((nMapMode < MM_ISOTROPIC) && (nMapMode != MM_TEXT)) + { + // when using a constrained map mode, map against physical inch + SetMapMode(MM_HIMETRIC); + LPtoDP(lpSize); + SetMapMode(nMapMode); + } + else + { + // map against logical inch for non-constrained mapping modes + int cxPerInch = GetDeviceCaps(LOGPIXELSX); + int cyPerInch = GetDeviceCaps(LOGPIXELSY); + ATLASSERT((cxPerInch != 0) && (cyPerInch != 0)); + lpSize->cx = ::MulDiv(lpSize->cx, cxPerInch, HIMETRIC_INCH); + lpSize->cy = ::MulDiv(lpSize->cy, cyPerInch, HIMETRIC_INCH); + } + } + + void LPtoHIMETRIC(LPSIZE lpSize) + { + LPtoDP(lpSize); + DPtoHIMETRIC(lpSize); + } + + void HIMETRICtoLP(LPSIZE lpSize) + { + HIMETRICtoDP(lpSize); + DPtoLP(lpSize); + } + +// Region Functions + BOOL FillRgn(HRGN hRgn, HBRUSH hBrush) + { + ATLASSERT(m_hDC != NULL); + return ::FillRgn(m_hDC, hRgn, hBrush); + } + + BOOL FrameRgn(HRGN hRgn, HBRUSH hBrush, int nWidth, int nHeight) + { + ATLASSERT(m_hDC != NULL); + return ::FrameRgn(m_hDC, hRgn, hBrush, nWidth, nHeight); + } + + BOOL InvertRgn(HRGN hRgn) + { + ATLASSERT(m_hDC != NULL); + return ::InvertRgn(m_hDC, hRgn); + } + + BOOL PaintRgn(HRGN hRgn) + { + ATLASSERT(m_hDC != NULL); + return ::PaintRgn(m_hDC, hRgn); + } + +// Clipping Functions + int GetClipBox(LPRECT lpRect) const + { + ATLASSERT(m_hDC != NULL); + return ::GetClipBox(m_hDC, lpRect); + } + + int GetClipRgn(CRgn& region) const + { + ATLASSERT(m_hDC != NULL); + if(region.IsNull()) + region.CreateRectRgn(0, 0, 0, 0); + + int nRet = ::GetClipRgn(m_hDC, region); + if(nRet != 1) + region.DeleteObject(); + + return nRet; + } + + BOOL PtVisible(int x, int y) const + { + ATLASSERT(m_hDC != NULL); + return ::PtVisible(m_hDC, x, y); + } + + BOOL PtVisible(POINT point) const + { + ATLASSERT(m_hDC != NULL); + return ::PtVisible(m_hDC, point.x, point.y); + } + + BOOL RectVisible(LPCRECT lpRect) const + { + ATLASSERT(m_hDC != NULL); + return ::RectVisible(m_hDC, lpRect); + } + + int SelectClipRgn(HRGN hRgn) + { + ATLASSERT(m_hDC != NULL); + return ::SelectClipRgn(m_hDC, (HRGN)hRgn); + } + + int ExcludeClipRect(int x1, int y1, int x2, int y2) + { + ATLASSERT(m_hDC != NULL); + return ::ExcludeClipRect(m_hDC, x1, y1, x2, y2); + } + + int ExcludeClipRect(LPCRECT lpRect) + { + ATLASSERT(m_hDC != NULL); + return ::ExcludeClipRect(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom); + } + + int ExcludeUpdateRgn(HWND hWnd) + { + ATLASSERT(m_hDC != NULL); + return ::ExcludeUpdateRgn(m_hDC, hWnd); + } + + int IntersectClipRect(int x1, int y1, int x2, int y2) + { + ATLASSERT(m_hDC != NULL); + return ::IntersectClipRect(m_hDC, x1, y1, x2, y2); + } + + int IntersectClipRect(LPCRECT lpRect) + { + ATLASSERT(m_hDC != NULL); + return ::IntersectClipRect(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom); + } + + int OffsetClipRgn(int x, int y) + { + ATLASSERT(m_hDC != NULL); + return ::OffsetClipRgn(m_hDC, x, y); + } + + int OffsetClipRgn(SIZE size) + { + ATLASSERT(m_hDC != NULL); + return ::OffsetClipRgn(m_hDC, size.cx, size.cy); + } + + int SelectClipRgn(HRGN hRgn, int nMode) + { + ATLASSERT(m_hDC != NULL); + return ::ExtSelectClipRgn(m_hDC, hRgn, nMode); + } + +// Line-Output Functions + BOOL GetCurrentPosition(LPPOINT lpPoint) const + { + ATLASSERT(m_hDC != NULL); + return ::GetCurrentPositionEx(m_hDC, lpPoint); + } + + BOOL MoveTo(int x, int y, LPPOINT lpPoint = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::MoveToEx(m_hDC, x, y, lpPoint); + } + + BOOL MoveTo(POINT point, LPPOINT lpPointRet = NULL) + { + ATLASSERT(m_hDC != NULL); + return MoveTo(point.x, point.y, lpPointRet); + } + + BOOL LineTo(int x, int y) + { + ATLASSERT(m_hDC != NULL); + return ::LineTo(m_hDC, x, y); + } + + BOOL LineTo(POINT point) + { + ATLASSERT(m_hDC != NULL); + return LineTo(point.x, point.y); + } + + BOOL Arc(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4) + { + ATLASSERT(m_hDC != NULL); + return ::Arc(m_hDC, x1, y1, x2, y2, x3, y3, x4, y4); + } + + BOOL Arc(LPCRECT lpRect, POINT ptStart, POINT ptEnd) + { + ATLASSERT(m_hDC != NULL); + return ::Arc(m_hDC, lpRect->left, lpRect->top, + lpRect->right, lpRect->bottom, ptStart.x, ptStart.y, + ptEnd.x, ptEnd.y); + } + + BOOL Polyline(const POINT* lpPoints, int nCount) + { + ATLASSERT(m_hDC != NULL); + return ::Polyline(m_hDC, lpPoints, nCount); + } + + BOOL AngleArc(int x, int y, int nRadius, float fStartAngle, float fSweepAngle) + { + ATLASSERT(m_hDC != NULL); + return ::AngleArc(m_hDC, x, y, nRadius, fStartAngle, fSweepAngle); + } + + BOOL ArcTo(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4) + { + ATLASSERT(m_hDC != NULL); + return ::ArcTo(m_hDC, x1, y1, x2, y2, x3, y3, x4, y4); + } + + BOOL ArcTo(LPCRECT lpRect, POINT ptStart, POINT ptEnd) + { + ATLASSERT(m_hDC != NULL); + return ArcTo(lpRect->left, lpRect->top, lpRect->right, + lpRect->bottom, ptStart.x, ptStart.y, ptEnd.x, ptEnd.y); + } + + int GetArcDirection() const + { + ATLASSERT(m_hDC != NULL); + return ::GetArcDirection(m_hDC); + } + + int SetArcDirection(int nArcDirection) + { + ATLASSERT(m_hDC != NULL); + return ::SetArcDirection(m_hDC, nArcDirection); + } + + BOOL PolyDraw(const POINT* lpPoints, const BYTE* lpTypes, int nCount) + { + ATLASSERT(m_hDC != NULL); + return ::PolyDraw(m_hDC, lpPoints, lpTypes, nCount); + } + + BOOL PolylineTo(const POINT* lpPoints, int nCount) + { + ATLASSERT(m_hDC != NULL); + return ::PolylineTo(m_hDC, lpPoints, nCount); + } + + BOOL PolyPolyline(const POINT* lpPoints, + const DWORD* lpPolyPoints, int nCount) + { + ATLASSERT(m_hDC != NULL); + return ::PolyPolyline(m_hDC, lpPoints, lpPolyPoints, nCount); + } + + BOOL PolyBezier(const POINT* lpPoints, int nCount) + { + ATLASSERT(m_hDC != NULL); + return ::PolyBezier(m_hDC, lpPoints, nCount); + } + + BOOL PolyBezierTo(const POINT* lpPoints, int nCount) + { + ATLASSERT(m_hDC != NULL); + return ::PolyBezierTo(m_hDC, lpPoints, nCount); + } + +// Simple Drawing Functions + BOOL FillRect(LPCRECT lpRect, HBRUSH hBrush) + { + ATLASSERT(m_hDC != NULL); + return ::FillRect(m_hDC, lpRect, hBrush); + } + + BOOL FillRect(LPCRECT lpRect, int nColorIndex) + { + ATLASSERT(m_hDC != NULL); + return ::FillRect(m_hDC, lpRect, (HBRUSH)LongToPtr(nColorIndex + 1)); + } + + BOOL FrameRect(LPCRECT lpRect, HBRUSH hBrush) + { + ATLASSERT(m_hDC != NULL); + return ::FrameRect(m_hDC, lpRect, hBrush); + } + + BOOL InvertRect(LPCRECT lpRect) + { + ATLASSERT(m_hDC != NULL); + return ::InvertRect(m_hDC, lpRect); + } + + BOOL DrawIcon(int x, int y, HICON hIcon) + { + ATLASSERT(m_hDC != NULL); + return ::DrawIcon(m_hDC, x, y, hIcon); + } + + BOOL DrawIcon(POINT point, HICON hIcon) + { + ATLASSERT(m_hDC != NULL); + return ::DrawIcon(m_hDC, point.x, point.y, hIcon); + } + + BOOL DrawIconEx(int x, int y, HICON hIcon, int cxWidth, int cyWidth, UINT uStepIfAniCur = 0, HBRUSH hbrFlickerFreeDraw = NULL, UINT uFlags = DI_NORMAL) + { + ATLASSERT(m_hDC != NULL); + return ::DrawIconEx(m_hDC, x, y, hIcon, cxWidth, cyWidth, uStepIfAniCur, hbrFlickerFreeDraw, uFlags); + } + + BOOL DrawIconEx(POINT point, HICON hIcon, SIZE size, UINT uStepIfAniCur = 0, HBRUSH hbrFlickerFreeDraw = NULL, UINT uFlags = DI_NORMAL) + { + ATLASSERT(m_hDC != NULL); + return ::DrawIconEx(m_hDC, point.x, point.y, hIcon, size.cx, size.cy, uStepIfAniCur, hbrFlickerFreeDraw, uFlags); + } + + BOOL DrawState(POINT pt, SIZE size, HBITMAP hBitmap, UINT nFlags, HBRUSH hBrush = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::DrawState(m_hDC, hBrush, NULL, (LPARAM)hBitmap, 0, pt.x, pt.y, size.cx, size.cy, nFlags | DST_BITMAP); + } + + BOOL DrawState(POINT pt, SIZE size, HICON hIcon, UINT nFlags, HBRUSH hBrush = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::DrawState(m_hDC, hBrush, NULL, (LPARAM)hIcon, 0, pt.x, pt.y, size.cx, size.cy, nFlags | DST_ICON); + } + + BOOL DrawState(POINT pt, SIZE size, LPCTSTR lpszText, UINT nFlags, BOOL bPrefixText = TRUE, int nTextLen = 0, HBRUSH hBrush = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::DrawState(m_hDC, hBrush, NULL, (LPARAM)lpszText, (WPARAM)nTextLen, pt.x, pt.y, size.cx, size.cy, nFlags | (bPrefixText ? DST_PREFIXTEXT : DST_TEXT)); + } + + BOOL DrawState(POINT pt, SIZE size, DRAWSTATEPROC lpDrawProc, LPARAM lData, UINT nFlags, HBRUSH hBrush = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::DrawState(m_hDC, hBrush, lpDrawProc, lData, 0, pt.x, pt.y, size.cx, size.cy, nFlags | DST_COMPLEX); + } + +// Ellipse and Polygon Functions + BOOL Chord(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4) + { + ATLASSERT(m_hDC != NULL); + return ::Chord(m_hDC, x1, y1, x2, y2, x3, y3, x4, y4); + } + + BOOL Chord(LPCRECT lpRect, POINT ptStart, POINT ptEnd) + { + ATLASSERT(m_hDC != NULL); + return ::Chord(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom, ptStart.x, ptStart.y, ptEnd.x, ptEnd.y); + } + + void DrawFocusRect(LPCRECT lpRect) + { + ATLASSERT(m_hDC != NULL); + ::DrawFocusRect(m_hDC, lpRect); + } + + BOOL Ellipse(int x1, int y1, int x2, int y2) + { + ATLASSERT(m_hDC != NULL); + return ::Ellipse(m_hDC, x1, y1, x2, y2); + } + + BOOL Ellipse(LPCRECT lpRect) + { + ATLASSERT(m_hDC != NULL); + return ::Ellipse(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom); + } + + BOOL Pie(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4) + { + ATLASSERT(m_hDC != NULL); + return ::Pie(m_hDC, x1, y1, x2, y2, x3, y3, x4, y4); + } + + BOOL Pie(LPCRECT lpRect, POINT ptStart, POINT ptEnd) + { + ATLASSERT(m_hDC != NULL); + return ::Pie(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom, ptStart.x, ptStart.y, ptEnd.x, ptEnd.y); + } + + BOOL Polygon(const POINT* lpPoints, int nCount) + { + ATLASSERT(m_hDC != NULL); + return ::Polygon(m_hDC, lpPoints, nCount); + } + + BOOL PolyPolygon(const POINT* lpPoints, const INT* lpPolyCounts, int nCount) + { + ATLASSERT(m_hDC != NULL); + return ::PolyPolygon(m_hDC, lpPoints, lpPolyCounts, nCount); + } + + BOOL Rectangle(int x1, int y1, int x2, int y2) + { + ATLASSERT(m_hDC != NULL); + return ::Rectangle(m_hDC, x1, y1, x2, y2); + } + + BOOL Rectangle(LPCRECT lpRect) + { + ATLASSERT(m_hDC != NULL); + return ::Rectangle(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom); + } + + BOOL RoundRect(int x1, int y1, int x2, int y2, int x3, int y3) + { + ATLASSERT(m_hDC != NULL); + return ::RoundRect(m_hDC, x1, y1, x2, y2, x3, y3); + } + + BOOL RoundRect(LPCRECT lpRect, POINT point) + { + ATLASSERT(m_hDC != NULL); + return ::RoundRect(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom, point.x, point.y); + } + +// Bitmap Functions + BOOL PatBlt(int x, int y, int nWidth, int nHeight, DWORD dwRop) + { + ATLASSERT(m_hDC != NULL); + return ::PatBlt(m_hDC, x, y, nWidth, nHeight, dwRop); + } + + BOOL BitBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, + int xSrc, int ySrc, DWORD dwRop) + { + ATLASSERT(m_hDC != NULL); + return ::BitBlt(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, dwRop); + } + + BOOL StretchBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, int xSrc, int ySrc, int nSrcWidth, int nSrcHeight, DWORD dwRop) + { + ATLASSERT(m_hDC != NULL); + return ::StretchBlt(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, nSrcWidth, nSrcHeight, dwRop); + } + + COLORREF GetPixel(int x, int y) const + { + ATLASSERT(m_hDC != NULL); + return ::GetPixel(m_hDC, x, y); + } + + COLORREF GetPixel(POINT point) const + { + ATLASSERT(m_hDC != NULL); + return ::GetPixel(m_hDC, point.x, point.y); + } + + COLORREF SetPixel(int x, int y, COLORREF crColor) + { + ATLASSERT(m_hDC != NULL); + return ::SetPixel(m_hDC, x, y, crColor); + } + + COLORREF SetPixel(POINT point, COLORREF crColor) + { + ATLASSERT(m_hDC != NULL); + return ::SetPixel(m_hDC, point.x, point.y, crColor); + } + + BOOL FloodFill(int x, int y, COLORREF crColor) + { + ATLASSERT(m_hDC != NULL); + return ::FloodFill(m_hDC, x, y, crColor); + } + + BOOL ExtFloodFill(int x, int y, COLORREF crColor, UINT nFillType) + { + ATLASSERT(m_hDC != NULL); + return ::ExtFloodFill(m_hDC, x, y, crColor, nFillType); + } + + BOOL MaskBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, int xSrc, int ySrc, HBITMAP hMaskBitmap, int xMask, int yMask, DWORD dwRop) + { + ATLASSERT(m_hDC != NULL); + return ::MaskBlt(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, hMaskBitmap, xMask, yMask, dwRop); + } + + BOOL PlgBlt(LPPOINT lpPoint, HDC hSrcDC, int xSrc, int ySrc, int nWidth, int nHeight, HBITMAP hMaskBitmap, int xMask, int yMask) + { + ATLASSERT(m_hDC != NULL); + return ::PlgBlt(m_hDC, lpPoint, hSrcDC, xSrc, ySrc, nWidth, nHeight, hMaskBitmap, xMask, yMask); + } + + BOOL SetPixelV(int x, int y, COLORREF crColor) + { + ATLASSERT(m_hDC != NULL); + return ::SetPixelV(m_hDC, x, y, crColor); + } + + BOOL SetPixelV(POINT point, COLORREF crColor) + { + ATLASSERT(m_hDC != NULL); + return ::SetPixelV(m_hDC, point.x, point.y, crColor); + } + + BOOL TransparentBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, int xSrc, int ySrc, int nSrcWidth, int nSrcHeight, UINT crTransparent) + { + ATLASSERT(m_hDC != NULL); + return ::TransparentBlt(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, nSrcWidth, nSrcHeight, crTransparent); + } + + BOOL GradientFill(const PTRIVERTEX pVertices, DWORD nVertices, void* pMeshElements, DWORD nMeshElements, DWORD dwMode) + { + ATLASSERT(m_hDC != NULL); + return ::GradientFill(m_hDC, pVertices, nVertices, pMeshElements, nMeshElements, dwMode); + } + + BOOL GradientFillRect(RECT& rect, COLORREF clr1, COLORREF clr2, bool bHorizontal) + { + ATLASSERT(m_hDC != NULL); + + TRIVERTEX arrTvx[2] = { { 0 }, { 0 } }; + + arrTvx[0].x = rect.left; + arrTvx[0].y = rect.top; + arrTvx[0].Red = MAKEWORD(0, GetRValue(clr1)); + arrTvx[0].Green = MAKEWORD(0, GetGValue(clr1)); + arrTvx[0].Blue = MAKEWORD(0, GetBValue(clr1)); + arrTvx[0].Alpha = 0; + + arrTvx[1].x = rect.right; + arrTvx[1].y = rect.bottom; + arrTvx[1].Red = MAKEWORD(0, GetRValue(clr2)); + arrTvx[1].Green = MAKEWORD(0, GetGValue(clr2)); + arrTvx[1].Blue = MAKEWORD(0, GetBValue(clr2)); + arrTvx[1].Alpha = 0; + + GRADIENT_RECT gr = { 0, 1 }; + + return ::GradientFill(m_hDC, arrTvx, 2, &gr, 1, bHorizontal ? GRADIENT_FILL_RECT_H : GRADIENT_FILL_RECT_V); + } + + BOOL AlphaBlend(int x, int y, int nWidth, int nHeight, HDC hSrcDC, int xSrc, int ySrc, int nSrcWidth, int nSrcHeight, BLENDFUNCTION bf) + { + ATLASSERT(m_hDC != NULL); + return ::AlphaBlend(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, nSrcWidth, nSrcHeight, bf); + } + +// Extra bitmap functions + // Helper function for painting a disabled toolbar or menu bitmap + // This function can take either an HBITMAP (for SS) or a DC with + // the bitmap already painted (for cmdbar) + BOOL DitherBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, HBITMAP hBitmap, int xSrc, int ySrc, + HBRUSH hBrushBackground = ::GetSysColorBrush(COLOR_3DFACE), + HBRUSH hBrush3DEffect = ::GetSysColorBrush(COLOR_3DHILIGHT), + HBRUSH hBrushDisabledImage = ::GetSysColorBrush(COLOR_3DSHADOW)) + { + ATLASSERT((m_hDC != NULL) || (hBitmap != NULL)); + ATLASSERT((nWidth > 0) && (nHeight > 0)); + + // Create a generic DC for all BitBlts + CDCT<false> dc = (hSrcDC != NULL) ? hSrcDC : ::CreateCompatibleDC(m_hDC); + ATLASSERT(dc.m_hDC != NULL); + if(dc.m_hDC == NULL) + return FALSE; + + // Create a DC for the monochrome DIB section + CDCT<true> dcBW = ::CreateCompatibleDC(m_hDC); + ATLASSERT(dcBW.m_hDC != NULL); + if(dcBW.m_hDC == NULL) + { + if(hSrcDC == NULL) + dc.DeleteDC(); + return FALSE; + } + + // Create the monochrome DIB section with a black and white palette + struct RGBBWBITMAPINFO + { + BITMAPINFOHEADER bmiHeader; + RGBQUAD bmiColors[2]; + }; + + RGBBWBITMAPINFO rgbBWBitmapInfo = + { + { sizeof(BITMAPINFOHEADER), nWidth, nHeight, 1, 1, BI_RGB, 0, 0, 0, 0, 0 }, + { { 0x00, 0x00, 0x00, 0x00 }, { 0xFF, 0xFF, 0xFF, 0x00 } } + }; + + VOID* pbitsBW; + CBitmap bmpBW = ::CreateDIBSection(dcBW, (LPBITMAPINFO)&rgbBWBitmapInfo, DIB_RGB_COLORS, &pbitsBW, NULL, 0); + ATLASSERT(bmpBW.m_hBitmap != NULL); + if(bmpBW.m_hBitmap == NULL) + { + if(hSrcDC == NULL) + dc.DeleteDC(); + return FALSE; + } + + // Attach the monochrome DIB section and the bitmap to the DCs + HBITMAP hbmOldBW = dcBW.SelectBitmap(bmpBW); + HBITMAP hbmOldDC = NULL; + if(hBitmap != NULL) + hbmOldDC = dc.SelectBitmap(hBitmap); + + // Block: Dark gray removal: we want (128, 128, 128) pixels to become black and not white + { + CDCT<true> dcTemp1 = ::CreateCompatibleDC(m_hDC); + CDCT<true> dcTemp2 = ::CreateCompatibleDC(m_hDC); + CBitmap bmpTemp1; + bmpTemp1.CreateCompatibleBitmap(dc, nWidth, nHeight); + CBitmap bmpTemp2; + bmpTemp2.CreateBitmap(nWidth, nHeight, 1, 1, NULL); + HBITMAP hOldBmp1 = dcTemp1.SelectBitmap(bmpTemp1); + HBITMAP hOldBmp2 = dcTemp2.SelectBitmap(bmpTemp2); + // Let's copy our image, it will be altered + dcTemp1.BitBlt(0, 0, nWidth, nHeight, dc, xSrc, ySrc, SRCCOPY); + + // All dark gray pixels will become white, the others black + dcTemp1.SetBkColor(RGB(128, 128, 128)); + dcTemp2.BitBlt(0, 0, nWidth, nHeight, dcTemp1, 0, 0, SRCCOPY); + // Do an XOR to set to black these white pixels + dcTemp1.BitBlt(0, 0, nWidth, nHeight, dcTemp2, 0, 0, SRCINVERT); + + // BitBlt the bitmap into the monochrome DIB section + // The DIB section will do a true monochrome conversion + // The magenta background being closer to white will become white + dcBW.BitBlt(0, 0, nWidth, nHeight, dcTemp1, 0, 0, SRCCOPY); + + // Cleanup + dcTemp1.SelectBitmap(hOldBmp1); + dcTemp2.SelectBitmap(hOldBmp2); + } + + // Paint the destination rectangle using hBrushBackground + if(hBrushBackground != NULL) + { + RECT rc = { x, y, x + nWidth, y + nHeight }; + FillRect(&rc, hBrushBackground); + } + + // BitBlt the black bits in the monochrome bitmap into hBrush3DEffect color in the destination DC + // The magic ROP comes from the Charles Petzold's book + HBRUSH hOldBrush = SelectBrush(hBrush3DEffect); + BitBlt(x + 1, y + 1, nWidth, nHeight, dcBW, 0, 0, 0xB8074A); + + // BitBlt the black bits in the monochrome bitmap into hBrushDisabledImage color in the destination DC + SelectBrush(hBrushDisabledImage); + BitBlt(x, y, nWidth, nHeight, dcBW, 0, 0, 0xB8074A); + + SelectBrush(hOldBrush); + dcBW.SelectBitmap(hbmOldBW); + dc.SelectBitmap(hbmOldDC); + + if(hSrcDC == NULL) + dc.DeleteDC(); + + return TRUE; + } + +// Text Functions + BOOL TextOut(int x, int y, LPCTSTR lpszString, int nCount = -1) + { + ATLASSERT(m_hDC != NULL); + if(nCount == -1) + nCount = lstrlen(lpszString); + return ::TextOut(m_hDC, x, y, lpszString, nCount); + } + + BOOL ExtTextOut(int x, int y, UINT nOptions, LPCRECT lpRect, LPCTSTR lpszString, int nCount = -1, LPINT lpDxWidths = NULL) + { + ATLASSERT(m_hDC != NULL); + if(nCount == -1) + nCount = lstrlen(lpszString); + ATLASSERT((nCount >= 0) && (nCount <= 8192)); + return ::ExtTextOut(m_hDC, x, y, nOptions, lpRect, lpszString, (UINT)nCount, lpDxWidths); + } + + SIZE TabbedTextOut(int x, int y, LPCTSTR lpszString, int nCount = -1, int nTabPositions = 0, LPINT lpnTabStopPositions = NULL, int nTabOrigin = 0) + { + ATLASSERT(m_hDC != NULL); + if(nCount == -1) + nCount = lstrlen(lpszString); + LONG lRes = ::TabbedTextOut(m_hDC, x, y, lpszString, nCount, nTabPositions, lpnTabStopPositions, nTabOrigin); + SIZE size = { GET_X_LPARAM(lRes), GET_Y_LPARAM(lRes) }; + return size; + } + + int DrawText(LPCTSTR lpstrText, int cchText, LPRECT lpRect, UINT uFormat) + { + ATLASSERT(m_hDC != NULL); + ATLASSERT((uFormat & DT_MODIFYSTRING) == 0); + return ::DrawText(m_hDC, lpstrText, cchText, lpRect, uFormat); + } + + int DrawText(LPTSTR lpstrText, int cchText, LPRECT lpRect, UINT uFormat) + { + ATLASSERT(m_hDC != NULL); + return ::DrawText(m_hDC, lpstrText, cchText, lpRect, uFormat); + } + + int DrawTextEx(LPTSTR lpstrText, int cchText, LPRECT lpRect, UINT uFormat, LPDRAWTEXTPARAMS lpDTParams = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::DrawTextEx(m_hDC, lpstrText, cchText, lpRect, uFormat, lpDTParams); + } + + // Note - ::DrawShadowText() is present only if comctl32.dll version 6 is loaded + int DrawShadowText(LPCWSTR lpstrText, int cchText, LPRECT lpRect, DWORD dwFlags, COLORREF clrText, COLORREF clrShadow, int xOffset, int yOffset) + { + ATLASSERT(m_hDC != NULL); + ATLASSERT(lpRect != NULL); + return ::DrawShadowText(m_hDC, lpstrText, cchText, lpRect, dwFlags, clrText, clrShadow, xOffset, yOffset); + } + + BOOL GetTextExtent(LPCTSTR lpszString, int nCount, LPSIZE lpSize) const + { + ATLASSERT(m_hDC != NULL); + if(nCount == -1) + nCount = lstrlen(lpszString); + return ::GetTextExtentPoint32(m_hDC, lpszString, nCount, lpSize); + } + + BOOL GetTextExtentExPoint(LPCTSTR lpszString, int cchString, LPSIZE lpSize, int nMaxExtent, LPINT lpnFit = NULL, LPINT alpDx = NULL) + { + ATLASSERT(m_hDC != NULL); + return ::GetTextExtentExPoint(m_hDC, lpszString, cchString, nMaxExtent, lpnFit, alpDx, lpSize); + } + + DWORD GetTabbedTextExtent(LPCTSTR lpszString, int nCount = -1, int nTabPositions = 0, LPINT lpnTabStopPositions = NULL) const + { + ATLASSERT(m_hDC != NULL); + if(nCount == -1) + nCount = lstrlen(lpszString); + return ::GetTabbedTextExtent(m_hDC, lpszString, nCount, nTabPositions, lpnTabStopPositions); + } + + BOOL GrayString(HBRUSH hBrush, BOOL (CALLBACK* lpfnOutput)(HDC, LPARAM, int), LPARAM lpData, int nCount, int x, int y, int nWidth, int nHeight) + { + ATLASSERT(m_hDC != NULL); + return ::GrayString(m_hDC, hBrush, (GRAYSTRINGPROC)lpfnOutput, lpData, nCount, x, y, nWidth, nHeight); + } + + UINT GetTextAlign() const + { + ATLASSERT(m_hDC != NULL); + return ::GetTextAlign(m_hDC); + } + + UINT SetTextAlign(UINT nFlags) + { + ATLASSERT(m_hDC != NULL); + return ::SetTextAlign(m_hDC, nFlags); + } + + int GetTextFace(LPTSTR lpszFacename, int nCount) const + { + ATLASSERT(m_hDC != NULL); + return ::GetTextFace(m_hDC, nCount, lpszFacename); + } + + int GetTextFaceLen() const + { + ATLASSERT(m_hDC != NULL); + return ::GetTextFace(m_hDC, 0, NULL); + } + +#ifdef _OLEAUTO_H_ + BOOL GetTextFace(BSTR& bstrFace) const + { + USES_CONVERSION; + ATLASSERT(m_hDC != NULL); + ATLASSERT(bstrFace == NULL); + + int nLen = GetTextFaceLen(); + if(nLen == 0) + return FALSE; + + ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff; + LPTSTR lpszText = buff.Allocate(nLen); + if(lpszText == NULL) + return FALSE; + + if(!GetTextFace(lpszText, nLen)) + return FALSE; + + bstrFace = ::SysAllocString(T2OLE(lpszText)); + return (bstrFace != NULL) ? TRUE : FALSE; + } +#endif + +#ifdef __ATLSTR_H__ + int GetTextFace(ATL::CString& strFace) const + { + ATLASSERT(m_hDC != NULL); + + int nLen = GetTextFaceLen(); + if(nLen == 0) + return 0; + + LPTSTR lpstr = strFace.GetBufferSetLength(nLen); + if(lpstr == NULL) + return 0; + int nRet = GetTextFace(lpstr, nLen); + strFace.ReleaseBuffer(); + return nRet; + } +#endif // __ATLSTR_H__ + + BOOL GetTextMetrics(LPTEXTMETRIC lpMetrics) const + { + ATLASSERT(m_hDC != NULL); + return ::GetTextMetrics(m_hDC, lpMetrics); + } + + int SetTextJustification(int nBreakExtra, int nBreakCount) + { + ATLASSERT(m_hDC != NULL); + return ::SetTextJustification(m_hDC, nBreakExtra, nBreakCount); + } + + int GetTextCharacterExtra() const + { + ATLASSERT(m_hDC != NULL); + return ::GetTextCharacterExtra(m_hDC); + } + + int SetTextCharacterExtra(int nCharExtra) + { + ATLASSERT(m_hDC != NULL); + return ::SetTextCharacterExtra(m_hDC, nCharExtra); + } + +// Advanced Drawing + BOOL DrawEdge(LPRECT lpRect, UINT nEdge, UINT nFlags) + { + ATLASSERT(m_hDC != NULL); + return ::DrawEdge(m_hDC, lpRect, nEdge, nFlags); + } + + BOOL DrawFrameControl(LPRECT lpRect, UINT nType, UINT nState) + { + ATLASSERT(m_hDC != NULL); + return ::DrawFrameControl(m_hDC, lpRect, nType, nState); + } + +// Scrolling Functions + BOOL ScrollDC(int dx, int dy, LPCRECT lpRectScroll, LPCRECT lpRectClip, HRGN hRgnUpdate, LPRECT lpRectUpdate) + { + ATLASSERT(m_hDC != NULL); + return ::ScrollDC(m_hDC, dx, dy, lpRectScroll, lpRectClip, hRgnUpdate, lpRectUpdate); + } + +// Font Functions + BOOL GetCharWidth(UINT nFirstChar, UINT nLastChar, LPINT lpBuffer) const + { + ATLASSERT(m_hDC != NULL); + return ::GetCharWidth(m_hDC, nFirstChar, nLastChar, lpBuffer); + } + + // GetCharWidth32 is not supported under Win9x + BOOL GetCharWidth32(UINT nFirstChar, UINT nLastChar, LPINT lpBuffer) const + { + ATLASSERT(m_hDC != NULL); + return ::GetCharWidth32(m_hDC, nFirstChar, nLastChar, lpBuffer); + } + + DWORD SetMapperFlags(DWORD dwFlag) + { + ATLASSERT(m_hDC != NULL); + return ::SetMapperFlags(m_hDC, dwFlag); + } + + BOOL GetAspectRatioFilter(LPSIZE lpSize) const + { + ATLASSERT(m_hDC != NULL); + return ::GetAspectRatioFilterEx(m_hDC, lpSize); + } + + BOOL GetCharABCWidths(UINT nFirstChar, UINT nLastChar, LPABC lpabc) const + { + ATLASSERT(m_hDC != NULL); + return ::GetCharABCWidths(m_hDC, nFirstChar, nLastChar, lpabc); + } + + DWORD GetFontData(DWORD dwTable, DWORD dwOffset, LPVOID lpData, DWORD cbData) const + { + ATLASSERT(m_hDC != NULL); + return ::GetFontData(m_hDC, dwTable, dwOffset, lpData, cbData); + } + + int GetKerningPairs(int nPairs, LPKERNINGPAIR lpkrnpair) const + { + ATLASSERT(m_hDC != NULL); + return ::GetKerningPairs(m_hDC, nPairs, lpkrnpair); + } + + UINT GetOutlineTextMetrics(UINT cbData, LPOUTLINETEXTMETRIC lpotm) const + { + ATLASSERT(m_hDC != NULL); + return ::GetOutlineTextMetrics(m_hDC, cbData, lpotm); + } + + DWORD GetGlyphOutline(UINT nChar, UINT nFormat, LPGLYPHMETRICS lpgm, DWORD cbBuffer, LPVOID lpBuffer, const MAT2* lpmat2) const + { + ATLASSERT(m_hDC != NULL); + return ::GetGlyphOutline(m_hDC, nChar, nFormat, lpgm, cbBuffer, lpBuffer, lpmat2); + } + + BOOL GetCharABCWidths(UINT nFirstChar, UINT nLastChar, LPABCFLOAT lpABCF) const + { + ATLASSERT(m_hDC != NULL); + return ::GetCharABCWidthsFloat(m_hDC, nFirstChar, nLastChar, lpABCF); + } + + BOOL GetCharWidth(UINT nFirstChar, UINT nLastChar, float* lpFloatBuffer) const + { + ATLASSERT(m_hDC != NULL); + return ::GetCharWidthFloat(m_hDC, nFirstChar, nLastChar, lpFloatBuffer); + } + +// Printer/Device Escape Functions + int Escape(int nEscape, int nCount, LPCSTR lpszInData, LPVOID lpOutData) + { + ATLASSERT(m_hDC != NULL); + return ::Escape(m_hDC, nEscape, nCount, lpszInData, lpOutData); + } + + int Escape(int nEscape, int nInputSize, LPCSTR lpszInputData, + int nOutputSize, LPSTR lpszOutputData) + { + ATLASSERT(m_hDC != NULL); + return ::ExtEscape(m_hDC, nEscape, nInputSize, lpszInputData, nOutputSize, lpszOutputData); + } + + int DrawEscape(int nEscape, int nInputSize, LPCSTR lpszInputData) + { + ATLASSERT(m_hDC != NULL); + return ::DrawEscape(m_hDC, nEscape, nInputSize, lpszInputData); + } + + // Escape helpers + int StartDoc(LPCTSTR lpszDocName) // old Win3.0 version + { + DOCINFO di = {}; + di.cbSize = sizeof(DOCINFO); + di.lpszDocName = lpszDocName; + return StartDoc(&di); + } + + int StartDoc(LPDOCINFO lpDocInfo) + { + ATLASSERT(m_hDC != NULL); + return ::StartDoc(m_hDC, lpDocInfo); + } + + int StartPage() + { + ATLASSERT(m_hDC != NULL); + return ::StartPage(m_hDC); + } + + int EndPage() + { + ATLASSERT(m_hDC != NULL); + return ::EndPage(m_hDC); + } + + int SetAbortProc(BOOL (CALLBACK* lpfn)(HDC, int)) + { + ATLASSERT(m_hDC != NULL); + return ::SetAbortProc(m_hDC, (ABORTPROC)lpfn); + } + + int AbortDoc() + { + ATLASSERT(m_hDC != NULL); + return ::AbortDoc(m_hDC); + } + + int EndDoc() + { + ATLASSERT(m_hDC != NULL); + return ::EndDoc(m_hDC); + } + +// MetaFile Functions + BOOL PlayMetaFile(HMETAFILE hMF) + { + ATLASSERT(m_hDC != NULL); + if(::GetDeviceCaps(m_hDC, TECHNOLOGY) == DT_METAFILE) + { + // playing metafile in metafile, just use core windows API + return ::PlayMetaFile(m_hDC, hMF); + } + + // for special playback, lParam == pDC + return ::EnumMetaFile(m_hDC, hMF, EnumMetaFileProc, (LPARAM)this); + } + + BOOL PlayMetaFile(HENHMETAFILE hEnhMetaFile, LPCRECT lpBounds) + { + ATLASSERT(m_hDC != NULL); + return ::PlayEnhMetaFile(m_hDC, hEnhMetaFile, lpBounds); + } + + BOOL AddMetaFileComment(UINT nDataSize, const BYTE* pCommentData) // can be used for enhanced metafiles only + { + ATLASSERT(m_hDC != NULL); + return ::GdiComment(m_hDC, nDataSize, pCommentData); + } + + // Special handling for metafile playback + static int CALLBACK EnumMetaFileProc(HDC hDC, HANDLETABLE* pHandleTable, METARECORD* pMetaRec, int nHandles, LPARAM lParam) + { + CDCT<false>* pDC = (CDCT<false>*)lParam; + + switch (pMetaRec->rdFunction) + { + case META_SETMAPMODE: + pDC->SetMapMode((int)(short)pMetaRec->rdParm[0]); + break; + case META_SETWINDOWEXT: + pDC->SetWindowExt((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]); + break; + case META_SETWINDOWORG: + pDC->SetWindowOrg((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]); + break; + case META_SETVIEWPORTEXT: + pDC->SetViewportExt((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]); + break; + case META_SETVIEWPORTORG: + pDC->SetViewportOrg((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]); + break; + case META_SCALEWINDOWEXT: + pDC->ScaleWindowExt((int)(short)pMetaRec->rdParm[3], (int)(short)pMetaRec->rdParm[2], + (int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]); + break; + case META_SCALEVIEWPORTEXT: + pDC->ScaleViewportExt((int)(short)pMetaRec->rdParm[3], (int)(short)pMetaRec->rdParm[2], + (int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]); + break; + case META_OFFSETVIEWPORTORG: + pDC->OffsetViewportOrg((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]); + break; + case META_SAVEDC: + pDC->SaveDC(); + break; + case META_RESTOREDC: + pDC->RestoreDC((int)(short)pMetaRec->rdParm[0]); + break; + case META_SETBKCOLOR: + pDC->SetBkColor(*(UNALIGNED COLORREF*)&pMetaRec->rdParm[0]); + break; + case META_SETTEXTCOLOR: + pDC->SetTextColor(*(UNALIGNED COLORREF*)&pMetaRec->rdParm[0]); + break; + + // need to watch out for SelectObject(HFONT), for custom font mapping + case META_SELECTOBJECT: + { + HGDIOBJ hObject = pHandleTable->objectHandle[pMetaRec->rdParm[0]]; + UINT nObjType = ::GetObjectType(hObject); + if(nObjType == 0) + { + // object type is unknown, determine if it is a font + HFONT hStockFont = (HFONT)::GetStockObject(SYSTEM_FONT); + HFONT hFontOld = (HFONT)::SelectObject(pDC->m_hDC, hStockFont); + HGDIOBJ hObjOld = ::SelectObject(pDC->m_hDC, hObject); + if(hObjOld == hStockFont) + { + // got the stock object back, so must be selecting a font + pDC->SelectFont((HFONT)hObject); + break; // don't play the default record + } + else + { + // didn't get the stock object back, so restore everything + ::SelectObject(pDC->m_hDC, hFontOld); + ::SelectObject(pDC->m_hDC, hObjOld); + } + // and fall through to PlayMetaFileRecord... + } + else if(nObjType == OBJ_FONT) + { + // play back as CDCHandle::SelectFont(HFONT) + pDC->SelectFont((HFONT)hObject); + break; // don't play the default record + } + } + // fall through... + + default: + ::PlayMetaFileRecord(hDC, pHandleTable, pMetaRec, nHandles); + break; + } + + return 1; + } + +// Path Functions + BOOL AbortPath() + { + ATLASSERT(m_hDC != NULL); + return ::AbortPath(m_hDC); + } + + BOOL BeginPath() + { + ATLASSERT(m_hDC != NULL); + return ::BeginPath(m_hDC); + } + + BOOL CloseFigure() + { + ATLASSERT(m_hDC != NULL); + return ::CloseFigure(m_hDC); + } + + BOOL EndPath() + { + ATLASSERT(m_hDC != NULL); + return ::EndPath(m_hDC); + } + + BOOL FillPath() + { + ATLASSERT(m_hDC != NULL); + return ::FillPath(m_hDC); + } + + BOOL FlattenPath() + { + ATLASSERT(m_hDC != NULL); + return ::FlattenPath(m_hDC); + } + + BOOL StrokeAndFillPath() + { + ATLASSERT(m_hDC != NULL); + return ::StrokeAndFillPath(m_hDC); + } + + BOOL StrokePath() + { + ATLASSERT(m_hDC != NULL); + return ::StrokePath(m_hDC); + } + + BOOL WidenPath() + { + ATLASSERT(m_hDC != NULL); + return ::WidenPath(m_hDC); + } + + BOOL GetMiterLimit(PFLOAT pfMiterLimit) const + { + ATLASSERT(m_hDC != NULL); + return ::GetMiterLimit(m_hDC, pfMiterLimit); + } + + BOOL SetMiterLimit(float fMiterLimit) + { + ATLASSERT(m_hDC != NULL); + return ::SetMiterLimit(m_hDC, fMiterLimit, NULL); + } + + int GetPath(LPPOINT lpPoints, LPBYTE lpTypes, int nCount) const + { + ATLASSERT(m_hDC != NULL); + return ::GetPath(m_hDC, lpPoints, lpTypes, nCount); + } + + BOOL SelectClipPath(int nMode) + { + ATLASSERT(m_hDC != NULL); + return ::SelectClipPath(m_hDC, nMode); + } + +// Misc Helper Functions + static CBrushHandle PASCAL GetHalftoneBrush() + { + HBRUSH halftoneBrush = NULL; + WORD grayPattern[8] = {}; + for(int i = 0; i < 8; i++) + grayPattern[i] = (WORD)(0x5555 << (i & 1)); + HBITMAP grayBitmap = CreateBitmap(8, 8, 1, 1, &grayPattern); + if(grayBitmap != NULL) + { + halftoneBrush = ::CreatePatternBrush(grayBitmap); + DeleteObject(grayBitmap); + } + return CBrushHandle(halftoneBrush); + } + + void DrawDragRect(LPCRECT lpRect, SIZE size, LPCRECT lpRectLast, SIZE sizeLast, HBRUSH hBrush = NULL, HBRUSH hBrushLast = NULL) + { + // first, determine the update region and select it + CRgn rgnOutside; + rgnOutside.CreateRectRgnIndirect(lpRect); + RECT rect = *lpRect; + ::InflateRect(&rect, -size.cx, -size.cy); + ::IntersectRect(&rect, &rect, lpRect); + CRgn rgnInside; + rgnInside.CreateRectRgnIndirect(&rect); + CRgn rgnNew; + rgnNew.CreateRectRgn(0, 0, 0, 0); + rgnNew.CombineRgn(rgnOutside, rgnInside, RGN_XOR); + + HBRUSH hBrushOld = NULL; + CBrush brushHalftone; + if(hBrush == NULL) + brushHalftone = hBrush = CDCHandle::GetHalftoneBrush(); + if(hBrushLast == NULL) + hBrushLast = hBrush; + + CRgn rgnLast; + CRgn rgnUpdate; + if(lpRectLast != NULL) + { + // find difference between new region and old region + rgnLast.CreateRectRgn(0, 0, 0, 0); + rgnOutside.SetRectRgn(lpRectLast->left, lpRectLast->top, lpRectLast->right, lpRectLast->bottom); + rect = *lpRectLast; + ::InflateRect(&rect, -sizeLast.cx, -sizeLast.cy); + ::IntersectRect(&rect, &rect, lpRectLast); + rgnInside.SetRectRgn(rect.left, rect.top, rect.right, rect.bottom); + rgnLast.CombineRgn(rgnOutside, rgnInside, RGN_XOR); + + // only diff them if brushes are the same + if(hBrush == hBrushLast) + { + rgnUpdate.CreateRectRgn(0, 0, 0, 0); + rgnUpdate.CombineRgn(rgnLast, rgnNew, RGN_XOR); + } + } + if((hBrush != hBrushLast) && (lpRectLast != NULL)) + { + // brushes are different -- erase old region first + SelectClipRgn(rgnLast); + GetClipBox(&rect); + hBrushOld = SelectBrush(hBrushLast); + PatBlt(rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, PATINVERT); + SelectBrush(hBrushOld); + hBrushOld = NULL; + } + + // draw into the update/new region + SelectClipRgn(rgnUpdate.IsNull() ? rgnNew : rgnUpdate); + GetClipBox(&rect); + hBrushOld = SelectBrush(hBrush); + PatBlt(rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, PATINVERT); + + // cleanup DC + if(hBrushOld != NULL) + SelectBrush(hBrushOld); + SelectClipRgn(NULL); + } + + void FillSolidRect(LPCRECT lpRect, COLORREF clr) + { + ATLASSERT(m_hDC != NULL); + + COLORREF clrOld = ::SetBkColor(m_hDC, clr); + ATLASSERT(clrOld != CLR_INVALID); + if(clrOld != CLR_INVALID) + { + ::ExtTextOut(m_hDC, 0, 0, ETO_OPAQUE, lpRect, NULL, 0, NULL); + ::SetBkColor(m_hDC, clrOld); + } + } + + void FillSolidRect(int x, int y, int cx, int cy, COLORREF clr) + { + ATLASSERT(m_hDC != NULL); + + RECT rect = { x, y, x + cx, y + cy }; + FillSolidRect(&rect, clr); + } + + void Draw3dRect(LPCRECT lpRect, COLORREF clrTopLeft, COLORREF clrBottomRight) + { + Draw3dRect(lpRect->left, lpRect->top, lpRect->right - lpRect->left, + lpRect->bottom - lpRect->top, clrTopLeft, clrBottomRight); + } + + void Draw3dRect(int x, int y, int cx, int cy, COLORREF clrTopLeft, COLORREF clrBottomRight) + { + FillSolidRect(x, y, cx - 1, 1, clrTopLeft); + FillSolidRect(x, y, 1, cy - 1, clrTopLeft); + FillSolidRect(x + cx, y, -1, cy, clrBottomRight); + FillSolidRect(x, y + cy, cx, -1, clrBottomRight); + } + +// DIB support + int SetDIBitsToDevice(int x, int y, DWORD dwWidth, DWORD dwHeight, int xSrc, int ySrc, UINT uStartScan, UINT cScanLines, CONST VOID* lpvBits, CONST BITMAPINFO* lpbmi, UINT uColorUse) + { + ATLASSERT(m_hDC != NULL); + return ::SetDIBitsToDevice(m_hDC, x, y, dwWidth, dwHeight, xSrc, ySrc, uStartScan, cScanLines, lpvBits, lpbmi, uColorUse); + } + + int StretchDIBits(int x, int y, int nWidth, int nHeight, int xSrc, int ySrc, int nSrcWidth, int nSrcHeight, CONST VOID* lpvBits, CONST BITMAPINFO* lpbmi, UINT uColorUse, DWORD dwRop) + { + ATLASSERT(m_hDC != NULL); + return ::StretchDIBits(m_hDC, x, y, nWidth, nHeight, xSrc, ySrc, nSrcWidth, nSrcHeight, lpvBits, lpbmi, uColorUse, dwRop); + } + + UINT GetDIBColorTable(UINT uStartIndex, UINT cEntries, RGBQUAD* pColors) const + { + ATLASSERT(m_hDC != NULL); + return ::GetDIBColorTable(m_hDC, uStartIndex, cEntries, pColors); + } + + UINT SetDIBColorTable(UINT uStartIndex, UINT cEntries, CONST RGBQUAD* pColors) + { + ATLASSERT(m_hDC != NULL); + return ::SetDIBColorTable(m_hDC, uStartIndex, cEntries, pColors); + } + +// OpenGL support +#if !defined(_ATL_NO_OPENGL) + int ChoosePixelFormat(CONST PIXELFORMATDESCRIPTOR* ppfd) + { + ATLASSERT(m_hDC != NULL); + return ::ChoosePixelFormat(m_hDC, ppfd); + } + + int DescribePixelFormat(int iPixelFormat, UINT nBytes, LPPIXELFORMATDESCRIPTOR ppfd) + { + ATLASSERT(m_hDC != NULL); + return ::DescribePixelFormat(m_hDC, iPixelFormat, nBytes, ppfd); + } + + int GetPixelFormat() const + { + ATLASSERT(m_hDC != NULL); + return ::GetPixelFormat(m_hDC); + } + + BOOL SetPixelFormat(int iPixelFormat, CONST PIXELFORMATDESCRIPTOR* ppfd) + { + ATLASSERT(m_hDC != NULL); + return ::SetPixelFormat(m_hDC, iPixelFormat, ppfd); + } + + BOOL SwapBuffers() + { + ATLASSERT(m_hDC != NULL); + return ::SwapBuffers(m_hDC); + } + + HGLRC wglCreateContext() + { + ATLASSERT(m_hDC != NULL); + return ::wglCreateContext(m_hDC); + } + + HGLRC wglCreateLayerContext(int iLayerPlane) + { + ATLASSERT(m_hDC != NULL); + return ::wglCreateLayerContext(m_hDC, iLayerPlane); + } + + BOOL wglMakeCurrent(HGLRC hglrc) + { + ATLASSERT(m_hDC != NULL); + return ::wglMakeCurrent(m_hDC, hglrc); + } + + BOOL wglUseFontBitmaps(DWORD dwFirst, DWORD dwCount, DWORD listBase) + { + ATLASSERT(m_hDC != NULL); + return ::wglUseFontBitmaps(m_hDC, dwFirst, dwCount, listBase); + } + + BOOL wglUseFontOutlines(DWORD dwFirst, DWORD dwCount, DWORD listBase, FLOAT deviation, FLOAT extrusion, int format, LPGLYPHMETRICSFLOAT lpgmf) + { + ATLASSERT(m_hDC != NULL); + return ::wglUseFontOutlines(m_hDC, dwFirst, dwCount, listBase, deviation, extrusion, format, lpgmf); + } + + BOOL wglDescribeLayerPlane(int iPixelFormat, int iLayerPlane, UINT nBytes, LPLAYERPLANEDESCRIPTOR plpd) + { + ATLASSERT(m_hDC != NULL); + return ::wglDescribeLayerPlane(m_hDC, iPixelFormat, iLayerPlane, nBytes, plpd); + } + + int wglSetLayerPaletteEntries(int iLayerPlane, int iStart, int cEntries, CONST COLORREF* pclr) + { + ATLASSERT(m_hDC != NULL); + return ::wglSetLayerPaletteEntries(m_hDC, iLayerPlane, iStart, cEntries, pclr); + } + + int wglGetLayerPaletteEntries(int iLayerPlane, int iStart, int cEntries, COLORREF* pclr) + { + ATLASSERT(m_hDC != NULL); + return ::wglGetLayerPaletteEntries(m_hDC, iLayerPlane, iStart, cEntries, pclr); + } + + BOOL wglRealizeLayerPalette(int iLayerPlane, BOOL bRealize) + { + ATLASSERT(m_hDC != NULL); + return ::wglRealizeLayerPalette(m_hDC, iLayerPlane, bRealize); + } + + BOOL wglSwapLayerBuffers(UINT uPlanes) + { + ATLASSERT(m_hDC != NULL); + return ::wglSwapLayerBuffers(m_hDC, uPlanes); + } +#endif // !defined(_ATL_NO_OPENGL) + + COLORREF GetDCPenColor() const + { + ATLASSERT(m_hDC != NULL); + return ::GetDCPenColor(m_hDC); + } + + COLORREF SetDCPenColor(COLORREF clr) + { + ATLASSERT(m_hDC != NULL); + return ::SetDCPenColor(m_hDC, clr); + } + + COLORREF GetDCBrushColor() const + { + ATLASSERT(m_hDC != NULL); + return ::GetDCBrushColor(m_hDC); + } + + COLORREF SetDCBrushColor(COLORREF clr) + { + ATLASSERT(m_hDC != NULL); + return ::SetDCBrushColor(m_hDC, clr); + } + + DWORD GetFontUnicodeRanges(LPGLYPHSET lpgs) const + { + ATLASSERT(m_hDC != NULL); + return ::GetFontUnicodeRanges(m_hDC, lpgs); + } + + DWORD GetGlyphIndices(LPCTSTR lpstr, int cch, LPWORD pgi, DWORD dwFlags) const + { + ATLASSERT(m_hDC != NULL); + return ::GetGlyphIndices(m_hDC, lpstr, cch, pgi, dwFlags); + } + + BOOL GetTextExtentPointI(LPWORD pgiIn, int cgi, LPSIZE lpSize) const + { + ATLASSERT(m_hDC != NULL); + return ::GetTextExtentPointI(m_hDC, pgiIn, cgi, lpSize); + } + + BOOL GetTextExtentExPointI(LPWORD pgiIn, int cgi, int nMaxExtent, LPINT lpnFit, LPINT alpDx, LPSIZE lpSize) const + { + ATLASSERT(m_hDC != NULL); + return ::GetTextExtentExPointI(m_hDC, pgiIn, cgi, nMaxExtent, lpnFit, alpDx, lpSize); + } + + BOOL GetCharWidthI(UINT giFirst, UINT cgi, LPWORD pgi, LPINT lpBuffer) const + { + ATLASSERT(m_hDC != NULL); + return ::GetCharWidthI(m_hDC, giFirst, cgi, pgi, lpBuffer); + } + + BOOL GetCharABCWidthsI(UINT giFirst, UINT cgi, LPWORD pgi, LPABC lpabc) const + { + ATLASSERT(m_hDC != NULL); + return ::GetCharABCWidthsI(m_hDC, giFirst, cgi, pgi, lpabc); + } + + BOOL ColorCorrectPalette(HPALETTE hPalette, DWORD dwFirstEntry, DWORD dwNumOfEntries) + { + ATLASSERT(m_hDC != NULL); + return ::ColorCorrectPalette(m_hDC, hPalette, dwFirstEntry, dwNumOfEntries); + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CDC Helpers + +class CPaintDC : public CDC +{ +public: +// Data members + HWND m_hWnd; + PAINTSTRUCT m_ps; + +// Constructor/destructor + CPaintDC(HWND hWnd) + { + ATLASSERT(::IsWindow(hWnd)); + m_hWnd = hWnd; + m_hDC = ::BeginPaint(hWnd, &m_ps); + } + + ~CPaintDC() + { + ATLASSERT(m_hDC != NULL); + ATLASSERT(::IsWindow(m_hWnd)); + ::EndPaint(m_hWnd, &m_ps); + Detach(); + } +}; + +class CClientDC : public CDC +{ +public: +// Data members + HWND m_hWnd; + +// Constructor/destructor + CClientDC(HWND hWnd) + { + ATLASSERT((hWnd == NULL) || ::IsWindow(hWnd)); + m_hWnd = hWnd; + m_hDC = ::GetDC(hWnd); + } + + ~CClientDC() + { + ATLASSERT(m_hDC != NULL); + ::ReleaseDC(m_hWnd, Detach()); + } +}; + +class CWindowDC : public CDC +{ +public: +// Data members + HWND m_hWnd; + +// Constructor/destructor + CWindowDC(HWND hWnd) + { + ATLASSERT((hWnd == NULL) || ::IsWindow(hWnd)); + m_hWnd = hWnd; + m_hDC = ::GetWindowDC(hWnd); + } + + ~CWindowDC() + { + ATLASSERT(m_hDC != NULL); + ::ReleaseDC(m_hWnd, Detach()); + } +}; + +class CMemoryDC : public CDC +{ +public: +// Data members + HDC m_hDCOriginal; + RECT m_rcPaint; + CBitmap m_bmp; + HBITMAP m_hBmpOld; + +// Constructor/destructor + CMemoryDC(HDC hDC, const RECT& rcPaint) : m_hDCOriginal(hDC), m_hBmpOld(NULL) + { + m_rcPaint = rcPaint; + CreateCompatibleDC(m_hDCOriginal); + ATLASSERT(m_hDC != NULL); + m_bmp.CreateCompatibleBitmap(m_hDCOriginal, m_rcPaint.right - m_rcPaint.left, m_rcPaint.bottom - m_rcPaint.top); + ATLASSERT(m_bmp.m_hBitmap != NULL); + m_hBmpOld = SelectBitmap(m_bmp); + SetViewportOrg(-m_rcPaint.left, -m_rcPaint.top); + } + + ~CMemoryDC() + { + ::BitBlt(m_hDCOriginal, m_rcPaint.left, m_rcPaint.top, m_rcPaint.right - m_rcPaint.left, m_rcPaint.bottom - m_rcPaint.top, m_hDC, m_rcPaint.left, m_rcPaint.top, SRCCOPY); + SelectBitmap(m_hBmpOld); + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// Enhanced metafile support + +class CEnhMetaFileInfo +{ +public: +// Data members + HENHMETAFILE m_hEMF; + BYTE* m_pBits; + TCHAR* m_pDesc; + ENHMETAHEADER m_header; + PIXELFORMATDESCRIPTOR m_pfd; + +// Constructor/destructor + CEnhMetaFileInfo(HENHMETAFILE hEMF) : m_hEMF(hEMF), m_pBits(NULL), m_pDesc(NULL) + { + memset(&m_header, 0, sizeof(m_header)); + memset(&m_pfd, 0, sizeof(m_pfd)); + } + + ~CEnhMetaFileInfo() + { + delete [] m_pBits; + delete [] m_pDesc; + } + +// Operations + BYTE* GetEnhMetaFileBits() + { + ATLASSERT(m_hEMF != NULL); + UINT nBytes = ::GetEnhMetaFileBits(m_hEMF, 0, NULL); + delete [] m_pBits; + m_pBits = NULL; + ATLTRY(m_pBits = new BYTE[nBytes]); + if (m_pBits != NULL) + ::GetEnhMetaFileBits(m_hEMF, nBytes, m_pBits); + return m_pBits; + } + + LPTSTR GetEnhMetaFileDescription() + { + ATLASSERT(m_hEMF != NULL); + UINT nLen = ::GetEnhMetaFileDescription(m_hEMF, 0, NULL); + delete [] m_pDesc; + m_pDesc = NULL; + ATLTRY(m_pDesc = new TCHAR[nLen]); + if (m_pDesc != NULL) + nLen = ::GetEnhMetaFileDescription(m_hEMF, nLen, m_pDesc); + return m_pDesc; + } + + ENHMETAHEADER* GetEnhMetaFileHeader() + { + ATLASSERT(m_hEMF != NULL); + memset(&m_header, 0, sizeof(m_header)); + m_header.iType = EMR_HEADER; + m_header.nSize = sizeof(ENHMETAHEADER); + UINT n = ::GetEnhMetaFileHeader(m_hEMF, sizeof(ENHMETAHEADER), &m_header); + return (n != 0) ? &m_header : NULL; + } + + PIXELFORMATDESCRIPTOR* GetEnhMetaFilePixelFormat() + { + ATLASSERT(m_hEMF != NULL); + memset(&m_pfd, 0, sizeof(m_pfd)); + UINT n = ::GetEnhMetaFilePixelFormat(m_hEMF, sizeof(m_pfd), &m_pfd); + return (n != 0) ? &m_pfd : NULL; + } +}; + + +template <bool t_bManaged> +class CEnhMetaFileT +{ +public: +// Data members + HENHMETAFILE m_hEMF; + +// Constructor/destructor + CEnhMetaFileT(HENHMETAFILE hEMF = NULL) : m_hEMF(hEMF) + { + } + + ~CEnhMetaFileT() + { + if(t_bManaged && (m_hEMF != NULL)) + DeleteObject(); + } + +// Operations + CEnhMetaFileT<t_bManaged>& operator =(HENHMETAFILE hEMF) + { + Attach(hEMF); + return *this; + } + + void Attach(HENHMETAFILE hEMF) + { + if(t_bManaged && (m_hEMF != NULL) && (m_hEMF != hEMF)) + DeleteObject(); + m_hEMF = hEMF; + } + + HENHMETAFILE Detach() + { + HENHMETAFILE hEMF = m_hEMF; + m_hEMF = NULL; + return hEMF; + } + + operator HENHMETAFILE() const { return m_hEMF; } + + bool IsNull() const { return (m_hEMF == NULL); } + + BOOL DeleteObject() + { + ATLASSERT(m_hEMF != NULL); + BOOL bRet = ::DeleteEnhMetaFile(m_hEMF); + m_hEMF = NULL; + return bRet; + } + + UINT GetEnhMetaFileBits(UINT cbBuffer, LPBYTE lpbBuffer) const + { + ATLASSERT(m_hEMF != NULL); + return ::GetEnhMetaFileBits(m_hEMF, cbBuffer, lpbBuffer); + } + + UINT GetEnhMetaFileDescription(UINT cchBuffer, LPTSTR lpszDescription) const + { + ATLASSERT(m_hEMF != NULL); + return ::GetEnhMetaFileDescription(m_hEMF, cchBuffer, lpszDescription); + } + + UINT GetEnhMetaFileHeader(LPENHMETAHEADER lpemh) const + { + ATLASSERT(m_hEMF != NULL); + lpemh->iType = EMR_HEADER; + lpemh->nSize = sizeof(ENHMETAHEADER); + return ::GetEnhMetaFileHeader(m_hEMF, sizeof(ENHMETAHEADER), lpemh); + } + + UINT GetEnhMetaFilePaletteEntries(UINT cEntries, LPPALETTEENTRY lppe) const + { + ATLASSERT(m_hEMF != NULL); + return ::GetEnhMetaFilePaletteEntries(m_hEMF, cEntries, lppe); + } + + UINT GetEnhMetaFilePixelFormat(DWORD cbBuffer, PIXELFORMATDESCRIPTOR* ppfd) const + { + ATLASSERT(m_hEMF != NULL); + return ::GetEnhMetaFilePixelFormat(m_hEMF, cbBuffer, ppfd); + } +}; + +typedef CEnhMetaFileT<false> CEnhMetaFileHandle; +typedef CEnhMetaFileT<true> CEnhMetaFile; + + +class CEnhMetaFileDC : public CDC +{ +public: +// Constructor/destructor + CEnhMetaFileDC() + { + } + + CEnhMetaFileDC(HDC hdc, LPCRECT lpRect) + { + Create(hdc, NULL, lpRect, NULL); + ATLASSERT(m_hDC != NULL); + } + + CEnhMetaFileDC(HDC hdcRef, LPCTSTR lpFilename, LPCRECT lpRect, LPCTSTR lpDescription) + { + Create(hdcRef, lpFilename, lpRect, lpDescription); + ATLASSERT(m_hDC != NULL); + } + + ~CEnhMetaFileDC() + { + HENHMETAFILE hEMF = Close(); + if (hEMF != NULL) + ::DeleteEnhMetaFile(hEMF); + } + +// Operations + void Create(HDC hdcRef, LPCTSTR lpFilename, LPCRECT lpRect, LPCTSTR lpDescription) + { + ATLASSERT(m_hDC == NULL); + m_hDC = ::CreateEnhMetaFile(hdcRef, lpFilename, lpRect, lpDescription); + } + + HENHMETAFILE Close() + { + HENHMETAFILE hEMF = NULL; + if (m_hDC != NULL) + { + hEMF = ::CloseEnhMetaFile(m_hDC); + m_hDC = NULL; + } + return hEMF; + } +}; + +} // namespace WTL + +#endif // __ATLGDI_H__ diff --git a/Examples/WhisperDesktop/Utils/WTL/atlres.h b/Examples/WhisperDesktop/Utils/WTL/atlres.h new file mode 100644 index 0000000..aed58f1 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/WTL/atlres.h @@ -0,0 +1,259 @@ +// Windows Template Library - WTL version 10.0 +// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved. +// +// This file is a part of the Windows Template Library. +// The use and distribution terms for this software are covered by the +// Microsoft Public License (http://opensource.org/licenses/MS-PL) +// which can be found in the file MS-PL.txt at the root folder. + +#ifndef __ATLRES_H__ +#define __ATLRES_H__ + +#pragma once + + +#ifdef RC_INVOKED +#ifndef _INC_WINDOWS + + #define _INC_WINDOWS + + #define VS_VERSION_INFO 1 + + #ifdef APSTUDIO_INVOKED + #define APSTUDIO_HIDDEN_SYMBOLS // Ignore following symbols + #endif // APSTUDIO_INVOKED + + #ifndef WINVER + #define WINVER 0x0500 + #endif // !WINVER + + #include <winresrc.h> + + // operation messages sent to DLGINIT + #define LB_ADDSTRING (WM_USER+1) + #define CB_ADDSTRING (WM_USER+3) + + #ifdef APSTUDIO_INVOKED + #undef APSTUDIO_HIDDEN_SYMBOLS + #endif // APSTUDIO_INVOKED + + #ifdef IDC_STATIC + #undef IDC_STATIC + #endif // IDC_STATIC + #define IDC_STATIC (-1) + +#endif // !_INC_WINDOWS +#endif // RC_INVOKED + +#ifdef APSTUDIO_INVOKED + #define APSTUDIO_HIDDEN_SYMBOLS +#endif // APSTUDIO_INVOKED + +/////////////////////////////////////////////////////////////////////////////// +// ATL resource types + +#ifndef RC_INVOKED + #define RT_DLGINIT MAKEINTRESOURCE(240) + #define RT_TOOLBAR MAKEINTRESOURCE(241) +#endif // RC_INVOKED + +/////////////////////////////////////////////////////////////////////////////// + +#ifdef APSTUDIO_INVOKED + #undef APSTUDIO_HIDDEN_SYMBOLS +#endif // APSTUDIO_INVOKED + +/////////////////////////////////////////////////////////////////////////////// +// Standard window components + +#define ID_SEPARATOR 0 // special separator value +#define ID_DEFAULT_PANE 0 // default status bar pane + +#ifndef RC_INVOKED // code only +// standard control bars (IDW = window ID) + #define ATL_IDW_TOOLBAR 0xE800 // main Toolbar for window + #define ATL_IDW_STATUS_BAR 0xE801 // Status bar window + #define ATL_IDW_COMMAND_BAR 0xE802 // Command bar window + +// parts of a frame window + #define ATL_IDW_CLIENT 0xE900 + #define ATL_IDW_PANE_FIRST 0xE900 // first pane (256 max) + #define ATL_IDW_PANE_LAST 0xE9FF + #define ATL_IDW_HSCROLL_FIRST 0xEA00 // first Horz scrollbar (16 max) + #define ATL_IDW_VSCROLL_FIRST 0xEA10 // first Vert scrollbar (16 max) + + #define ATL_IDW_SIZE_BOX 0xEA20 // size box for splitters + #define ATL_IDW_PANE_SAVE 0xEA21 // to shift ATL_IDW_PANE_FIRST + +// bands for a rebar + #define ATL_IDW_BAND_FIRST 0xEB00 + #define ATL_IDW_BAND_LAST 0xEBFF +#endif // !RC_INVOKED + +/////////////////////////////////////////////////////////////////////////////// +// Standard Commands + +// File commands +#define ID_FILE_NEW 0xE100 +#define ID_FILE_OPEN 0xE101 +#define ID_FILE_CLOSE 0xE102 +#define ID_FILE_SAVE 0xE103 +#define ID_FILE_SAVE_AS 0xE104 +#define ID_FILE_PAGE_SETUP 0xE105 +#define ID_FILE_PRINT_SETUP 0xE106 +#define ID_FILE_PRINT 0xE107 +#define ID_FILE_PRINT_DIRECT 0xE108 +#define ID_FILE_PRINT_PREVIEW 0xE109 +#define ID_FILE_UPDATE 0xE10A +#define ID_FILE_SAVE_COPY_AS 0xE10B +#define ID_FILE_SEND_MAIL 0xE10C + +#define ID_FILE_MRU_FIRST 0xE110 +#define ID_FILE_MRU_FILE1 0xE110 // range - 16 max +#define ID_FILE_MRU_FILE2 0xE111 +#define ID_FILE_MRU_FILE3 0xE112 +#define ID_FILE_MRU_FILE4 0xE113 +#define ID_FILE_MRU_FILE5 0xE114 +#define ID_FILE_MRU_FILE6 0xE115 +#define ID_FILE_MRU_FILE7 0xE116 +#define ID_FILE_MRU_FILE8 0xE117 +#define ID_FILE_MRU_FILE9 0xE118 +#define ID_FILE_MRU_FILE10 0xE119 +#define ID_FILE_MRU_FILE11 0xE11A +#define ID_FILE_MRU_FILE12 0xE11B +#define ID_FILE_MRU_FILE13 0xE11C +#define ID_FILE_MRU_FILE14 0xE11D +#define ID_FILE_MRU_FILE15 0xE11E +#define ID_FILE_MRU_FILE16 0xE11F +#define ID_FILE_MRU_LAST 0xE11F + +// Edit commands +#define ID_EDIT_CLEAR 0xE120 +#define ID_EDIT_CLEAR_ALL 0xE121 +#define ID_EDIT_COPY 0xE122 +#define ID_EDIT_CUT 0xE123 +#define ID_EDIT_FIND 0xE124 +#define ID_EDIT_PASTE 0xE125 +#define ID_EDIT_PASTE_LINK 0xE126 +#define ID_EDIT_PASTE_SPECIAL 0xE127 +#define ID_EDIT_REPEAT 0xE128 +#define ID_EDIT_REPLACE 0xE129 +#define ID_EDIT_SELECT_ALL 0xE12A +#define ID_EDIT_UNDO 0xE12B +#define ID_EDIT_REDO 0xE12C +#define ID_EDIT_DELETE ID_EDIT_CLEAR +#define ID_EDIT_FIND_NEXT ID_EDIT_REPEAT +#define ID_EDIT_FIND_PREVIOUS 0xE12D + +// Window commands +#define ID_WINDOW_NEW 0xE130 +#define ID_WINDOW_ARRANGE 0xE131 +#define ID_WINDOW_CASCADE 0xE132 +#define ID_WINDOW_TILE_HORZ 0xE133 +#define ID_WINDOW_TILE_VERT 0xE134 +#define ID_WINDOW_SPLIT 0xE135 +#ifndef RC_INVOKED // code only + #define ATL_IDM_WINDOW_FIRST 0xE130 + #define ATL_IDM_WINDOW_LAST 0xE13F + #define ATL_IDM_FIRST_MDICHILD 0xFF00 // window list starts here + #define ATL_IDM_LAST_MDICHILD 0xFFFD +#endif // !RC_INVOKED +// TabView +#define ID_WINDOW_TABFIRST 0xFF00 // = ATL_IDM_FIRST_MDICHILD +#define ID_WINDOW_TABLAST 0xFFFD +#define ID_WINDOW_SHOWTABLIST 0xFFFE + +// Help and App commands +#define ID_APP_ABOUT 0xE140 +#define ID_APP_EXIT 0xE141 +#define ID_HELP_INDEX 0xE142 +#define ID_HELP_FINDER 0xE143 +#define ID_HELP_USING 0xE144 +#define ID_CONTEXT_HELP 0xE145 // shift-F1 +// special commands for processing help +#define ID_HELP 0xE146 // first attempt for F1 +#define ID_DEFAULT_HELP 0xE147 // last attempt + +// Misc +#define ID_NEXT_PANE 0xE150 +#define ID_PREV_PANE 0xE151 +#define ID_PANE_CLOSE 0xE152 +#define ID_PANE_NEXT ID_NEXT_PANE +#define ID_PANE_PREVIOUS ID_PREV_PANE + +// Format +#define ID_FORMAT_FONT 0xE160 + +// Scroll +#define ID_SCROLL_UP 0xE170 +#define ID_SCROLL_DOWN 0xE171 +#define ID_SCROLL_PAGE_UP 0xE172 +#define ID_SCROLL_PAGE_DOWN 0xE173 +#define ID_SCROLL_TOP 0xE174 +#define ID_SCROLL_BOTTOM 0xE175 +#define ID_SCROLL_LEFT 0xE176 +#define ID_SCROLL_RIGHT 0xE177 +#define ID_SCROLL_PAGE_LEFT 0xE178 +#define ID_SCROLL_PAGE_RIGHT 0xE179 +#define ID_SCROLL_ALL_LEFT 0xE17A +#define ID_SCROLL_ALL_RIGHT 0xE17B + +// OLE commands +#define ID_OLE_INSERT_NEW 0xE200 +#define ID_OLE_EDIT_LINKS 0xE201 +#define ID_OLE_EDIT_CONVERT 0xE202 +#define ID_OLE_EDIT_CHANGE_ICON 0xE203 +#define ID_OLE_EDIT_PROPERTIES 0xE204 +#define ID_OLE_VERB_FIRST 0xE210 // range - 16 max +#ifndef RC_INVOKED // code only + #define ID_OLE_VERB_LAST 0xE21F +#endif // !RC_INVOKED + +// View commands (same number used as IDW used for toolbar and status bar) +#define ID_VIEW_TOOLBAR 0xE800 +#define ID_VIEW_STATUS_BAR 0xE801 +#define ID_VIEW_REFRESH 0xE803 +#define ID_VIEW_RIBBON 0xE804 + +/////////////////////////////////////////////////////////////////////////////// +// Standard control IDs + +#ifdef IDC_STATIC + #undef IDC_STATIC +#endif // IDC_STATIC +#define IDC_STATIC (-1) // all static controls + +/////////////////////////////////////////////////////////////////////////////// +// Standard string error/warnings + +// idle status bar message +#define ATL_IDS_IDLEMESSAGE 0xE001 + +#ifndef RC_INVOKED // code only + #define ATL_IDS_SCFIRST 0xEF00 +#endif // !RC_INVOKED + +#define ATL_IDS_SCSIZE 0xEF00 +#define ATL_IDS_SCMOVE 0xEF01 +#define ATL_IDS_SCMINIMIZE 0xEF02 +#define ATL_IDS_SCMAXIMIZE 0xEF03 +#define ATL_IDS_SCNEXTWINDOW 0xEF04 +#define ATL_IDS_SCPREVWINDOW 0xEF05 +#define ATL_IDS_SCCLOSE 0xEF06 +#define ATL_IDS_SCRESTORE 0xEF12 +#define ATL_IDS_SCTASKLIST 0xEF13 + +#define ATL_IDS_MDICHILD 0xEF1F +#define ATL_IDS_MRU_FILE 0xEFDA + +/////////////////////////////////////////////////////////////////////////////// +// Misc. control IDs + +// Property Sheet control id's (determined with Spy++) +#define ID_APPLY_NOW 0x3021 +#define ID_WIZBACK 0x3023 +#define ID_WIZNEXT 0x3024 +#define ID_WIZFINISH 0x3025 +#define ATL_IDC_TAB_CONTROL 0x3020 + +#endif // __ATLRES_H__ diff --git a/Examples/WhisperDesktop/Utils/WTL/atluser.h b/Examples/WhisperDesktop/Utils/WTL/atluser.h new file mode 100644 index 0000000..9f8e0f4 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/WTL/atluser.h @@ -0,0 +1,1231 @@ +// Windows Template Library - WTL version 10.0 +// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved. +// +// This file is a part of the Windows Template Library. +// The use and distribution terms for this software are covered by the +// Microsoft Public License (http://opensource.org/licenses/MS-PL) +// which can be found in the file MS-PL.txt at the root folder. + +#ifndef __ATLUSER_H__ +#define __ATLUSER_H__ + +#pragma once + +#ifndef __ATLAPP_H__ + #error atluser.h requires atlapp.h to be included first +#endif + + +/////////////////////////////////////////////////////////////////////////////// +// Classes in this file: +// +// CMenuItemInfo +// CMenuT<t_bManaged> +// CAcceleratorT<t_bManaged> +// CIconT<t_bManaged> +// CCursorT<t_bManaged> +// CResource +// +// Global functions: +// AtlMessageBox() +// +// AtlLoadAccelerators() +// AtlLoadMenu() +// AtlLoadBitmap() +// AtlLoadSysBitmap() +// AtlLoadCursor() +// AtlLoadSysCursor() +// AtlLoadIcon() +// AtlLoadSysIcon() +// AtlLoadBitmapImage() +// AtlLoadCursorImage() +// AtlLoadIconImage() +// AtlLoadSysBitmapImage() +// AtlLoadSysCursorImage() +// AtlLoadSysIconImage() +// AtlLoadString() + + +namespace WTL +{ + +/////////////////////////////////////////////////////////////////////////////// +// AtlMessageBox - accepts both memory and resource based strings + +inline int AtlMessageBox(HWND hWndOwner, ATL::_U_STRINGorID message, ATL::_U_STRINGorID title = (LPCTSTR)NULL, UINT uType = MB_OK | MB_ICONINFORMATION) +{ + ATLASSERT((hWndOwner == NULL) || ::IsWindow(hWndOwner)); + + LPTSTR lpstrMessage = NULL; + if(IS_INTRESOURCE(message.m_lpstr)) + { + for(int nLen = 256; ; nLen *= 2) + { + ATLTRY(lpstrMessage = new TCHAR[nLen]); + if(lpstrMessage == NULL) + { + ATLASSERT(FALSE); + return 0; + } + int nRes = ::LoadString(ModuleHelper::GetResourceInstance(), LOWORD(message.m_lpstr), lpstrMessage, nLen); + if(nRes < nLen - 1) + break; + delete [] lpstrMessage; + lpstrMessage = NULL; + } + + message.m_lpstr = lpstrMessage; + } + + LPTSTR lpstrTitle = NULL; + if(IS_INTRESOURCE(title.m_lpstr) && (LOWORD(title.m_lpstr) != 0)) + { + for(int nLen = 256; ; nLen *= 2) + { + ATLTRY(lpstrTitle = new TCHAR[nLen]); + if(lpstrTitle == NULL) + { + ATLASSERT(FALSE); + return 0; + } + int nRes = ::LoadString(ModuleHelper::GetResourceInstance(), LOWORD(title.m_lpstr), lpstrTitle, nLen); + if(nRes < nLen - 1) + break; + delete [] lpstrTitle; + lpstrTitle = NULL; + } + + title.m_lpstr = lpstrTitle; + } + + int nRet = ::MessageBox(hWndOwner, message.m_lpstr, title.m_lpstr, uType); + + delete [] lpstrMessage; + delete [] lpstrTitle; + + return nRet; +} + + +/////////////////////////////////////////////////////////////////////////////// +// CMenu + +class CMenuItemInfo : public MENUITEMINFO +{ +public: + CMenuItemInfo() + { + memset(this, 0, sizeof(MENUITEMINFO)); + cbSize = sizeof(MENUITEMINFO); + } +}; + + +// forward declarations +template <bool t_bManaged> class CMenuT; +typedef CMenuT<false> CMenuHandle; +typedef CMenuT<true> CMenu; + + +template <bool t_bManaged> +class CMenuT +{ +public: +// Data members + HMENU m_hMenu; + +// Constructor/destructor/operators + CMenuT(HMENU hMenu = NULL) : m_hMenu(hMenu) + { } + + ~CMenuT() + { + if(t_bManaged && (m_hMenu != NULL)) + DestroyMenu(); + } + + CMenuT<t_bManaged>& operator =(HMENU hMenu) + { + Attach(hMenu); + return *this; + } + + void Attach(HMENU hMenuNew) + { + ATLASSERT(::IsMenu(hMenuNew)); + if(t_bManaged && (m_hMenu != NULL) && (m_hMenu != hMenuNew)) + ::DestroyMenu(m_hMenu); + m_hMenu = hMenuNew; + } + + HMENU Detach() + { + HMENU hMenu = m_hMenu; + m_hMenu = NULL; + return hMenu; + } + + operator HMENU() const { return m_hMenu; } + + bool IsNull() const { return (m_hMenu == NULL); } + + BOOL IsMenu() const + { + return ::IsMenu(m_hMenu); + } + +// Create/destroy methods + BOOL CreateMenu() + { + ATLASSERT(m_hMenu == NULL); + m_hMenu = ::CreateMenu(); + return (m_hMenu != NULL) ? TRUE : FALSE; + } + + BOOL CreatePopupMenu() + { + ATLASSERT(m_hMenu == NULL); + m_hMenu = ::CreatePopupMenu(); + return (m_hMenu != NULL) ? TRUE : FALSE; + } + + BOOL LoadMenu(ATL::_U_STRINGorID menu) + { + ATLASSERT(m_hMenu == NULL); + m_hMenu = ::LoadMenu(ModuleHelper::GetResourceInstance(), menu.m_lpstr); + return (m_hMenu != NULL) ? TRUE : FALSE; + } + + BOOL LoadMenuIndirect(const void* lpMenuTemplate) + { + ATLASSERT(m_hMenu == NULL); + m_hMenu = ::LoadMenuIndirect(lpMenuTemplate); + return (m_hMenu != NULL) ? TRUE : FALSE; + } + + BOOL DestroyMenu() + { + if (m_hMenu == NULL) + return FALSE; + BOOL bRet = ::DestroyMenu(m_hMenu); + if(bRet) + m_hMenu = NULL; + return bRet; + } + +// Menu Operations + BOOL DeleteMenu(UINT nPosition, UINT nFlags) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::DeleteMenu(m_hMenu, nPosition, nFlags); + } + + BOOL TrackPopupMenu(UINT nFlags, int x, int y, HWND hWnd, LPCRECT lpRect = NULL) + { + ATLASSERT(::IsMenu(m_hMenu)); + x = _FixTrackMenuPopupX(x, y); + return ::TrackPopupMenu(m_hMenu, nFlags, x, y, 0, hWnd, lpRect); + } + + BOOL TrackPopupMenuEx(UINT uFlags, int x, int y, HWND hWnd, LPTPMPARAMS lptpm = NULL) + { + ATLASSERT(::IsMenu(m_hMenu)); + x = _FixTrackMenuPopupX(x, y); + return ::TrackPopupMenuEx(m_hMenu, uFlags, x, y, hWnd, lptpm); + } + + // helper that fixes popup menu X position when it's off-screen + static int _FixTrackMenuPopupX(int x, int y) + { + POINT pt = { x, y }; + HMONITOR hMonitor = ::MonitorFromPoint(pt, MONITOR_DEFAULTTONULL); + if(hMonitor == NULL) + { + HMONITOR hMonitorNear = ::MonitorFromPoint(pt, MONITOR_DEFAULTTONEAREST); + if(hMonitorNear != NULL) + { + MONITORINFO mi = { sizeof(MONITORINFO) }; + if(::GetMonitorInfo(hMonitorNear, &mi) != FALSE) + { + if(x < mi.rcWork.left) + x = mi.rcWork.left; + else if(x > mi.rcWork.right) + x = mi.rcWork.right; + } + } + } + + return x; + } + + BOOL GetMenuInfo(LPMENUINFO lpMenuInfo) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuInfo(m_hMenu, lpMenuInfo); + } + + BOOL SetMenuInfo(LPCMENUINFO lpMenuInfo) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::SetMenuInfo(m_hMenu, lpMenuInfo); + } + +// Menu Item Operations + BOOL AppendMenu(UINT nFlags, UINT_PTR nIDNewItem = 0, LPCTSTR lpszNewItem = NULL) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::AppendMenu(m_hMenu, nFlags, nIDNewItem, lpszNewItem); + } + + BOOL AppendMenu(UINT nFlags, HMENU hSubMenu, LPCTSTR lpszNewItem) + { + ATLASSERT(::IsMenu(m_hMenu)); + ATLASSERT(::IsMenu(hSubMenu)); + return ::AppendMenu(m_hMenu, nFlags | MF_POPUP, (UINT_PTR)hSubMenu, lpszNewItem); + } + + BOOL AppendMenu(UINT nFlags, UINT_PTR nIDNewItem, HBITMAP hBmp) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::AppendMenu(m_hMenu, nFlags | MF_BITMAP, nIDNewItem, (LPCTSTR)hBmp); + } + + BOOL AppendMenu(UINT nFlags, HMENU hSubMenu, HBITMAP hBmp) + { + ATLASSERT(::IsMenu(m_hMenu)); + ATLASSERT(::IsMenu(hSubMenu)); + return ::AppendMenu(m_hMenu, nFlags | (MF_BITMAP | MF_POPUP), (UINT_PTR)hSubMenu, (LPCTSTR)hBmp); + } + + UINT CheckMenuItem(UINT nIDCheckItem, UINT nCheck) + { + ATLASSERT(::IsMenu(m_hMenu)); + return (UINT)::CheckMenuItem(m_hMenu, nIDCheckItem, nCheck); + } + + UINT EnableMenuItem(UINT nIDEnableItem, UINT nEnable) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::EnableMenuItem(m_hMenu, nIDEnableItem, nEnable); + } + + BOOL HiliteMenuItem(HWND hWnd, UINT uIDHiliteItem, UINT uHilite) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::HiliteMenuItem(hWnd, m_hMenu, uIDHiliteItem, uHilite); + } + + int GetMenuItemCount() const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuItemCount(m_hMenu); + } + + UINT GetMenuItemID(int nPos) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuItemID(m_hMenu, nPos); + } + + UINT GetMenuState(UINT nID, UINT nFlags) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuState(m_hMenu, nID, nFlags); + } + + int GetMenuString(UINT nIDItem, LPTSTR lpString, int nMaxCount, UINT nFlags) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuString(m_hMenu, nIDItem, lpString, nMaxCount, nFlags); + } + + int GetMenuStringLen(UINT nIDItem, UINT nFlags) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuString(m_hMenu, nIDItem, NULL, 0, nFlags); + } + + BOOL GetMenuString(UINT nIDItem, BSTR& bstrText, UINT nFlags) const + { + USES_CONVERSION; + ATLASSERT(::IsMenu(m_hMenu)); + ATLASSERT(bstrText == NULL); + + int nLen = GetMenuStringLen(nIDItem, nFlags); + if(nLen == 0) + { + bstrText = ::SysAllocString(OLESTR("")); + return (bstrText != NULL) ? TRUE : FALSE; + } + + nLen++; // increment to include terminating NULL char + ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff; + LPTSTR lpszText = buff.Allocate(nLen); + if(lpszText == NULL) + return FALSE; + + if(!GetMenuString(nIDItem, lpszText, nLen, nFlags)) + return FALSE; + + bstrText = ::SysAllocString(T2OLE(lpszText)); + return (bstrText != NULL) ? TRUE : FALSE; + } + +#ifdef __ATLSTR_H__ + int GetMenuString(UINT nIDItem, ATL::CString& strText, UINT nFlags) const + { + ATLASSERT(::IsMenu(m_hMenu)); + + int nLen = GetMenuStringLen(nIDItem, nFlags); + if(nLen == 0) + return 0; + + nLen++; // increment to include terminating NULL char + LPTSTR lpstr = strText.GetBufferSetLength(nLen); + if(lpstr == NULL) + return 0; + int nRet = GetMenuString(nIDItem, lpstr, nLen, nFlags); + strText.ReleaseBuffer(); + return nRet; + } +#endif // __ATLSTR_H__ + + CMenuHandle GetSubMenu(int nPos) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return CMenuHandle(::GetSubMenu(m_hMenu, nPos)); + } + + BOOL InsertMenu(UINT nPosition, UINT nFlags, UINT_PTR nIDNewItem = 0, LPCTSTR lpszNewItem = NULL) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::InsertMenu(m_hMenu, nPosition, nFlags, nIDNewItem, lpszNewItem); + } + + BOOL InsertMenu(UINT nPosition, UINT nFlags, HMENU hSubMenu, LPCTSTR lpszNewItem) + { + ATLASSERT(::IsMenu(m_hMenu)); + ATLASSERT(::IsMenu(hSubMenu)); + return ::InsertMenu(m_hMenu, nPosition, nFlags | MF_POPUP, (UINT_PTR)hSubMenu, lpszNewItem); + } + + BOOL InsertMenu(UINT nPosition, UINT nFlags, UINT_PTR nIDNewItem, HBITMAP hBmp) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::InsertMenu(m_hMenu, nPosition, nFlags | MF_BITMAP, nIDNewItem, (LPCTSTR)hBmp); + } + + BOOL InsertMenu(UINT nPosition, UINT nFlags, HMENU hSubMenu, HBITMAP hBmp) + { + ATLASSERT(::IsMenu(m_hMenu)); + ATLASSERT(::IsMenu(hSubMenu)); + return ::InsertMenu(m_hMenu, nPosition, nFlags | (MF_BITMAP | MF_POPUP), (UINT_PTR)hSubMenu, (LPCTSTR)hBmp); + } + + BOOL ModifyMenu(UINT nPosition, UINT nFlags, UINT_PTR nIDNewItem = 0, LPCTSTR lpszNewItem = NULL) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::ModifyMenu(m_hMenu, nPosition, nFlags, nIDNewItem, lpszNewItem); + } + + BOOL ModifyMenu(UINT nPosition, UINT nFlags, HMENU hSubMenu, LPCTSTR lpszNewItem) + { + ATLASSERT(::IsMenu(m_hMenu)); + ATLASSERT(::IsMenu(hSubMenu)); + return ::ModifyMenu(m_hMenu, nPosition, nFlags | MF_POPUP, (UINT_PTR)hSubMenu, lpszNewItem); + } + + BOOL ModifyMenu(UINT nPosition, UINT nFlags, UINT_PTR nIDNewItem, HBITMAP hBmp) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::ModifyMenu(m_hMenu, nPosition, nFlags | MF_BITMAP, nIDNewItem, (LPCTSTR)hBmp); + } + + BOOL ModifyMenu(UINT nPosition, UINT nFlags, HMENU hSubMenu, HBITMAP hBmp) + { + ATLASSERT(::IsMenu(m_hMenu)); + ATLASSERT(::IsMenu(hSubMenu)); + return ::ModifyMenu(m_hMenu, nPosition, nFlags | (MF_BITMAP | MF_POPUP), (UINT_PTR)hSubMenu, (LPCTSTR)hBmp); + } + + BOOL RemoveMenu(UINT nPosition, UINT nFlags) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::RemoveMenu(m_hMenu, nPosition, nFlags); + } + + BOOL SetMenuItemBitmaps(UINT nPosition, UINT nFlags, HBITMAP hBmpUnchecked, HBITMAP hBmpChecked) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::SetMenuItemBitmaps(m_hMenu, nPosition, nFlags, hBmpUnchecked, hBmpChecked); + } + + BOOL CheckMenuRadioItem(UINT nIDFirst, UINT nIDLast, UINT nIDItem, UINT nFlags) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::CheckMenuRadioItem(m_hMenu, nIDFirst, nIDLast, nIDItem, nFlags); + } + + BOOL GetMenuItemInfo(UINT uItem, BOOL bByPosition, LPMENUITEMINFO lpmii) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return (BOOL)::GetMenuItemInfo(m_hMenu, uItem, bByPosition, lpmii); + } + + BOOL SetMenuItemInfo(UINT uItem, BOOL bByPosition, LPMENUITEMINFO lpmii) + { + ATLASSERT(::IsMenu(m_hMenu)); + return (BOOL)::SetMenuItemInfo(m_hMenu, uItem, bByPosition, lpmii); + } + + BOOL InsertMenuItem(UINT uItem, BOOL bByPosition, LPMENUITEMINFO lpmii) + { + ATLASSERT(::IsMenu(m_hMenu)); + return (BOOL)::InsertMenuItem(m_hMenu, uItem, bByPosition, lpmii); + } + + UINT GetMenuDefaultItem(BOOL bByPosition = FALSE, UINT uFlags = 0U) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuDefaultItem(m_hMenu, (UINT)bByPosition, uFlags); + } + + BOOL SetMenuDefaultItem(UINT uItem = (UINT)-1, BOOL bByPosition = FALSE) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::SetMenuDefaultItem(m_hMenu, uItem, (UINT)bByPosition); + } + + BOOL GetMenuItemRect(HWND hWnd, UINT uItem, LPRECT lprcItem) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuItemRect(hWnd, m_hMenu, uItem, lprcItem); + } + + int MenuItemFromPoint(HWND hWnd, POINT point) const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::MenuItemFromPoint(hWnd, m_hMenu, point); + } + +// Context Help Functions + BOOL SetMenuContextHelpId(DWORD dwContextHelpId) + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::SetMenuContextHelpId(m_hMenu, dwContextHelpId); + } + + DWORD GetMenuContextHelpId() const + { + ATLASSERT(::IsMenu(m_hMenu)); + return ::GetMenuContextHelpId(m_hMenu); + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// CAccelerator + +template <bool t_bManaged> +class CAcceleratorT +{ +public: + HACCEL m_hAccel; + +// Constructor/destructor/operators + CAcceleratorT(HACCEL hAccel = NULL) : m_hAccel(hAccel) + { } + + ~CAcceleratorT() + { + if(t_bManaged && (m_hAccel != NULL)) + ::DestroyAcceleratorTable(m_hAccel); + } + + CAcceleratorT<t_bManaged>& operator =(HACCEL hAccel) + { + Attach(hAccel); + return *this; + } + + void Attach(HACCEL hAccel) + { + if(t_bManaged && (m_hAccel != NULL)) + ::DestroyAcceleratorTable(m_hAccel); + m_hAccel = hAccel; + } + + HACCEL Detach() + { + HACCEL hAccel = m_hAccel; + m_hAccel = NULL; + return hAccel; + } + + operator HACCEL() const { return m_hAccel; } + + bool IsNull() const { return m_hAccel == NULL; } + +// Create/destroy methods + HACCEL LoadAccelerators(ATL::_U_STRINGorID accel) + { + ATLASSERT(m_hAccel == NULL); + m_hAccel = ::LoadAccelerators(ModuleHelper::GetResourceInstance(), accel.m_lpstr); + return m_hAccel; + } + + HACCEL CreateAcceleratorTable(LPACCEL pAccel, int cEntries) + { + ATLASSERT(m_hAccel == NULL); + ATLASSERT(pAccel != NULL); + m_hAccel = ::CreateAcceleratorTable(pAccel, cEntries); + return m_hAccel; + } + + void DestroyObject() + { + if(m_hAccel != NULL) + { + ::DestroyAcceleratorTable(m_hAccel); + m_hAccel = NULL; + } + } + +// Operations + int CopyAcceleratorTable(LPACCEL lpAccelDst, int cEntries) + { + ATLASSERT(m_hAccel != NULL); + ATLASSERT(lpAccelDst != NULL); + return ::CopyAcceleratorTable(m_hAccel, lpAccelDst, cEntries); + } + + int GetEntriesCount() const + { + ATLASSERT(m_hAccel != NULL); + return ::CopyAcceleratorTable(m_hAccel, NULL, 0); + } + + BOOL TranslateAccelerator(HWND hWnd, LPMSG pMsg) + { + ATLASSERT(m_hAccel != NULL); + ATLASSERT(::IsWindow(hWnd)); + ATLASSERT(pMsg != NULL); + return ::TranslateAccelerator(hWnd, m_hAccel, pMsg); + } +}; + +typedef CAcceleratorT<false> CAcceleratorHandle; +typedef CAcceleratorT<true> CAccelerator; + + +/////////////////////////////////////////////////////////////////////////////// +// CIcon + +template <bool t_bManaged> +class CIconT +{ +public: + HICON m_hIcon; + +// Constructor/destructor/operators + CIconT(HICON hIcon = NULL) : m_hIcon(hIcon) + { } + + ~CIconT() + { + if(t_bManaged && (m_hIcon != NULL)) + ::DestroyIcon(m_hIcon); + } + + CIconT<t_bManaged>& operator =(HICON hIcon) + { + Attach(hIcon); + return *this; + } + + void Attach(HICON hIcon) + { + if(t_bManaged && (m_hIcon != NULL)) + ::DestroyIcon(m_hIcon); + m_hIcon = hIcon; + } + + HICON Detach() + { + HICON hIcon = m_hIcon; + m_hIcon = NULL; + return hIcon; + } + + operator HICON() const { return m_hIcon; } + + bool IsNull() const { return m_hIcon == NULL; } + +// Create/destroy methods + HICON LoadIcon(ATL::_U_STRINGorID icon) + { + ATLASSERT(m_hIcon == NULL); + m_hIcon = ::LoadIcon(ModuleHelper::GetResourceInstance(), icon.m_lpstr); + return m_hIcon; + } + + HICON LoadIcon(ATL::_U_STRINGorID icon, int cxDesired, int cyDesired, UINT fuLoad = 0) + { + ATLASSERT(m_hIcon == NULL); + m_hIcon = (HICON) ::LoadImage(ModuleHelper::GetResourceInstance(), icon.m_lpstr, IMAGE_ICON, cxDesired, cyDesired, fuLoad); + return m_hIcon; + } + + HICON LoadOEMIcon(LPCTSTR lpstrIconName) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(IsOEMIcon(lpstrIconName)); + m_hIcon = ::LoadIcon(NULL, lpstrIconName); + return m_hIcon; + } + + HICON CreateIcon(int nWidth, int nHeight, BYTE cPlanes, BYTE cBitsPixel, CONST BYTE* lpbANDbits, CONST BYTE *lpbXORbits) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(lpbANDbits != NULL); + ATLASSERT(lpbXORbits != NULL); + m_hIcon = ::CreateIcon(ModuleHelper::GetResourceInstance(), nWidth, nHeight, cPlanes, cBitsPixel, lpbANDbits, lpbXORbits); + return m_hIcon; + } + + HICON CreateIconFromResource(PBYTE pBits, DWORD dwResSize, DWORD dwVersion = 0x00030000) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(pBits != NULL); + m_hIcon = ::CreateIconFromResource(pBits, dwResSize, TRUE, dwVersion); + return m_hIcon; + } + + HICON CreateIconFromResourceEx(PBYTE pbBits, DWORD cbBits, DWORD dwVersion = 0x00030000, int cxDesired = 0, int cyDesired = 0, UINT uFlags = LR_DEFAULTCOLOR) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(pbBits != NULL); + ATLASSERT(cbBits > 0); + m_hIcon = ::CreateIconFromResourceEx(pbBits, cbBits, TRUE, dwVersion, cxDesired, cyDesired, uFlags); + return m_hIcon; + } + + HICON CreateIconIndirect(PICONINFO pIconInfo) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(pIconInfo != NULL); + m_hIcon = ::CreateIconIndirect(pIconInfo); + return m_hIcon; + } + + HICON ExtractIcon(LPCTSTR lpszExeFileName, UINT nIconIndex) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(lpszExeFileName != NULL); + m_hIcon = ::ExtractIcon(ModuleHelper::GetModuleInstance(), lpszExeFileName, nIconIndex); + return m_hIcon; + } + + HICON ExtractAssociatedIcon(HINSTANCE hInst, LPTSTR lpIconPath, LPWORD lpiIcon) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(lpIconPath != NULL); + ATLASSERT(lpiIcon != NULL); + m_hIcon = ::ExtractAssociatedIcon(hInst, lpIconPath, lpiIcon); + return m_hIcon; + } + + BOOL DestroyIcon() + { + ATLASSERT(m_hIcon != NULL); + BOOL bRet = ::DestroyIcon(m_hIcon); + if(bRet != FALSE) + m_hIcon = NULL; + return bRet; + } + +// Operations + HICON CopyIcon() + { + ATLASSERT(m_hIcon != NULL); + return ::CopyIcon(m_hIcon); + } + + HICON DuplicateIcon() + { + ATLASSERT(m_hIcon != NULL); + return ::DuplicateIcon(NULL, m_hIcon); + } + + BOOL DrawIcon(HDC hDC, int x, int y) + { + ATLASSERT(m_hIcon != NULL); + return ::DrawIcon(hDC, x, y, m_hIcon); + } + + BOOL DrawIcon(HDC hDC, POINT pt) + { + ATLASSERT(m_hIcon != NULL); + return ::DrawIcon(hDC, pt.x, pt.y, m_hIcon); + } + + BOOL DrawIconEx(HDC hDC, int x, int y, int cxWidth, int cyWidth, UINT uStepIfAniCur = 0, HBRUSH hbrFlickerFreeDraw = NULL, UINT uFlags = DI_NORMAL) + { + ATLASSERT(m_hIcon != NULL); + return ::DrawIconEx(hDC, x, y, m_hIcon, cxWidth, cyWidth, uStepIfAniCur, hbrFlickerFreeDraw, uFlags); + } + + BOOL DrawIconEx(HDC hDC, POINT pt, SIZE size, UINT uStepIfAniCur = 0, HBRUSH hbrFlickerFreeDraw = NULL, UINT uFlags = DI_NORMAL) + { + ATLASSERT(m_hIcon != NULL); + return ::DrawIconEx(hDC, pt.x, pt.y, m_hIcon, size.cx, size.cy, uStepIfAniCur, hbrFlickerFreeDraw, uFlags); + } + + BOOL GetIconInfo(PICONINFO pIconInfo) const + { + ATLASSERT(m_hIcon != NULL); + ATLASSERT(pIconInfo != NULL); + return ::GetIconInfo(m_hIcon, pIconInfo); + } + +#if (_WIN32_WINNT >= 0x0600) + BOOL GetIconInfoEx(PICONINFOEX pIconInfo) const + { + ATLASSERT(m_hIcon != NULL); + ATLASSERT(pIconInfo != NULL); + return ::GetIconInfoEx(m_hIcon, pIconInfo); + } +#endif // (_WIN32_WINNT >= 0x0600) + +#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) + HRESULT LoadIconMetric(ATL::_U_STRINGorID icon, int lims) + { + ATLASSERT(m_hIcon == NULL); + USES_CONVERSION; + return ::LoadIconMetric(ModuleHelper::GetResourceInstance(), T2CW(icon.m_lpstr), lims, &m_hIcon); + } + + HRESULT LoadIconWithScaleDown(ATL::_U_STRINGorID icon, int cx, int cy) + { + ATLASSERT(m_hIcon == NULL); + USES_CONVERSION; + return ::LoadIconWithScaleDown(ModuleHelper::GetResourceInstance(), T2CW(icon.m_lpstr), cx, cy, &m_hIcon); + } + + HRESULT LoadOEMIconMetric(LPCTSTR lpstrIconName, int lims) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(IsOEMIcon(lpstrIconName)); + return ::LoadIconMetric(NULL, (LPCWSTR)lpstrIconName, lims, &m_hIcon); + } + + HRESULT LoadOEMIconWithScaleDown(LPCTSTR lpstrIconName, int cx, int cy) + { + ATLASSERT(m_hIcon == NULL); + ATLASSERT(IsOEMIcon(lpstrIconName)); + USES_CONVERSION; + return ::LoadIconWithScaleDown(NULL, (LPCWSTR)lpstrIconName, cx, cy, &m_hIcon); + } +#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) + + // Helper + static bool IsOEMIcon(LPCTSTR lpstrIconName) + { +#if (WINVER >= 0x0600) + return ((lpstrIconName == IDI_APPLICATION) || (lpstrIconName == IDI_ASTERISK) || (lpstrIconName == IDI_EXCLAMATION) || + (lpstrIconName == IDI_HAND) || (lpstrIconName == IDI_QUESTION) || (lpstrIconName == IDI_WINLOGO) || + (lpstrIconName == IDI_SHIELD)); +#else // !(WINVER >= 0x0600) + return ((lpstrIconName == IDI_APPLICATION) || (lpstrIconName == IDI_ASTERISK) || (lpstrIconName == IDI_EXCLAMATION) || + (lpstrIconName == IDI_HAND) || (lpstrIconName == IDI_QUESTION) || (lpstrIconName == IDI_WINLOGO)); +#endif // !(WINVER >= 0x0600) + } +}; + +typedef CIconT<false> CIconHandle; +typedef CIconT<true> CIcon; + + +/////////////////////////////////////////////////////////////////////////////// +// CCursor + +// protect template member from a winuser.h macro +#ifdef CopyCursor + #undef CopyCursor +#endif + +template <bool t_bManaged> +class CCursorT +{ +public: + HCURSOR m_hCursor; + +// Constructor/destructor/operators + CCursorT(HCURSOR hCursor = NULL) : m_hCursor(hCursor) + { } + + ~CCursorT() + { + if(t_bManaged && (m_hCursor != NULL)) + DestroyCursor(); + } + + CCursorT<t_bManaged>& operator =(HCURSOR hCursor) + { + Attach(hCursor); + return *this; + } + + void Attach(HCURSOR hCursor) + { + if(t_bManaged && (m_hCursor != NULL)) + DestroyCursor(); + m_hCursor = hCursor; + } + + HCURSOR Detach() + { + HCURSOR hCursor = m_hCursor; + m_hCursor = NULL; + return hCursor; + } + + operator HCURSOR() const { return m_hCursor; } + + bool IsNull() const { return m_hCursor == NULL; } + +// Create/destroy methods + HCURSOR LoadCursor(ATL::_U_STRINGorID cursor) + { + ATLASSERT(m_hCursor == NULL); + m_hCursor = ::LoadCursor(ModuleHelper::GetResourceInstance(), cursor.m_lpstr); + return m_hCursor; + } + + HCURSOR LoadSysCursor(LPCTSTR lpstrCursorName) + { + ATLASSERT(m_hCursor == NULL); + ATLASSERT((lpstrCursorName == IDC_ARROW) || (lpstrCursorName == IDC_IBEAM) || (lpstrCursorName == IDC_WAIT) || + (lpstrCursorName == IDC_CROSS) || (lpstrCursorName == IDC_UPARROW) || (lpstrCursorName == IDC_SIZE) || + (lpstrCursorName == IDC_ICON) || (lpstrCursorName == IDC_SIZENWSE) || (lpstrCursorName == IDC_SIZENESW) || + (lpstrCursorName == IDC_SIZEWE) || (lpstrCursorName == IDC_SIZENS) || (lpstrCursorName == IDC_SIZEALL) || + (lpstrCursorName == IDC_NO) || (lpstrCursorName == IDC_APPSTARTING) || (lpstrCursorName == IDC_HELP) || + (lpstrCursorName == IDC_HAND)); + m_hCursor = ::LoadCursor(NULL, lpstrCursorName); + return m_hCursor; + } + + // deprecated + HCURSOR LoadOEMCursor(LPCTSTR lpstrCursorName) + { + return LoadSysCursor(lpstrCursorName); + } + + HCURSOR LoadCursor(ATL::_U_STRINGorID cursor, int cxDesired, int cyDesired, UINT fuLoad = 0) + { + ATLASSERT(m_hCursor == NULL); + m_hCursor = (HCURSOR) ::LoadImage(ModuleHelper::GetResourceInstance(), cursor.m_lpstr, IMAGE_CURSOR, cxDesired, cyDesired, fuLoad); + return m_hCursor; + } + + HCURSOR LoadCursorFromFile(LPCTSTR pstrFilename) + { + ATLASSERT(m_hCursor == NULL); + ATLASSERT(pstrFilename != NULL); + m_hCursor = ::LoadCursorFromFile(pstrFilename); + return m_hCursor; + } + + HCURSOR CreateCursor(int xHotSpot, int yHotSpot, int nWidth, int nHeight, CONST VOID *pvANDPlane, CONST VOID *pvXORPlane) + { + ATLASSERT(m_hCursor == NULL); + m_hCursor = ::CreateCursor(ModuleHelper::GetResourceInstance(), xHotSpot, yHotSpot, nWidth, nHeight, pvANDPlane, pvXORPlane); + return m_hCursor; + } + + HCURSOR CreateCursorFromResource(PBYTE pBits, DWORD dwResSize, DWORD dwVersion = 0x00030000) + { + ATLASSERT(m_hCursor == NULL); + ATLASSERT(pBits != NULL); + m_hCursor = (HCURSOR)::CreateIconFromResource(pBits, dwResSize, FALSE, dwVersion); + return m_hCursor; + } + + HCURSOR CreateCursorFromResourceEx(PBYTE pbBits, DWORD cbBits, DWORD dwVersion = 0x00030000, int cxDesired = 0, int cyDesired = 0, UINT uFlags = LR_DEFAULTCOLOR) + { + ATLASSERT(m_hCursor == NULL); + ATLASSERT(pbBits != NULL); + ATLASSERT(cbBits > 0); + m_hCursor = (HCURSOR)::CreateIconFromResourceEx(pbBits, cbBits, FALSE, dwVersion, cxDesired, cyDesired, uFlags); + return m_hCursor; + } + + BOOL DestroyCursor() + { + ATLASSERT(m_hCursor != NULL); + BOOL bRet = ::DestroyCursor(m_hCursor); + if(bRet != FALSE) + m_hCursor = NULL; + return bRet; + } + +// Operations + HCURSOR CopyCursor() + { + ATLASSERT(m_hCursor != NULL); + return (HCURSOR)::CopyIcon((HICON)m_hCursor); + } + + BOOL GetCursorInfo(LPCURSORINFO pCursorInfo) + { + ATLASSERT(m_hCursor != NULL); + ATLASSERT(pCursorInfo != NULL); + return ::GetCursorInfo(pCursorInfo); + } +}; + +typedef CCursorT<false> CCursorHandle; +typedef CCursorT<true> CCursor; + + +/////////////////////////////////////////////////////////////////////////////// +// CResource - Wraps a generic Windows resource. +// Use it with custom resource types other than the +// standard RT_CURSOR, RT_BITMAP, etc. + +class CResource +{ +public: + HGLOBAL m_hGlobal; + HRSRC m_hResource; + +// Constructor/destructor + CResource() : m_hGlobal(NULL), m_hResource(NULL) + { } + + ~CResource() + { + Release(); + } + +// Load methods + bool Load(ATL::_U_STRINGorID Type, ATL::_U_STRINGorID ID) + { + ATLASSERT(m_hResource == NULL); + ATLASSERT(m_hGlobal == NULL); + + m_hResource = ::FindResource(ModuleHelper::GetResourceInstance(), ID.m_lpstr, Type.m_lpstr); + if(m_hResource == NULL) + return false; + + m_hGlobal = ::LoadResource(ModuleHelper::GetResourceInstance(), m_hResource); + if(m_hGlobal == NULL) + { + m_hResource = NULL; + return false; + } + + return true; + } + + bool LoadEx(ATL::_U_STRINGorID ID, ATL::_U_STRINGorID Type, WORD wLanguage) + { + ATLASSERT(m_hResource == NULL); + ATLASSERT(m_hGlobal == NULL); + + m_hResource = ::FindResourceEx(ModuleHelper::GetResourceInstance(), Type.m_lpstr, ID.m_lpstr, wLanguage); + if(m_hResource == NULL) + return false; + + m_hGlobal = ::LoadResource(ModuleHelper::GetResourceInstance(), m_hResource); + if(m_hGlobal == NULL) + { + m_hResource = NULL; + return false; + } + + return true; + } + +// Misc. operations + DWORD GetSize() const + { + ATLASSERT(m_hResource != NULL); + return ::SizeofResource(ModuleHelper::GetResourceInstance(), m_hResource); + } + + LPVOID Lock() + { + ATLASSERT(m_hResource != NULL); + ATLASSERT(m_hGlobal != NULL); + LPVOID pVoid = ::LockResource(m_hGlobal); + ATLASSERT(pVoid != NULL); + return pVoid; + } + + void Release() + { + if(m_hGlobal != NULL) + { + FreeResource(m_hGlobal); + m_hGlobal = NULL; + m_hResource = NULL; + } + } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// Toolbar resource descriptor + +struct _AtlToolBarData +{ + WORD wVersion; + WORD wWidth; + WORD wHeight; + WORD wItemCount; + + WORD* items() + { return (WORD*)(this+1); } +}; + + +/////////////////////////////////////////////////////////////////////////////// +// Global functions for loading resources + +inline HACCEL AtlLoadAccelerators(ATL::_U_STRINGorID table) +{ + return ::LoadAccelerators(ModuleHelper::GetResourceInstance(), table.m_lpstr); +} + +inline HMENU AtlLoadMenu(ATL::_U_STRINGorID menu) +{ + return ::LoadMenu(ModuleHelper::GetResourceInstance(), menu.m_lpstr); +} + +inline HBITMAP AtlLoadBitmap(ATL::_U_STRINGorID bitmap) +{ + return ::LoadBitmap(ModuleHelper::GetResourceInstance(), bitmap.m_lpstr); +} + +#ifdef OEMRESOURCE +inline HBITMAP AtlLoadSysBitmap(ATL::_U_STRINGorID bitmap) +{ +#ifdef _DEBUG + WORD wID = LOWORD(bitmap.m_lpstr); + ATLASSERT((wID >= 32734) && (wID <= 32767)); +#endif // _DEBUG + return ::LoadBitmap(NULL, bitmap.m_lpstr); +} +#endif // OEMRESOURCE + +inline HCURSOR AtlLoadCursor(ATL::_U_STRINGorID cursor) +{ + return ::LoadCursor(ModuleHelper::GetResourceInstance(), cursor.m_lpstr); +} + +inline HCURSOR AtlLoadSysCursor(LPCTSTR lpCursorName) +{ + ATLASSERT((lpCursorName == IDC_ARROW) || (lpCursorName == IDC_IBEAM) || (lpCursorName == IDC_WAIT) || + (lpCursorName == IDC_CROSS) || (lpCursorName == IDC_UPARROW) || (lpCursorName == IDC_SIZE) || + (lpCursorName == IDC_ICON) || (lpCursorName == IDC_SIZENWSE) || (lpCursorName == IDC_SIZENESW) || + (lpCursorName == IDC_SIZEWE) || (lpCursorName == IDC_SIZENS) || (lpCursorName == IDC_SIZEALL) || + (lpCursorName == IDC_NO) || (lpCursorName == IDC_APPSTARTING) || (lpCursorName == IDC_HELP) || + (lpCursorName == IDC_HAND)); + return ::LoadCursor(NULL, lpCursorName); +} + +inline HICON AtlLoadIcon(ATL::_U_STRINGorID icon) +{ + return ::LoadIcon(ModuleHelper::GetResourceInstance(), icon.m_lpstr); +} + +inline HICON AtlLoadSysIcon(LPCTSTR lpIconName) +{ +#if (WINVER >= 0x0600) + ATLASSERT((lpIconName == IDI_APPLICATION) || (lpIconName == IDI_ASTERISK) || (lpIconName == IDI_EXCLAMATION) || + (lpIconName == IDI_HAND) || (lpIconName == IDI_QUESTION) || (lpIconName == IDI_WINLOGO) || + (lpIconName == IDI_SHIELD)); +#else // !(WINVER >= 0x0600) + ATLASSERT((lpIconName == IDI_APPLICATION) || (lpIconName == IDI_ASTERISK) || (lpIconName == IDI_EXCLAMATION) || + (lpIconName == IDI_HAND) || (lpIconName == IDI_QUESTION) || (lpIconName == IDI_WINLOGO)); +#endif // !(WINVER >= 0x0600) + return ::LoadIcon(NULL, lpIconName); +} + +inline HBITMAP AtlLoadBitmapImage(ATL::_U_STRINGorID bitmap, UINT fuLoad = LR_DEFAULTCOLOR) +{ + return (HBITMAP)::LoadImage(ModuleHelper::GetResourceInstance(), bitmap.m_lpstr, IMAGE_BITMAP, 0, 0, fuLoad); +} + +inline HCURSOR AtlLoadCursorImage(ATL::_U_STRINGorID cursor, UINT fuLoad = LR_DEFAULTCOLOR | LR_DEFAULTSIZE, int cxDesired = 0, int cyDesired = 0) +{ + return (HCURSOR)::LoadImage(ModuleHelper::GetResourceInstance(), cursor.m_lpstr, IMAGE_CURSOR, cxDesired, cyDesired, fuLoad); +} + +inline HICON AtlLoadIconImage(ATL::_U_STRINGorID icon, UINT fuLoad = LR_DEFAULTCOLOR | LR_DEFAULTSIZE, int cxDesired = 0, int cyDesired = 0) +{ + return (HICON)::LoadImage(ModuleHelper::GetResourceInstance(), icon.m_lpstr, IMAGE_ICON, cxDesired, cyDesired, fuLoad); +} + +#ifdef OEMRESOURCE +inline HBITMAP AtlLoadSysBitmapImage(WORD wBitmapID, UINT fuLoad = LR_DEFAULTCOLOR) +{ + ATLASSERT((wBitmapID >= 32734) && (wBitmapID <= 32767)); + ATLASSERT((fuLoad & LR_LOADFROMFILE) == 0); // this one doesn't load from a file + return (HBITMAP)::LoadImage(NULL, MAKEINTRESOURCE(wBitmapID), IMAGE_BITMAP, 0, 0, fuLoad); +} +#endif // OEMRESOURCE + +inline HCURSOR AtlLoadSysCursorImage(ATL::_U_STRINGorID cursor, UINT fuLoad = LR_DEFAULTCOLOR | LR_DEFAULTSIZE, int cxDesired = 0, int cyDesired = 0) +{ +#ifdef _DEBUG + WORD wID = LOWORD(cursor.m_lpstr); + ATLASSERT(((wID >= 32512) && (wID <= 32516)) || ((wID >= 32640) && (wID <= 32648)) || (wID == 32650) || (wID == 32651)); + ATLASSERT((fuLoad & LR_LOADFROMFILE) == 0); // this one doesn't load from a file +#endif // _DEBUG + return (HCURSOR)::LoadImage(NULL, cursor.m_lpstr, IMAGE_CURSOR, cxDesired, cyDesired, fuLoad); +} + +inline HICON AtlLoadSysIconImage(ATL::_U_STRINGorID icon, UINT fuLoad = LR_DEFAULTCOLOR | LR_DEFAULTSIZE, int cxDesired = 0, int cyDesired = 0) +{ +#ifdef _DEBUG + WORD wID = LOWORD(icon.m_lpstr); + ATLASSERT((wID >= 32512) && (wID <= 32517)); + ATLASSERT((fuLoad & LR_LOADFROMFILE) == 0); // this one doesn't load from a file +#endif // _DEBUG + return (HICON)::LoadImage(NULL, icon.m_lpstr, IMAGE_ICON, cxDesired, cyDesired, fuLoad); +} + +inline bool AtlLoadString(UINT uID, BSTR& bstrText) +{ + USES_CONVERSION; + ATLASSERT(bstrText == NULL); + + LPTSTR lpstrText = NULL; + int nRes = 0; + for(int nLen = 256; ; nLen *= 2) + { + ATLTRY(lpstrText = new TCHAR[nLen]); + if(lpstrText == NULL) + break; + nRes = ::LoadString(ModuleHelper::GetResourceInstance(), uID, lpstrText, nLen); + if(nRes < nLen - 1) + break; + delete [] lpstrText; + lpstrText = NULL; + } + + if(lpstrText != NULL) + { + if(nRes != 0) + bstrText = ::SysAllocString(T2OLE(lpstrText)); + delete [] lpstrText; + } + + return (bstrText != NULL) ? true : false; +} + +} // namespace WTL + +#endif // __ATLUSER_H__ diff --git a/Examples/WhisperDesktop/Utils/WTL/atlwinx.h b/Examples/WhisperDesktop/Utils/WTL/atlwinx.h new file mode 100644 index 0000000..b89c513 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/WTL/atlwinx.h @@ -0,0 +1,623 @@ +// Windows Template Library - WTL version 10.0 +// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved. +// +// This file is a part of the Windows Template Library. +// The use and distribution terms for this software are covered by the +// Microsoft Public License (http://opensource.org/licenses/MS-PL) +// which can be found in the file MS-PL.txt at the root folder. + +#ifndef __ATLWINX_H__ +#define __ATLWINX_H__ + +#pragma once + +#ifndef __ATLAPP_H__ + #error atlwinx.h requires atlapp.h to be included first +#endif + +#include <atlwin.h> + + +/////////////////////////////////////////////////////////////////////////////// +// Classes in this file: +// +// CWindowEx + + +///////////////////////////////////////////////////////////////////////////// +// Additional macros needed for template classes + +#ifndef DECLARE_WND_CLASS_EX2 + #define DECLARE_WND_CLASS_EX2(WndClassName, EnclosingClass, style, bkgnd) \ + static ATL::CWndClassInfo& GetWndClassInfo() \ + { \ + static ATL::CWndClassInfo wc = \ + { \ + { sizeof(WNDCLASSEX), style, EnclosingClass::StartWindowProc, \ + 0, 0, NULL, NULL, NULL, (HBRUSH)(bkgnd + 1), NULL, WndClassName, NULL }, \ + NULL, NULL, IDC_ARROW, TRUE, 0, _T("") \ + }; \ + return wc; \ + } +#endif // DECLARE_WND_CLASS_EX2 + +#ifndef DECLARE_WND_SUPERCLASS2 + #define DECLARE_WND_SUPERCLASS2(WndClassName, EnclosingClass, OrigWndClassName) \ + static ATL::CWndClassInfo& GetWndClassInfo() \ + { \ + static ATL::CWndClassInfo wc = \ + { \ + { sizeof(WNDCLASSEX), 0, EnclosingClass::StartWindowProc, \ + 0, 0, NULL, NULL, NULL, NULL, NULL, WndClassName, NULL }, \ + OrigWndClassName, NULL, NULL, TRUE, 0, _T("") \ + }; \ + return wc; \ + } +#endif // DECLARE_WND_SUPERCLASS2 + + +/////////////////////////////////////////////////////////////////////////////// +// Command Chaining Macros + +#define CHAIN_COMMANDS(theChainClass) \ + if(uMsg == WM_COMMAND) \ + CHAIN_MSG_MAP(theChainClass) + +#define CHAIN_COMMANDS_ALT(theChainClass, msgMapID) \ + if(uMsg == WM_COMMAND) \ + CHAIN_MSG_MAP_ALT(theChainClass, msgMapID) + +#define CHAIN_COMMANDS_MEMBER(theChainMember) \ + if(uMsg == WM_COMMAND) \ + CHAIN_MSG_MAP_MEMBER(theChainMember) + +#define CHAIN_COMMANDS_ALT_MEMBER(theChainMember, msgMapID) \ + if(uMsg == WM_COMMAND) \ + CHAIN_MSG_MAP_ALT_MEMBER(theChainMember, msgMapID) + + +/////////////////////////////////////////////////////////////////////////////// +// Macros for parent message map to selectively reflect control messages + +// NOTE: ReflectNotifications is a member of ATL's CWindowImplRoot +// (and overridden in 2 cases - CContainedWindowT and CAxHostWindow) +// Since we can't modify ATL, we'll provide the needed additions +// in a separate function (that is not a member of CWindowImplRoot) + +namespace WTL +{ + +inline LRESULT WtlReflectNotificationsFiltered(HWND hWndParent, UINT uMsg, WPARAM wParam, LPARAM lParam, BOOL& bHandled, + UINT uMsgFilter = WM_NULL, UINT_PTR idFromFilter = 0, HWND hWndChildFilter = NULL) +{ + if((uMsgFilter != WM_NULL) && (uMsgFilter != uMsg)) + { + // The notification message doesn't match the filter. + bHandled = FALSE; + return 1; + } + + HWND hWndChild = NULL; + UINT_PTR idFrom = 0; + + switch(uMsg) + { + case WM_COMMAND: + if(lParam != NULL) // not from a menu + { + hWndChild = (HWND)lParam; + idFrom = (UINT_PTR)LOWORD(wParam); + } + break; + case WM_NOTIFY: + hWndChild = ((LPNMHDR)lParam)->hwndFrom; + idFrom = ((LPNMHDR)lParam)->idFrom; + break; + case WM_PARENTNOTIFY: + switch(LOWORD(wParam)) + { + case WM_CREATE: + case WM_DESTROY: + hWndChild = (HWND)lParam; + idFrom = (UINT_PTR)HIWORD(wParam); + break; + default: + hWndChild = ::GetDlgItem(hWndParent, HIWORD(wParam)); + idFrom = (UINT_PTR)::GetDlgCtrlID(hWndChild); + break; + } + break; + case WM_DRAWITEM: + if(wParam) // not from a menu + { + hWndChild = ((LPDRAWITEMSTRUCT)lParam)->hwndItem; + idFrom = (UINT_PTR)wParam; + } + break; + case WM_MEASUREITEM: + if(wParam) // not from a menu + { + hWndChild = ::GetDlgItem(hWndParent, ((LPMEASUREITEMSTRUCT)lParam)->CtlID); + idFrom = (UINT_PTR)wParam; + } + break; + case WM_COMPAREITEM: + if(wParam) // not from a menu + { + hWndChild = ((LPCOMPAREITEMSTRUCT)lParam)->hwndItem; + idFrom = (UINT_PTR)wParam; + } + break; + case WM_DELETEITEM: + if(wParam) // not from a menu + { + hWndChild = ((LPDELETEITEMSTRUCT)lParam)->hwndItem; + idFrom = (UINT_PTR)wParam; + } + break; + case WM_VKEYTOITEM: + case WM_CHARTOITEM: + case WM_HSCROLL: + case WM_VSCROLL: + case WM_CTLCOLORBTN: + case WM_CTLCOLORDLG: + case WM_CTLCOLOREDIT: + case WM_CTLCOLORLISTBOX: + case WM_CTLCOLORMSGBOX: + case WM_CTLCOLORSCROLLBAR: + case WM_CTLCOLORSTATIC: + hWndChild = (HWND)lParam; + idFrom = (UINT_PTR)::GetDlgCtrlID(hWndChild); + break; + default: + break; + } + + if((hWndChild == NULL) || + ((hWndChildFilter != NULL) && (hWndChildFilter != hWndChild))) + { + // Either hWndChild isn't valid, or + // hWndChild doesn't match the filter. + bHandled = FALSE; + return 1; + } + + if((idFromFilter != 0) && (idFromFilter != idFrom)) + { + // The dialog control id doesn't match the filter. + bHandled = FALSE; + return 1; + } + + ATLASSERT(::IsWindow(hWndChild)); + LRESULT lResult = ::SendMessage(hWndChild, OCM__BASE + uMsg, wParam, lParam); + if((lResult == 0) && (uMsg >= WM_CTLCOLORMSGBOX) && (uMsg <= WM_CTLCOLORSTATIC)) + { + // Try to prevent problems with WM_CTLCOLOR* messages when + // the message wasn't really handled + bHandled = FALSE; + } + + return lResult; +} + +} // namespace WTL + +// Try to prevent problems with WM_CTLCOLOR* messages when +// the message wasn't really handled +#define REFLECT_NOTIFICATIONS_EX() \ +{ \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if((lResult == 0) && (uMsg >= WM_CTLCOLORMSGBOX) && (uMsg <= WM_CTLCOLORSTATIC)) \ + bHandled = FALSE; \ + if(bHandled) \ + return TRUE; \ +} + +#define REFLECT_NOTIFICATIONS_MSG_FILTERED(uMsgFilter) \ + { \ + bHandled = TRUE; \ + lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, uMsgFilter, 0, NULL); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFICATIONS_ID_FILTERED(idFromFilter) \ + { \ + bHandled = TRUE; \ + lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, WM_NULL, idFromFilter, NULL); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFICATIONS_HWND_FILTERED(hWndChildFilter) \ + { \ + bHandled = TRUE; \ + lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, WM_NULL, 0, hWndChildFilter); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFICATIONS_MSG_ID_FILTERED(uMsgFilter, idFromFilter) \ + { \ + bHandled = TRUE; \ + lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, uMsgFilter, idFromFilter, NULL); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFICATIONS_MSG_HWND_FILTERED(uMsgFilter, hWndChildFilter) \ + { \ + bHandled = TRUE; \ + lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, uMsgFilter, 0, hWndChildFilter); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_COMMAND(id, code) \ + if((uMsg == WM_COMMAND) && (id == LOWORD(wParam)) && (code == HIWORD(wParam))) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_COMMAND_ID(id) \ + if((uMsg == WM_COMMAND) && (id == LOWORD(wParam))) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_COMMAND_CODE(code) \ + if((uMsg == WM_COMMAND) && (code == HIWORD(wParam))) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_COMMAND_RANGE(idFirst, idLast) \ + if((uMsg == WM_COMMAND) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_COMMAND_RANGE_CODE(idFirst, idLast, code) \ + if((uMsg == WM_COMMAND) && (code == HIWORD(wParam)) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFY(id, cd) \ + if((uMsg == WM_NOTIFY) && (id == ((LPNMHDR)lParam)->idFrom) && (cd == ((LPNMHDR)lParam)->code)) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFY_ID(id) \ + if((uMsg == WM_NOTIFY) && (id == ((LPNMHDR)lParam)->idFrom)) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFY_CODE(cd) \ + if((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code)) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFY_RANGE(idFirst, idLast) \ + if((uMsg == WM_NOTIFY) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + +#define REFLECT_NOTIFY_RANGE_CODE(idFirst, idLast, cd) \ + if((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \ + { \ + bHandled = TRUE; \ + lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \ + if(bHandled) \ + return TRUE; \ + } + + +/////////////////////////////////////////////////////////////////////////////// +// GetClassLong/SetClassLong redefinition to avoid problems with class members + +#ifdef SetClassLongPtrA + #undef SetClassLongPtrA + inline LONG_PTR SetClassLongPtrA(HWND hWnd, int nIndex, LONG_PTR dwNewLong) + { + return ::SetClassLongA(hWnd, nIndex, LONG(dwNewLong)); + } +#endif + +#ifdef SetClassLongPtrW + #undef SetClassLongPtrW + inline LONG_PTR SetClassLongPtrW(HWND hWnd, int nIndex, LONG_PTR dwNewLong) + { + return ::SetClassLongW(hWnd, nIndex, LONG(dwNewLong)); + } +#endif + +#ifdef GetClassLongPtrA + #undef GetClassLongPtrA + inline LONG_PTR GetClassLongPtrA(HWND hWnd, int nIndex) + { + return ::GetClassLongA(hWnd, nIndex); + } +#endif + +#ifdef GetClassLongPtrW + #undef GetClassLongPtrW + inline LONG_PTR GetClassLongPtrW(HWND hWnd, int nIndex) + { + return ::GetClassLongW(hWnd, nIndex); + } +#endif + + +/////////////////////////////////////////////////////////////////////////////// +// CWindowEx - extension of ATL::CWindow + +namespace WTL +{ + +class CWindowEx : public ATL::CWindow +{ +public: + CWindowEx(HWND hWnd = NULL) : ATL::CWindow(hWnd) + { } + + CWindowEx& operator =(HWND hWnd) + { + m_hWnd = hWnd; + return *this; + } + + operator HWND() const + { + return m_hWnd; + } + +// Methods + BOOL PrintWindow(HDC hDC, UINT uFlags = 0) + { + ATLASSERT(::IsWindow(m_hWnd)); + return ::PrintWindow(m_hWnd, hDC, uFlags); + } + + BOOL DragDetect(POINT pt) + { + ATLASSERT(::IsWindow(m_hWnd)); + return ::DragDetect(m_hWnd, pt); + } + + BOOL DragDetect() + { + ATLASSERT(::IsWindow(m_hWnd)); + + POINT pt = {}; + ::GetCursorPos(&pt); + return ::DragDetect(m_hWnd, pt); + } + + CWindowEx GetAncestor(UINT uFlags) const + { + ATLASSERT(::IsWindow(m_hWnd)); + return CWindowEx(::GetAncestor(m_hWnd, uFlags)); + } + + // Note: Does not work properly on Vista Aero and above + BOOL AnimateWindow(DWORD dwFlags, DWORD dwTime = 200) + { + ATLASSERT(::IsWindow(m_hWnd)); + return ::AnimateWindow(m_hWnd, dwTime, dwFlags); + } + + BOOL FlashWindowEx(DWORD dwFlags, UINT uCount, DWORD dwTimeout = 0) + { + ATLASSERT(::IsWindow(m_hWnd)); + + FLASHWINFO fi = { sizeof(FLASHWINFO) }; + fi.hwnd = m_hWnd; + fi.dwFlags = dwFlags; + fi.uCount = uCount; + fi.dwTimeout = dwTimeout; + return ::FlashWindowEx(&fi); + } + + BOOL StopFlashWindowEx() + { + ATLASSERT(::IsWindow(m_hWnd)); + + FLASHWINFO fi = { sizeof(FLASHWINFO) }; + fi.hwnd = m_hWnd; + fi.dwFlags = FLASHW_STOP; + return ::FlashWindowEx(&fi); + } + +// Class long properties + DWORD GetClassLong(int nIndex) const + { + ATLASSERT(::IsWindow(m_hWnd)); + return ::GetClassLong(m_hWnd, nIndex); + } + + DWORD SetClassLong(int nIndex, LONG dwNewLong) + { + ATLASSERT(::IsWindow(m_hWnd)); + return ::SetClassLong(m_hWnd, nIndex, dwNewLong); + } + + ULONG_PTR GetClassLongPtr(int nIndex) const + { + ATLASSERT(::IsWindow(m_hWnd)); + return ::GetClassLongPtr(m_hWnd, nIndex); + } + + ULONG_PTR SetClassLongPtr(int nIndex, LONG_PTR dwNewLong) + { + ATLASSERT(::IsWindow(m_hWnd)); + return ::SetClassLongPtr(m_hWnd, nIndex, dwNewLong); + } + +// Layered windows + BOOL SetLayeredWindowAttributes(COLORREF crlKey, BYTE byteAlpha, DWORD dwFlags) + { + ATLASSERT(::IsWindow(m_hWnd)); + ATLASSERT((GetExStyle() & WS_EX_LAYERED) != 0); + + return ::SetLayeredWindowAttributes(m_hWnd, crlKey, byteAlpha, dwFlags); + } + + BOOL UpdateLayeredWindow(HDC hdcDst, LPPOINT pptDst, LPSIZE psize, HDC hdcSrc, LPPOINT pptSrc, COLORREF crlKey, BLENDFUNCTION* pblend, DWORD dwFlags) + { + ATLASSERT(::IsWindow(m_hWnd)); + ATLASSERT((GetExStyle() & WS_EX_LAYERED) != 0); + + return ::UpdateLayeredWindow(m_hWnd, hdcDst, pptDst, psize, hdcSrc, pptSrc, crlKey, pblend, dwFlags); + } + + BOOL UpdateLayeredWindow(LPPOINT pptDst = NULL) + { + ATLASSERT(::IsWindow(m_hWnd)); + ATLASSERT((GetExStyle() & WS_EX_LAYERED) != 0); + + return ::UpdateLayeredWindow(m_hWnd, NULL, pptDst, NULL, NULL, NULL, CLR_NONE, NULL, 0); + } + + BOOL GetLayeredWindowAttributes(COLORREF* pcrlKey, BYTE* pbyteAlpha, DWORD* pdwFlags) const + { + ATLASSERT(::IsWindow(m_hWnd)); + ATLASSERT((GetExStyle() & WS_EX_LAYERED) != 0); + + return ::GetLayeredWindowAttributes(m_hWnd, pcrlKey, pbyteAlpha, pdwFlags); + } + +// Mouse tracking + BOOL StartTrackMouseLeave() + { + ATLASSERT(::IsWindow(m_hWnd)); + + TRACKMOUSEEVENT tme = {}; + tme.cbSize = sizeof(TRACKMOUSEEVENT); + tme.dwFlags = TME_LEAVE; + tme.hwndTrack = m_hWnd; + return ::TrackMouseEvent(&tme); + } + + BOOL StartTrackMouse(DWORD dwFlags, DWORD dwHoverTime = HOVER_DEFAULT) + { + ATLASSERT(::IsWindow(m_hWnd)); + + TRACKMOUSEEVENT tme = {}; + tme.cbSize = sizeof(TRACKMOUSEEVENT); + tme.dwFlags = dwFlags; + tme.hwndTrack = m_hWnd; + tme.dwHoverTime = dwHoverTime; + return ::TrackMouseEvent(&tme); + } + + BOOL CancelTrackMouse(DWORD dwType) + { + ATLASSERT(::IsWindow(m_hWnd)); + + TRACKMOUSEEVENT tme = {}; + tme.cbSize = sizeof(TRACKMOUSEEVENT); + tme.dwFlags = TME_CANCEL | dwType; + tme.hwndTrack = m_hWnd; + return ::TrackMouseEvent(&tme); + } + +// CString support +#ifdef __ATLSTR_H__ + int GetWindowText(ATL::CString& strText) const + { + int nLength = GetWindowTextLength(); + LPTSTR pszText = strText.GetBuffer(nLength + 1); + nLength = ::GetWindowText(m_hWnd, pszText, nLength + 1); + strText.ReleaseBuffer(nLength); + + return nLength; + } + + UINT GetDlgItemText(int nID, ATL::CString& strText) const + { + ATLASSERT(::IsWindow(m_hWnd)); + + HWND hItem = GetDlgItem(nID); + if(hItem != NULL) + { + int nLength = ::GetWindowTextLength(hItem); + LPTSTR pszText = strText.GetBuffer(nLength + 1); + nLength = ::GetWindowText(hItem, pszText, nLength + 1); + strText.ReleaseBuffer(nLength); + + return nLength; + } + else + { + strText.Empty(); + + return 0; + } + } +#endif // __ATLSTR_H__ + +// Dialog window only + UINT DlgGetDefID() const + { + ATLASSERT(::IsWindow(m_hWnd)); + + LRESULT lRet = ::SendMessage(m_hWnd, DM_GETDEFID, 0, 0L); + UINT uID = 0U; + if(HIWORD(lRet) == DC_HASDEFID) + uID = LOWORD(lRet); + + return uID; + } + + void DlgSetDefID(UINT uID) + { + ATLASSERT(::IsWindow(m_hWnd)); + + ::SendMessage(m_hWnd, DM_SETDEFID, uID, 0L); + } + + void DlgReposition() + { + ATLASSERT(::IsWindow(m_hWnd)); + + ::SendMessage(m_hWnd, DM_REPOSITION, 0, 0L); + } +}; + +} // namespace WTL + +#endif // __ATLWINX_H__ diff --git a/Examples/WhisperDesktop/Utils/logger.cpp b/Examples/WhisperDesktop/Utils/logger.cpp new file mode 100644 index 0000000..5c7c257 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/logger.cpp @@ -0,0 +1,71 @@ +#include "stdafx.h" +#include "logger.h" +#include "miscUtils.h" + +namespace +{ + using namespace Whisper; + + // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] + // Lowest is red, middle is yellow, highest is green. + static const std::array<const char*, 10> k_colors = + { + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", + }; + + static int colorIndex( const sToken& tok ) + { + const float p = tok.probability; + const float p3 = p * p * p; + int col = (int)( p3 * float( k_colors.size() ) ); + col = std::max( 0, std::min( (int)k_colors.size() - 1, col ) ); + return col; + } +} + +void printTimeStamp( CStringA& rdi, Whisper::sTimeSpan ts ) +{ + sTimeSpanFields fields = ts; + uint32_t msec = fields.ticks / 10'000; + uint32_t hr = fields.days * 24 + fields.hours; + uint32_t min = fields.minutes; + uint32_t sec = fields.seconds; + rdi.AppendFormat( "%02d:%02d:%02d.%03d", hr, min, sec, msec ); +} + +HRESULT logNewSegments( const iTranscribeResult* results, size_t newSegments, bool printSpecial ) +{ + sTranscribeLength length; + CHECK( results->getSize( length ) ); + + const size_t len = length.countSegments; + size_t i = len - newSegments; + + const sSegment* const segments = results->getSegments(); + const sToken* const tokens = results->getTokens(); + + CStringA str; + for( ; i < len; i++ ) + { + const sSegment& seg = segments[ i ]; + str = "["; + printTimeStamp( str, seg.time.begin ); + str += " --> "; + printTimeStamp( str, seg.time.end ); + str += "] "; + + for( uint32_t j = 0; j < seg.countTokens; j++ ) + { + const sToken& tok = tokens[ seg.firstToken + j ]; + if( !printSpecial && ( tok.flags & eTokenFlags::Special ) ) + continue; + str += k_colors[ colorIndex( tok ) ]; + str += tok.text; + str += "\033[0m"; + } + logInfo( u8"%s", cstr( str ) ); + } + + return S_OK; +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/logger.h b/Examples/WhisperDesktop/Utils/logger.h new file mode 100644 index 0000000..07ec012 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/logger.h @@ -0,0 +1,36 @@ +#pragma once +#include <whisperWindows.h> +#include <cstdarg> + +void logMessage( Whisper::eLogLevel lvl, const char8_t* pszFormat, va_list args ); + +#define LOG_MESSAGE_IMPL( lvl ) \ + std::va_list args; \ + va_start( args, pszFormat ); \ + logMessage( lvl, pszFormat, args ); \ + va_end( args ) + +inline void logError( const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( Whisper::eLogLevel::Error ); +} +inline void logWarning( const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( Whisper::eLogLevel::Warning ); +} +inline void logInfo( const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( Whisper::eLogLevel::Info ); +} +inline void logDebug( const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( Whisper::eLogLevel::Debug ); +} + +#undef LOG_MESSAGE_IMPL + +HRESULT logNewSegments( const Whisper::iTranscribeResult* results, size_t newSegments, bool printSpecial = false ); + +void clearLastError(); +bool getLastError( CString& rdi ); +void printTimeStamp( CStringA& rdi, Whisper::sTimeSpan ts );
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/miscUtils.cpp b/Examples/WhisperDesktop/Utils/miscUtils.cpp new file mode 100644 index 0000000..485cf00 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/miscUtils.cpp @@ -0,0 +1,254 @@ +#include "stdafx.h" +#include "miscUtils.h" + +namespace +{ + wchar_t* formatMessage( HRESULT hr ) + { + wchar_t* err; + if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + hr, + MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), + (LPTSTR)&err, + 0, + NULL ) ) + return err; + return nullptr; + } +} + +CString formatErrorMessage( HRESULT hr ) +{ + CString message; + const wchar_t* err = formatMessage( hr ); + if( nullptr != err ) + { + message = err; + LocalFree( (HLOCAL)err ); + message.TrimRight(); + } + else + message.Format( L"Error code %i (0x%08X)", hr, hr ); + + return message; +} + +void reportFatalError( const char* what, HRESULT hr ) +{ + CString message; + message.Format( L"%S\n%S\n", "Unable to start the application.", what ); + message += formatErrorMessage( hr ); + ::MessageBox( nullptr, message, L"Whisper Desktop Startup", MB_OK | MB_ICONERROR ); +} + +namespace +{ + using Whisper::eModelImplementation; + + struct sImplString + { + eModelImplementation val; + LPCTSTR str; + }; + static const std::array<sImplString, 3> s_implStrings = + { + sImplString{ eModelImplementation::GPU, L"GPU" }, + sImplString{ eModelImplementation::Hybrid, L"Hybrid" }, + sImplString{ eModelImplementation::Reference, L"Reference" }, + }; +} + +HRESULT implParse( const CString& s, eModelImplementation& rdi ) +{ + for( const auto& is : s_implStrings ) + { + if( 0 != s.CompareNoCase( is.str ) ) + continue; + rdi = is.val;; + return S_OK; + } + return E_INVALIDARG; +} + +LPCTSTR implString( eModelImplementation i ) +{ + for( const auto& is : s_implStrings ) + if( is.val == i ) + return is.str; + return nullptr; +} + +void implPopulateCombobox( CComboBox& cb, Whisper::eModelImplementation impl ) +{ + int curSel = 0; + int idx = 0; + for( const auto& is : s_implStrings ) + { + cb.AddString( is.str ); + if( is.val == impl ) + curSel = idx; + idx++; + } + cb.SetCurSel( curSel ); +} + +Whisper::eModelImplementation implGetValue( CComboBox& cb ) +{ + int curSel = cb.GetCurSel(); + if( curSel < 0 ) + return (Whisper::eModelImplementation)0; + return s_implStrings[ curSel ].val; +} + +ThreadPoolWork::~ThreadPoolWork() +{ + if( nullptr != work ) + { + CloseThreadpoolWork( work ); + work = nullptr; + } +} + +void __stdcall ThreadPoolWork::callback( PTP_CALLBACK_INSTANCE Instance, PVOID Context, PTP_WORK Work ) +{ + iThreadPoolCallback* cb = (iThreadPoolCallback*)Context; + cb->poolCallback(); +} + +HRESULT ThreadPoolWork::create( iThreadPoolCallback* cb ) +{ + if( nullptr == cb ) + return E_POINTER; + if( nullptr != work ) + return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); + + work = CreateThreadpoolWork( &callback, cb, nullptr ); + if( nullptr != work ) + return S_OK; + + return HRESULT_FROM_WIN32( GetLastError() ); +} + +HRESULT ThreadPoolWork::post() +{ + if( nullptr == work ) + return OLE_E_BLANK; + SubmitThreadpoolWork( work ); + return S_OK; +} + +void makeUtf16( CString& rdi, const char* utf8 ) +{ + const size_t length = strlen( utf8 ); + int count = MultiByteToWideChar( CP_UTF8, 0, utf8, (int)length, nullptr, 0 ); + wchar_t* p = rdi.GetBufferSetLength( count ); + MultiByteToWideChar( CP_UTF8, 0, utf8, (int)length, p, count ); + rdi.ReleaseBuffer(); +} + +void makeUtf8( CStringA& rdi, const CString& utf16 ) +{ + int count = WideCharToMultiByte( CP_UTF8, 0, utf16, utf16.GetLength(), nullptr, 0, nullptr, nullptr ); + char* s = rdi.GetBufferSetLength( count + 1 ); + count = WideCharToMultiByte( CP_UTF8, 0, utf16, utf16.GetLength(), s, count, nullptr, nullptr ); + rdi.ReleaseBufferSetLength( count ); +} + +constexpr int ofnBufferLength = 2048; + +bool getOpenFileName( HWND owner, LPCTSTR title, LPCTSTR filter, CString& path ) +{ + wchar_t buffer[ ofnBufferLength ]; + buffer[ 0 ] = 0; + OPENFILENAME ofn; + memset( &ofn, 0, sizeof( ofn ) ); + ofn.lStructSize = sizeof( OPENFILENAME ); + ofn.hwndOwner = owner; + ofn.lpstrFilter = filter; + ofn.lpstrTitle = title; + ofn.Flags = OFN_EXPLORER | OFN_FILEMUSTEXIST | OFN_PATHMUSTEXIST; + ofn.lpstrFile = buffer; + ofn.nMaxFile = ofnBufferLength - 1; + + CString dir; + if( path.GetLength() > 0 && path.GetLength() < ofnBufferLength ) + wcsncpy_s( buffer, path, path.GetLength() ); + + if( !GetOpenFileName( &ofn ) ) + { + path = L""; + return false; + } + else + { + path = ofn.lpstrFile; + return true; + } +} + +bool getSaveFileName( HWND owner, LPCTSTR title, LPCTSTR filter, CString& path, DWORD* filterIndex ) +{ + wchar_t buffer[ ofnBufferLength ]; + buffer[ 0 ] = 0; + + OPENFILENAME ofn; + memset( &ofn, 0, sizeof( ofn ) ); + ofn.lStructSize = sizeof( OPENFILENAME ); + ofn.hwndOwner = owner; + ofn.lpstrFilter = filter; + ofn.lpstrTitle = title; + ofn.Flags = OFN_EXPLORER | OFN_PATHMUSTEXIST; + ofn.lpstrFile = buffer; + ofn.nMaxFile = ofnBufferLength - 1; + if( nullptr != filterIndex ) + ofn.nFilterIndex = *filterIndex + 1; + + if( path.GetLength() > 0 && path.GetLength() < ofnBufferLength ) + wcsncpy_s( buffer, path, path.GetLength() ); + + if( !GetSaveFileName( &ofn ) ) + return false; + + path = ofn.lpstrFile; + + if( nullptr != filterIndex ) + *filterIndex = ofn.nFilterIndex - 1; + + return true; +} + +void reportError( HWND owner, LPCTSTR text, LPCTSTR title, HRESULT hr ) +{ + if( nullptr == title ) + title = L"Operation Failed"; + + CString message = text; + message.TrimRight(); + if( FAILED( hr ) ) + { + message += L"\n"; + message += formatErrorMessage( hr ); + } + + ::MessageBox( owner, message, title, MB_OK | MB_ICONWARNING ); +} + +HRESULT writeUtf8Bom( CAtlFile& file ) +{ + const std::array<uint8_t, 3> bom = { 0xEF, 0xBB, 0xBF }; + return file.Write( bom.data(), 3 ); +} + +bool isInvalidTranslate( HWND owner, uint32_t lang, bool translate ) +{ + if( !translate ) + return false; + constexpr uint32_t english = 0x6E65; + if( lang != english ) + return false; + + LPCTSTR message = L"The translate feature translates speech to English.\nIt’s not available when the audio language is already English."; + MessageBox( owner, message, L"Incompatible parameters", MB_OK | MB_ICONINFORMATION ); + return true; +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/Utils/miscUtils.h b/Examples/WhisperDesktop/Utils/miscUtils.h new file mode 100644 index 0000000..8cf8151 --- /dev/null +++ b/Examples/WhisperDesktop/Utils/miscUtils.h @@ -0,0 +1,72 @@ +#pragma once +#include <iContext.h> +#include "logger.h" + +CString formatErrorMessage( HRESULT hr ); + +void reportFatalError( const char* what, HRESULT hr ); + +#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } + +HRESULT implParse( const CString& s, Whisper::eModelImplementation& rdi ); + +LPCTSTR implString( Whisper::eModelImplementation i ); + +void implPopulateCombobox( CComboBox& cb, Whisper::eModelImplementation impl ); + +Whisper::eModelImplementation implGetValue( CComboBox& cb ); + +__interface iThreadPoolCallback +{ + void __stdcall poolCallback() noexcept; +}; + +class ThreadPoolWork +{ + PTP_WORK work = nullptr; + static void __stdcall callback( PTP_CALLBACK_INSTANCE Instance, PVOID Context, PTP_WORK Work ); + +public: + + ~ThreadPoolWork(); + HRESULT create( iThreadPoolCallback* cb ); + HRESULT post(); +}; + +void makeUtf16( CString& rdi, const char* utf8 ); +void makeUtf8( CStringA& rdi, const CString& utf16 ); + +bool getOpenFileName( HWND owner, LPCTSTR title, LPCTSTR filter, CString& path ); + +bool getSaveFileName( HWND owner, LPCTSTR title, LPCTSTR filter, CString& path, DWORD* filterIndex = nullptr ); + +#define ON_BUTTON_CLICK( id, func ) \ + if( uMsg == WM_COMMAND && \ + id == LOWORD( wParam ) ) \ + { \ + bHandled = TRUE; \ + func(); \ + lResult = 0; \ + return TRUE; \ + } + +void reportError( HWND owner, LPCTSTR text, LPCTSTR title, HRESULT hr = S_FALSE ); + +inline const wchar_t* cstr( const CString& s ) { return s; } +inline const char* cstr( const CStringA& s ) { return s; } + +inline HRESULT getLastHr() +{ + return HRESULT_FROM_WIN32( GetLastError() ); +} + +HRESULT writeUtf8Bom( CAtlFile& file ); + +// Flip order of bytes from RGB to BGR, or vice versa +inline uint32_t flipRgb( uint32_t val ) +{ + val = _byteswap_ulong( val ); + return val >> 8; +} + +bool isInvalidTranslate( HWND owner, uint32_t lang, bool translate );
\ No newline at end of file diff --git a/Examples/WhisperDesktop/WhisperDesktop.cpp b/Examples/WhisperDesktop/WhisperDesktop.cpp new file mode 100644 index 0000000..ddef5b6 --- /dev/null +++ b/Examples/WhisperDesktop/WhisperDesktop.cpp @@ -0,0 +1,63 @@ +#include "stdafx.h" +#include "AppState.h" +#include "Utils/miscUtils.h" +#include "LoadModelDlg.h" +#include "TranscribeDlg.h" +#include "CaptureDlg.h" + +HRESULT dialogLoadModel( AppState& appState ) +{ + LoadModelDlg loadDialog{ appState }; + HRESULT hr = loadDialog.show(); + if( FAILED( hr ) ) + { + reportFatalError( "Error loading the model", hr ); + return hr; + } + appState.automaticallyLoadModel = false; + return hr; +} + +HRESULT dialogTranscribe( AppState& appState ) +{ + TranscribeDlg dialog{ appState }; + return dialog.show(); +} + +HRESULT dialogCapture( AppState& appState ) +{ + CaptureDlg dialog{ appState }; + return dialog.show(); +} + +using pfnDialog = HRESULT( * )( AppState& appState ); +static const std::array<pfnDialog, 4> s_dialogs = +{ + nullptr, // S_OK + &dialogLoadModel, // SCREEN_MODEL + &dialogTranscribe, // SCREEN_TRANSCRIBE + &dialogCapture, // SCREEN_CAPTURE +}; + +int __stdcall wWinMain( HINSTANCE hInstance, HINSTANCE hPrevInstance, LPWSTR lpCmdLine, int nCmdShow ) +{ + AppState appState; + HRESULT hr = appState.startup(); + if( FAILED( hr ) ) + return hr; + + appState.findModelSource(); + + hr = SCREEN_MODEL; + while( true ) + { + pfnDialog pfn = s_dialogs[ hr ]; + if( nullptr == pfn ) + return S_OK; + hr = pfn( appState ); + if( FAILED( hr ) ) + return hr; + if( hr == SCREEN_MODEL ) + appState.model = nullptr; + } +}
\ No newline at end of file diff --git a/Examples/WhisperDesktop/WhisperDesktop.manifest b/Examples/WhisperDesktop/WhisperDesktop.manifest new file mode 100644 index 0000000..af2b796 --- /dev/null +++ b/Examples/WhisperDesktop/WhisperDesktop.manifest @@ -0,0 +1,16 @@ +<?xml version="1.0" encoding="UTF-8" standalone="yes"?> +<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3"> + <assemblyIdentity version="1.0.0.0" processorArchitecture="amd64" name="CompanyName.ProductName.YourApplication" type="win32" /> + <description>Your application description here.</description> + <dependency> + <dependentAssembly> + <assemblyIdentity type="win32" name="Microsoft.Windows.Common-Controls" version="6.0.0.0" processorArchitecture="amd64" publicKeyToken="6595b64144ccf1df" language="*" /> + </dependentAssembly> + </dependency> + <asmv3:application> + <asmv3:windowsSettings> + <dpiAware xmlns="http://schemas.microsoft.com/SMI/2005/WindowsSettings">true</dpiAware> + <dpiAwareness xmlns="http://schemas.microsoft.com/SMI/2016/WindowsSettings">PerMonitorV2</dpiAwareness> + </asmv3:windowsSettings> + </asmv3:application> +</assembly>
\ No newline at end of file diff --git a/Examples/WhisperDesktop/WhisperDesktop.rc b/Examples/WhisperDesktop/WhisperDesktop.rc Binary files differnew file mode 100644 index 0000000..67461d7 --- /dev/null +++ b/Examples/WhisperDesktop/WhisperDesktop.rc diff --git a/Examples/WhisperDesktop/WhisperDesktop.vcxproj b/Examples/WhisperDesktop/WhisperDesktop.vcxproj new file mode 100644 index 0000000..4956410 --- /dev/null +++ b/Examples/WhisperDesktop/WhisperDesktop.vcxproj @@ -0,0 +1,151 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{cd9e49f0-75a3-4f91-ac71-336109ee39c6}</ProjectGuid> + <RootNamespace>WhisperDesktop</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <IncludePath>$(ProjectDir);$(SolutionDir)Whisper\API\;$(IncludePath)</IncludePath> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <IncludePath>$(ProjectDir);$(SolutionDir)Whisper\API\;$(IncludePath)</IncludePath> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>_DEBUG;_WINDOWS;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <PrecompiledHeader>Use</PrecompiledHeader> + <LanguageStandard>stdcpp20</LanguageStandard> + <MultiProcessorCompilation>true</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Windows</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + <ManifestFile>WhisperDesktop.manifest</ManifestFile> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_WINDOWS;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <PrecompiledHeader>Use</PrecompiledHeader> + <LanguageStandard>stdcpp20</LanguageStandard> + <MultiProcessorCompilation>true</MultiProcessorCompilation> + <FavorSizeOrSpeed>Size</FavorSizeOrSpeed> + <RuntimeLibrary>MultiThreaded</RuntimeLibrary> + </ClCompile> + <Link> + <SubSystem>Windows</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + <ManifestFile>WhisperDesktop.manifest</ManifestFile> + <LinkTimeCodeGeneration>UseLinkTimeCodeGeneration</LinkTimeCodeGeneration> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ClInclude Include="CircleIndicator.h" /> + <ClInclude Include="AppState.h" /> + <ClInclude Include="CaptureDlg.h" /> + <ClInclude Include="Utils\TranslateCheckbox.h" /> + <ClInclude Include="Utils\DebugConsole.h" /> + <ClInclude Include="framework.h" /> + <ClInclude Include="Utils\logger.h" /> + <ClInclude Include="Utils\PendingState.h" /> + <ClInclude Include="Utils\LanguageDropdown.h" /> + <ClInclude Include="LoadModelDlg.h" /> + <ClInclude Include="TranscribeDlg.h" /> + <ClInclude Include="Utils\miscUtils.h" /> + <ClInclude Include="stdafx.h" /> + <ClInclude Include="Resource.h" /> + <ClInclude Include="targetver.h" /> + <ClInclude Include="Utils\WTL\atlapp.h" /> + <ClInclude Include="Utils\WTL\atlcrack.h" /> + <ClInclude Include="Utils\WTL\atlctrls.h" /> + <ClInclude Include="Utils\WTL\atlddx.h" /> + <ClInclude Include="Utils\WTL\atlgdi.h" /> + <ClInclude Include="Utils\WTL\atlres.h" /> + <ClInclude Include="Utils\WTL\atluser.h" /> + <ClInclude Include="Utils\WTL\atlwinx.h" /> + </ItemGroup> + <ItemGroup> + <ClCompile Include="CircleIndicator.cpp" /> + <ClCompile Include="AppState.cpp" /> + <ClCompile Include="CaptureDlg.cpp" /> + <ClCompile Include="Utils\TranslateCheckbox.cpp" /> + <ClCompile Include="Utils\DebugConsole.cpp" /> + <ClCompile Include="Utils\logger.cpp" /> + <ClCompile Include="Utils\PendingState.cpp" /> + <ClCompile Include="Utils\LanguageDropdown.cpp" /> + <ClCompile Include="LoadModelDlg.cpp" /> + <ClCompile Include="TranscribeDlg.cpp" /> + <ClCompile Include="Utils\miscUtils.cpp" /> + <ClCompile Include="stdafx.cpp"> + <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader> + <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader> + </ClCompile> + <ClCompile Include="WhisperDesktop.cpp" /> + </ItemGroup> + <ItemGroup> + <ResourceCompile Include="WhisperDesktop.rc" /> + </ItemGroup> + <ItemGroup> + <Image Include="sunflower.ico" /> + </ItemGroup> + <ItemGroup> + <Manifest Include="WhisperDesktop.manifest" /> + </ItemGroup> + <ItemGroup> + <ProjectReference Include="..\..\Whisper\Whisper.vcxproj"> + <Project>{701df8c8-e4a5-43ec-9c6b-747bbf4d8e71}</Project> + </ProjectReference> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/Examples/WhisperDesktop/WhisperDesktop.vcxproj.filters b/Examples/WhisperDesktop/WhisperDesktop.vcxproj.filters new file mode 100644 index 0000000..7a36c86 --- /dev/null +++ b/Examples/WhisperDesktop/WhisperDesktop.vcxproj.filters @@ -0,0 +1,142 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <Filter Include="Source Files"> + <UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier> + <Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions> + </Filter> + <Filter Include="Header Files"> + <UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier> + <Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions> + </Filter> + <Filter Include="Resource Files"> + <UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier> + <Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions> + </Filter> + </ItemGroup> + <ItemGroup> + <ClInclude Include="framework.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="targetver.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Resource.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="stdafx.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="LoadModelDlg.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="AppState.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\miscUtils.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\WTL\atlapp.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\WTL\atlddx.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\WTL\atlctrls.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\WTL\atlgdi.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\WTL\atlres.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\WTL\atluser.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\WTL\atlwinx.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="TranscribeDlg.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\LanguageDropdown.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\PendingState.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\DebugConsole.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\logger.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="CaptureDlg.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="CircleIndicator.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\WTL\atlcrack.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="Utils\TranslateCheckbox.h"> + <Filter>Header Files</Filter> + </ClInclude> + </ItemGroup> + <ItemGroup> + <ClCompile Include="WhisperDesktop.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="stdafx.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="LoadModelDlg.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="AppState.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="Utils\miscUtils.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="TranscribeDlg.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="Utils\LanguageDropdown.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="Utils\PendingState.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="Utils\DebugConsole.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="Utils\logger.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="CaptureDlg.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="CircleIndicator.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="Utils\TranslateCheckbox.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + </ItemGroup> + <ItemGroup> + <ResourceCompile Include="WhisperDesktop.rc"> + <Filter>Resource Files</Filter> + </ResourceCompile> + </ItemGroup> + <ItemGroup> + <Image Include="sunflower.ico"> + <Filter>Resource Files</Filter> + </Image> + </ItemGroup> + <ItemGroup> + <Manifest Include="WhisperDesktop.manifest" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/Examples/WhisperDesktop/framework.h b/Examples/WhisperDesktop/framework.h new file mode 100644 index 0000000..b52a644 --- /dev/null +++ b/Examples/WhisperDesktop/framework.h @@ -0,0 +1,22 @@ +#pragma once +#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers +#define NOMINMAX +// Windows Header Files +#include "targetver.h" +#include <windows.h> +// ATL header files +#include <atlstr.h> +#include <atlfile.h> +#include <atlbase.h> +#include <atlwin.h> + +// C RunTime Header Files +#include <stdlib.h> +#include <malloc.h> +#include <memory.h> +#include <tchar.h> +#include <assert.h> +// C++ headers +#include <vector> +#include <array> +#include <algorithm>
\ No newline at end of file diff --git a/Examples/WhisperDesktop/stdafx.cpp b/Examples/WhisperDesktop/stdafx.cpp new file mode 100644 index 0000000..1577c4e --- /dev/null +++ b/Examples/WhisperDesktop/stdafx.cpp @@ -0,0 +1 @@ +#include "stdafx.h"
\ No newline at end of file diff --git a/Examples/WhisperDesktop/stdafx.h b/Examples/WhisperDesktop/stdafx.h new file mode 100644 index 0000000..6f46ad0 --- /dev/null +++ b/Examples/WhisperDesktop/stdafx.h @@ -0,0 +1,8 @@ +#pragma once +#include "framework.h" + +#include <whisperWindows.h> + +#include "resource.h" +#include "Utils/WTL/atlapp.h" +#include "Utils/WTL/atlctrls.h"
\ No newline at end of file diff --git a/Examples/WhisperDesktop/sunflower.ico b/Examples/WhisperDesktop/sunflower.ico Binary files differnew file mode 100644 index 0000000..b6404e1 --- /dev/null +++ b/Examples/WhisperDesktop/sunflower.ico diff --git a/Examples/WhisperDesktop/targetver.h b/Examples/WhisperDesktop/targetver.h new file mode 100644 index 0000000..17d4075 --- /dev/null +++ b/Examples/WhisperDesktop/targetver.h @@ -0,0 +1,6 @@ +#pragma once +// Setup Windows SDK to only enable features available since Windows 8.0 +#include <WinSDKVer.h> +#define _WIN32_WINNT _WIN32_WINNT_WIN8 +#define NTDDI_VERSION NTDDI_WIN8 +#include <SDKDDKVer.h> diff --git a/Examples/main/main.cpp b/Examples/main/main.cpp new file mode 100644 index 0000000..c9eacf2 --- /dev/null +++ b/Examples/main/main.cpp @@ -0,0 +1,315 @@ +#include "params.h" +#include "../../Whisper/API/iContext.cl.h" +#include "../../Whisper/API/iMediaFoundation.cl.h" +#include "../../ComLightLib/comLightClient.h" +#include "miscUtils.h" +#include <array> +#include <atomic> +using namespace Whisper; + +#define STREAM_AUDIO 1 + +static HRESULT loadWhisperModel( const wchar_t* path, iModel** pp ) +{ + using namespace Whisper; + constexpr eModelImplementation impl = eModelImplementation::GPU; + // constexpr eModelImplementation impl = eModelImplementation::Reference; + return Whisper::loadModel( path, impl, nullptr, pp ); +} + +namespace +{ + struct sPrintUserData + { + const whisper_params* params; + // const std::vector<std::vector<float>>* pcmf32s; + }; + + // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] + // Lowest is red, middle is yellow, highest is green. + static const std::array<const char*, 10> k_colors = + { + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", + }; + + std::string to_timestamp( sTimeSpan ts, bool comma = false ) + { + sTimeSpanFields fields = ts; + uint32_t msec = fields.ticks / 10'000; + uint32_t hr = fields.days * 24 + fields.hours; + uint32_t min = fields.minutes; + uint32_t sec = fields.seconds; + + char buf[ 32 ]; + snprintf( buf, sizeof( buf ), "%02d:%02d:%02d%s%03d", hr, min, sec, comma ? "," : ".", msec ); + return std::string( buf ); + } + + static int colorIndex( const sToken& tok ) + { + const float p = tok.probability; + const float p3 = p * p * p; + int col = (int)( p3 * float( k_colors.size() ) ); + col = std::max( 0, std::min( (int)k_colors.size() - 1, col ) ); + return col; + } + + HRESULT __cdecl newSegmentCallback( iContext* context, uint32_t n_new, void* user_data ) noexcept + { + ComLight::CComPtr<iTranscribeResult> results; + CHECK( context->getResults( eResultFlags::Timestamps | eResultFlags::Tokens, &results ) ); + + sTranscribeLength length; + CHECK( results->getSize( length ) ); + + const whisper_params& params = *( (sPrintUserData*)user_data )->params; + // const std::vector<std::vector<float>>& pcmf32s = *( (sPrintUserData*)user_data )->pcmf32s; + + // print the last n_new segments + const uint32_t s0 = length.countSegments - n_new; + if( s0 == 0 ) + printf( "\n" ); + + const sSegment* const segments = results->getSegments(); + const sToken* const tokens = results->getTokens(); + + for( uint32_t i = s0; i < length.countSegments; i++ ) + { + const sSegment& seg = segments[ i ]; + + if( params.no_timestamps ) + { + if( params.print_colors ) + { + for( uint32_t j = 0; j < seg.countTokens; j++ ) + { + const sToken& tok = tokens[ seg.firstToken + j ]; + if( !params.print_special && ( tok.flags & eTokenFlags::Special ) ) + continue; + wprintf( L"%S%s%S", k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" ); + } + } + else + wprintf( L"%s", utf16( seg.text ).c_str() ); + fflush( stdout ); + continue; + } + + std::string speaker = ""; +#if 0 + if( params.diarize && pcmf32s.size() == 2 ) + { + const size_t n_samples = pcmf32s[ 0 ].size(); + const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples ); + const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples ); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for( int64_t j = is0; j < is1; j++ ) + { + energy0 += fabs( pcmf32s[ 0 ][ j ] ); + energy1 += fabs( pcmf32s[ 1 ][ j ] ); + } + + if( energy0 > 1.1 * energy1 ) + speaker = "(speaker 0)"; + else if( energy1 > 1.1 * energy0 ) + speaker = "(speaker 1)"; + else + speaker = "(speaker ?)"; + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } +#endif + + if( params.print_colors ) + { + printf( "[%s --> %s] ", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str() ); + + for( uint32_t j = 0; j < seg.countTokens; j++ ) + { + const sToken& tok = tokens[ seg.firstToken + j ]; + if( !params.print_special && ( tok.flags & eTokenFlags::Special ) ) + continue; + wprintf( L"%S%S%s%S", speaker.c_str(), k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" ); + } + printf( "\n" ); + } + else + wprintf( L"[%S --> %S] %S%s\n", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str(), speaker.c_str(), utf16( seg.text ).c_str() ); + } + return S_OK; + } + + HRESULT __cdecl beginSegmentCallback( iContext* context, void* user_data ) noexcept + { + std::atomic_bool* flag = (std::atomic_bool*)user_data; + bool aborted = flag->load(); + return aborted ? S_FALSE : S_OK; + } + + HRESULT setupConsoleColors() + { + HANDLE h = GetStdHandle( STD_OUTPUT_HANDLE ); + if( h == INVALID_HANDLE_VALUE ) + return HRESULT_FROM_WIN32( GetLastError() ); + + DWORD mode = 0; + if( !GetConsoleMode( h, &mode ) ) + return HRESULT_FROM_WIN32( GetLastError() ); + if( 0 != ( mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING ) ) + return S_FALSE; + + mode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if( !SetConsoleMode( h, mode ) ) + return HRESULT_FROM_WIN32( GetLastError() ); + return S_OK; + } +} + +int wmain( int argc, wchar_t* argv[] ) +{ + // Whisper::dbgCompareTraces( LR"(C:\Temp\2remove\Whisper\ref.bin)", LR"(C:\Temp\2remove\Whisper\gpu.bin )" ); return 0; + + // Tell logger to use the standard output stream for the messages + { + Whisper::sLoggerSetup logSetup; + logSetup.flags = eLoggerFlags::UseStandardError; + logSetup.level = eLogLevel::Debug; + Whisper::setupLogger( logSetup ); + } + + whisper_params params; + if( !params.parse( argc, argv ) ) + return 1; + + if( params.print_colors ) + { + if( FAILED( setupConsoleColors() ) ) + params.print_colors = false; + } + + if( params.fname_inp.empty() ) + { + fprintf( stderr, "error: no input files specified\n" ); + whisper_print_usage( argc, argv, params ); + return 2; + } + + if( Whisper::findLanguageKeyA( params.language.c_str() ) == UINT_MAX ) + { + fprintf( stderr, "error: unknown language '%s'\n", params.language.c_str() ); + whisper_print_usage( argc, argv, params ); + return 3; + } + + ComLight::CComPtr<iModel> model; + HRESULT hr = loadWhisperModel( params.model.c_str(), &model ); + if( FAILED( hr ) ) + { + printError( "failed to load the model", hr ); + return 4; + } + + ComLight::CComPtr<iContext> context; + hr = model->createContext( &context ); + if( FAILED( hr ) ) + { + printError( "failed to initialize whisper context", hr ); + return 5; + } + + ComLight::CComPtr<iMediaFoundation> mf; + hr = initMediaFoundation( &mf ); + if( FAILED( hr ) ) + { + printError( "failed to initialize Media Foundation runtime", hr ); + return 5; + } + + for( const std::wstring& fname : params.fname_inp ) + { + // print some info about the processing + { + if( model->isMultilingual() == S_FALSE ) + { + if( params.language != "en" || params.translate ) + { + params.language = "en"; + params.translate = false; + fprintf( stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__ ); + } + } + + /* + fwprintf( stderr, L"%S: processing '%s' (%zu samples, %.1f sec), %d threads, %d processors, lang = %S, task = %S, timestamps = %d ...\n", + __func__, fname.c_str(), audio.pcmf32.size(), audio.seconds(), + params.n_threads, params.n_processors, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1 ); + */ + } + + // run the inference + Whisper::sFullParams wparams; + context->fullDefaultParams( eSamplingStrategy::Greedy, &wparams ); + + wparams.resetFlag( eFullParamsFlags::PrintRealtime | eFullParamsFlags::PrintProgress ); + wparams.setFlag( eFullParamsFlags::PrintTimestamps, !params.no_timestamps ); + wparams.setFlag( eFullParamsFlags::PrintSpecial, params.print_special ); + wparams.setFlag( eFullParamsFlags::Translate, params.translate ); + wparams.language = Whisper::makeLanguageKey( params.language.c_str() ); + wparams.cpuThreads = params.n_threads; + if( params.max_context != UINT_MAX ) + wparams.n_max_text_ctx = params.max_context; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.setFlag( eFullParamsFlags::TokenTimestamps, params.output_wts || params.max_len > 0 ); + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.setFlag( eFullParamsFlags::SpeedupAudio, params.speed_up ); + // sPrintUserData user_data = { ¶ms, &audio.pcmf32s }; + sPrintUserData user_data = { ¶ms }; + + // this callback is called on each new segment + if( !wparams.flag( eFullParamsFlags::PrintRealtime ) ) + { + wparams.new_segment_callback = &newSegmentCallback; + wparams.new_segment_callback_user_data = &user_data; + } + + // example for abort mechanism + // in this example, we do not abort the processing, but we could if the flag is set to true + // the callback is called before every encoder run - if it returns false, the processing is aborted + std::atomic_bool is_aborted = false; + { + wparams.encoder_begin_callback = &beginSegmentCallback; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + +#if STREAM_AUDIO + ComLight::CComPtr<iAudioReader> reader; + CHECK( mf->openAudioFile( fname.c_str(), params.diarize, &reader ) ); + sProgressSink progressSink{ nullptr, nullptr }; + hr = context->runStreamed( wparams, progressSink, reader ); +#else + ComLight::CComPtr<iAudioBuffer> buffer; + CHECK( mf->loadAudioFile( fname.c_str(), params.diarize, &buffer ) ); + hr = context->runFull( wparams, buffer ); +#endif + if( FAILED( hr ) ) + { + fwprintf( stderr, L"%s: failed to process audio\n", argv[ 0 ] ); + return 10; + } + } + + context->timingsPrint(); + context = nullptr; + return 0; +}
\ No newline at end of file diff --git a/Examples/main/main.vcxproj b/Examples/main/main.vcxproj new file mode 100644 index 0000000..4945b88 --- /dev/null +++ b/Examples/main/main.vcxproj @@ -0,0 +1,93 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{4cca7042-eb15-4f7a-b77b-5cafd2df47b2}</ProjectGuid> + <RootNamespace>main</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NOMINMAX;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NOMINMAX;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ClCompile Include="main.cpp" /> + <ClCompile Include="miscUtils.cpp" /> + <ClCompile Include="params.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="miscUtils.h" /> + <ClInclude Include="params.h" /> + </ItemGroup> + <ItemGroup> + <ProjectReference Include="..\..\Whisper\Whisper.vcxproj"> + <Project>{701df8c8-e4a5-43ec-9c6b-747bbf4d8e71}</Project> + </ProjectReference> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/Examples/main/main.vcxproj.filters b/Examples/main/main.vcxproj.filters new file mode 100644 index 0000000..94cd8a1 --- /dev/null +++ b/Examples/main/main.vcxproj.filters @@ -0,0 +1,12 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <ClCompile Include="main.cpp" /> + <ClCompile Include="params.cpp" /> + <ClCompile Include="miscUtils.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="params.h" /> + <ClInclude Include="miscUtils.h" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/Examples/main/miscUtils.cpp b/Examples/main/miscUtils.cpp new file mode 100644 index 0000000..3ebda20 --- /dev/null +++ b/Examples/main/miscUtils.cpp @@ -0,0 +1,48 @@ +#include "miscUtils.h" +#define WIN32_LEAN_AND_MEAN +#include <windows.h> + +std::string utf8( const std::wstring& utf16 ) +{ + int count = WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), (int)utf16.length(), nullptr, 0, nullptr, nullptr ); + std::string str( count, 0 ); + WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), -1, &str[ 0 ], count, nullptr, nullptr ); + return str; +} + +std::wstring utf16( const std::string& u8 ) +{ + int count = MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), nullptr, 0 ); + std::wstring str( count, 0 ); + MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), &str[ 0 ], count ); + return str; +} + +namespace +{ + wchar_t* formatMessage( HRESULT hr ) + { + wchar_t* err; + if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + hr, + MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), + (LPTSTR)&err, + 0, + NULL ) ) + return err; + return nullptr; + } +} + +void printError( const char* what, HRESULT hr ) +{ + const wchar_t* err = formatMessage( hr ); + if( nullptr != err ) + { + fwprintf( stderr, L"%S: %s\n", what, err ); + LocalFree( (HLOCAL)err ); + } + else + fprintf( stderr, "%s: error code %i (0x%08X)\n", what, hr, hr ); +}
\ No newline at end of file diff --git a/Examples/main/miscUtils.h b/Examples/main/miscUtils.h new file mode 100644 index 0000000..52770a6 --- /dev/null +++ b/Examples/main/miscUtils.h @@ -0,0 +1,9 @@ +#pragma once +#include <string> + +std::string utf8( const std::wstring& utf16 ); + +std::wstring utf16( const std::string& u8 ); + +using HRESULT = long; +void printError( const char* what, HRESULT hr );
\ No newline at end of file diff --git a/Examples/main/params.cpp b/Examples/main/params.cpp new file mode 100644 index 0000000..ff1cfdd --- /dev/null +++ b/Examples/main/params.cpp @@ -0,0 +1,101 @@ +#include "params.h" +#include <algorithm> +#include <thread> +#include "miscUtils.h" + +whisper_params::whisper_params() +{ +#ifdef _DEBUG + n_threads = 2; +#else + n_threads = std::min( 4u, std::thread::hardware_concurrency() ); +#endif +} + +namespace +{ + const char* cstr( bool b ) + { + return b ? "true" : "false"; + } +} + +void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params ) +{ + fprintf( stderr, "\n" ); + fprintf( stderr, "usage: %S [options] file0.wav file1.wav ...\n", argv[ 0 ] ); + fprintf( stderr, "\n" ); + fprintf( stderr, "options:\n" ); + fprintf( stderr, " -h, --help [default] show this help message and exit\n" ); + fprintf( stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads ); + fprintf( stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors ); + fprintf( stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms ); + fprintf( stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n ); + fprintf( stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms ); + fprintf( stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context ); + fprintf( stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len ); + fprintf( stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold ); + fprintf( stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", cstr( params.speed_up ) ); + fprintf( stderr, " -tr, --translate [%-7s] translate from source language to english\n", cstr( params.translate ) ); + fprintf( stderr, " -di, --diarize [%-7s] stereo audio diarization\n", cstr( params.diarize ) ); + fprintf( stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", cstr( params.output_txt ) ); + fprintf( stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", cstr( params.output_vtt ) ); + fprintf( stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", cstr( params.output_srt ) ); + fprintf( stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", cstr( params.output_wts ) ); + fprintf( stderr, " -ps, --print-special [%-7s] print special tokens\n", cstr( params.print_special ) ); + fprintf( stderr, " -nc, --no-colors [%-7s] do not print colors\n", cstr( !params.print_colors ) ); + fprintf( stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", cstr( params.no_timestamps ) ); + fprintf( stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str() ); + fprintf( stderr, " -m FNAME, --model FNAME [%-7S] model path\n", params.model.c_str() ); + fprintf( stderr, " -f FNAME, --file FNAME [%-7s] path of the input audio file\n", "" ); + fprintf( stderr, "\n" ); +} + +bool whisper_params::parse( int argc, wchar_t* argv[] ) +{ + for( int i = 1; i < argc; i++ ) + { + std::wstring arg = argv[ i ]; + + if( arg[ 0 ] != '-' ) + { + fname_inp.push_back( arg ); + continue; + } + + if( arg == L"-h" || arg == L"--help" ) + { + whisper_print_usage( argc, argv, *this ); + return false; + } + + else if( arg == L"-t" || arg == L"--threads" ) { n_threads = std::stoul( argv[ ++i ] ); } + else if( arg == L"-p" || arg == L"--processors" ) { n_processors = std::stoul( argv[ ++i ] ); } + else if( arg == L"-ot" || arg == L"--offset-t" ) { offset_t_ms = std::stoul( argv[ ++i ] ); } + else if( arg == L"-on" || arg == L"--offset-n" ) { offset_n = std::stoul( argv[ ++i ] ); } + else if( arg == L"-d" || arg == L"--duration" ) { duration_ms = std::stoul( argv[ ++i ] ); } + else if( arg == L"-mc" || arg == L"--max-context" ) { max_context = std::stoul( argv[ ++i ] ); } + else if( arg == L"-ml" || arg == L"--max-len" ) { max_len = std::stoul( argv[ ++i ] ); } + else if( arg == L"-wt" || arg == L"--word-thold" ) { word_thold = std::stof( argv[ ++i ] ); } + else if( arg == L"-su" || arg == L"--speed-up" ) { speed_up = true; } + else if( arg == L"-tr" || arg == L"--translate" ) { translate = true; } + else if( arg == L"-di" || arg == L"--diarize" ) { diarize = true; } + else if( arg == L"-otxt" || arg == L"--output-txt" ) { output_txt = true; } + else if( arg == L"-ovtt" || arg == L"--output-vtt" ) { output_vtt = true; } + else if( arg == L"-osrt" || arg == L"--output-srt" ) { output_srt = true; } + else if( arg == L"-owts" || arg == L"--output-words" ) { output_wts = true; } + else if( arg == L"-ps" || arg == L"--print-special" ) { print_special = true; } + else if( arg == L"-nc" || arg == L"--no-colors" ) { print_colors = false; } + else if( arg == L"-nt" || arg == L"--no-timestamps" ) { no_timestamps = true; } + else if( arg == L"-l" || arg == L"--language" ) { language = utf8( argv[ ++i ] ); } + else if( arg == L"-m" || arg == L"--model" ) { model = argv[ ++i ]; } + else if( arg == L"-f" || arg == L"--file" ) { fname_inp.push_back( argv[ ++i ] ); } + else + { + fprintf( stderr, "error: unknown argument: %S\n", arg.c_str() ); + whisper_print_usage( argc, argv, *this ); + return false; + } + } + return true; +}
\ No newline at end of file diff --git a/Examples/main/params.h b/Examples/main/params.h new file mode 100644 index 0000000..9eb2b04 --- /dev/null +++ b/Examples/main/params.h @@ -0,0 +1,38 @@ +#pragma once +#include <vector> +#include <string> + +// command-line parameters +struct whisper_params +{ + uint32_t n_threads; + uint32_t n_processors = 1; + uint32_t offset_t_ms = 0; + uint32_t offset_n = 0; + uint32_t duration_ms = 0; + uint32_t max_context = UINT_MAX; + uint32_t max_len = 0; + + float word_thold = 0.01f; + + bool speed_up = false; + bool translate = false; + bool diarize = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool print_special = false; + bool print_colors = true; + bool no_timestamps = false; + + std::string language = "en"; + std::wstring model = L"models/ggml-base.en.bin"; + std::vector<std::wstring> fname_inp; + + whisper_params(); + + bool parse( int argc, wchar_t* argv[] ); +}; + +void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params );
\ No newline at end of file diff --git a/Tools/CompressShaders/Cabinet.cs b/Tools/CompressShaders/Cabinet.cs new file mode 100644 index 0000000..b53fd18 --- /dev/null +++ b/Tools/CompressShaders/Cabinet.cs @@ -0,0 +1,60 @@ +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace CompressShaders +{ + /// <summary>Lossless data compressor implemented by <c>Cabinet.dll</c> Windows component</summary> + /// <remarks> + /// <para>Whisper.dll consumes that component in runtime, to decompress these shader binaries</para> + /// <para>If you wonder why not gzip — because the OS doesn’t include an API for that, at least not an API usable from C or C++.<br/> + /// .NET standard library includes gzip algorithm, but we don't want Whisper.dll to depend on .NET.</para> + /// </remarks> + static class Cabinet + { + /// <summary>Compression algorithm</summary> + /// <seealso href="https://learn.microsoft.com/en-us/windows/win32/cmpapi/using-the-compression-api#selecting-the-compression-algorithm" /> + enum eCompressionAlgorithm: uint + { + MSZIP = 2, + XPRESS = 3, + XPRESS_HUFF = 4, + LZMS = 5, + } + /// <summary>The value should match <c>constexpr DWORD compressionAlgorithm</c> constant,<br/>in <c>Whisper/D3D/shaders.cpp</c> source file</summary> + const eCompressionAlgorithm algo = eCompressionAlgorithm.MSZIP; + + [DllImport( "Cabinet.dll", SetLastError = true )] + static extern bool CreateCompressor( eCompressionAlgorithm Algorithm, IntPtr AllocationRoutines, out IntPtr CompressorHandle ); + + [DllImport( "Cabinet.dll", SetLastError = true )] + static extern bool CloseCompressor( IntPtr CompressorHandle ); + + [DllImport( "Cabinet.dll", SetLastError = true )] + static extern bool Compress( IntPtr CompressorHandle, [In] byte[] UncompressedData, IntPtr UncompressedDataSize, [Out] byte[] CompressedBuffer, IntPtr CompressedBufferSize, out IntPtr CompressedDataSize ); + + /// <summary>Compress an array of bytes into another, smaller array of bytes</summary> + /// <remarks>In practice, the compression ratio is about 7.1 for the shader binaries in Release configuration.</remarks> + public static byte[] compressBuffer( byte[] src ) + { + if( src.Length <= 0 ) + throw new ArgumentException( "The source buffer is empty" ); + IntPtr hCompressor; + if( !CreateCompressor( algo, IntPtr.Zero, out hCompressor ) ) + throw new Win32Exception( "Unable to create the compressor" ); + try + { + byte[] dest = new byte[ src.Length * 2 ]; + IntPtr srcSize = new IntPtr( src.Length ); + IntPtr destSize = new IntPtr( src.Length * 2 ); + if( !Compress( hCompressor, src, srcSize, dest, destSize, out destSize ) ) + throw new Win32Exception( "Compress failed" ); + Array.Resize( ref dest, (int)destSize ); + return dest; + } + finally + { + CloseCompressor( hCompressor ); + } + } + } +}
\ No newline at end of file diff --git a/Tools/CompressShaders/CompressShaders.cs b/Tools/CompressShaders/CompressShaders.cs new file mode 100644 index 0000000..814f966 --- /dev/null +++ b/Tools/CompressShaders/CompressShaders.cs @@ -0,0 +1,244 @@ +using System.Runtime.CompilerServices; +namespace CompressShaders; + +record struct sShaderBinary +{ + public string name; + public byte[] data; + + public sShaderBinary( string path ) + { + name = Path.GetFileNameWithoutExtension( path ); + data = File.ReadAllBytes( path ); + } + + public bool wave64 => name.EndsWith( "64" ); + public string uniqueName => wave64 ? name.Substring( 0, name.Length - 2 ) : name; +} + +sealed class FoundShaders +{ + public readonly sShaderBinary[] binaries; + public readonly string[] names; + public readonly int[] wave32, wave64; + + public FoundShaders( IEnumerable<sShaderBinary> found ) + { + binaries = found + .OrderBy( b => b.name ) + .ToArray(); + + names = binaries + .Select( b => b.uniqueName ) + .Distinct() + .ToArray(); + + wave32 = new int[ names.Length ]; + wave64 = new int[ names.Length ]; + for( int i = 0; i < names.Length; i++ ) + { + int i32 = findIndex( names[ i ], false ); + int i64 = findIndex( names[ i ], true ); + if( i32 >= 0 && i64 >= 0 ) + { + wave32[ i ] = i32; + wave64[ i ] = i64; + continue; + } + if( i32 >= 0 ) + { + wave32[ i ] = wave64[ i ] = i32; + continue; + } + throw new ApplicationException( $"Wave64 shader {names[ i ]} doesn't have the corresponding Wave32 one" ); + } + } + + int findIndex( string name, bool wave64 ) + { + for( int i = 0; i < binaries.Length; i++ ) + { + sShaderBinary sb = binaries[ i ]; + if( sb.uniqueName != name ) + continue; + if( sb.wave64 == wave64 ) + return i; + } + return -1; + } +} + +class Program +{ + static string getSolutionRoot( [CallerFilePath] string? path = null ) + { + string? dir = Path.GetDirectoryName( path ); + dir = Path.GetDirectoryName( dir ); + dir = Path.GetDirectoryName( dir ); + return dir ?? throw new ApplicationException(); + } + +#if DEBUG + const string config = "Debug"; +#else + const string config = "Release"; +#endif + + static string shadersBinDir( string root ) + { + return Path.Combine( root, "ComputeShaders", "x64", config ); + } + + static IEnumerable<sShaderBinary> readShaders( string root ) + { + string dir = shadersBinDir( root ); + foreach( string path in Directory.EnumerateFiles( dir, "*.cso" ) ) + yield return new sShaderBinary( path ); + } + + static void writeHeader( string root, IEnumerable<string> names ) + { + string path = Path.Combine( root, "Whisper", "D3D", "shaderNames.h" ); + using var stream = File.CreateText( path ); + stream.WriteLine( @"// This header is generated by a tool +#pragma once +#include <stdint.h> + +namespace DirectCompute +{ + enum struct eComputeShader: uint16_t + {" ); + + int id = 0; + foreach( string name in names ) + { + stream.WriteLine( "\t\t{0} = {1},", name, id ); + id++; + } + stream.Write( @" }; + + const char* computeShaderName( eComputeShader cs ); +}" ); + } + + static void writeCpp( string root, IEnumerable<string> names ) + { + string path = Path.Combine( root, "Whisper", "D3D", "shaderNames.cpp" ); + ShaderNames.write( path, names ); + } + + static void writePayloadIDs( StreamWriter stream, string varName, int[] ids ) + { + stream.Write( @" +static const std::array<uint8_t, {0}> {1} = {{", ids.Length, varName ); + + for( int i = 0; i < ids.Length; i++ ) + { + if( 0 == i % 16 ) + stream.Write( "\r\n\t" ); + else + stream.Write( ' ' ); + stream.Write( "{0},", ids[ i ] ); + } + stream.Write( @" +};" ); + } + + static void writePayload( string root, FoundShaders shaders, out int cbSource, out int cbCompressed ) + { + MemoryStream ms = new MemoryStream(); + List<int> offsets = new List<int>(); + foreach( var bin in shaders.binaries ) + { + offsets.Add( (int)ms.Length ); + ms.Write( bin.data ); + } + offsets.Add( (int)ms.Length ); + + byte[] dxbc = ms.ToArray(); + byte[] compressed = Cabinet.compressBuffer( dxbc ); + cbSource = dxbc.Length; + cbCompressed = compressed.Length; + + string path = Path.Combine( root, "Whisper", "D3D", $"shaderData-{config}.inl" ); + using var stream = File.CreateText( path ); + stream.Write( @"// This source file is generated by a tool + +// This array contains concatenated and compressed DXBC binaries for all compiled compute shaders +static const std::array<uint8_t, {0}> s_compressedShaders = +{{", compressed.Length ); + + for( int i = 0; i < compressed.Length; i++ ) + { + if( 0 == i % 16 ) + stream.Write( "\r\n\t" ); + else + stream.Write( ' ' ); + stream.Write( "0x{0:X02},", compressed[ i ] ); + } + + stream.Write( @" +}}; + +// This array contains start offsets of shader binaries in the decompressed DXBC blob. +// It includes one more entry for the end of the complete decompressed blob. +static const std::array<uint32_t, {0}> s_shaderOffsets = {{", offsets.Count ); + + for( int i = 0; i < offsets.Count; i++ ) + { + if( 0 == i % 16 ) + stream.Write( "\r\n\t" ); + else + stream.Write( ' ' ); + stream.Write( "{0},", offsets[ i ] ); + } + stream.Write( @" +};" ); + + stream.Write( @" +// Index = eComputeShader enum value, value = index of the shader binary to use on nVidia and Intel GPUs" ); + writePayloadIDs( stream, "s_shaderBlobs32", shaders.wave32 ); + stream.Write( @" +// Index = eComputeShader enum value, value = index of the shader binary to use on AMD GPUs" ); + writePayloadIDs( stream, "s_shaderBlobs64", shaders.wave64 ); + + ulong fp64Flags = 0; + for( int i = 0; i < shaders.binaries.Length; i++ ) + { + bool fp64 = DetectFp64.usesFp64( shaders.binaries[ i ].data ); + if( fp64 ) + fp64Flags |= (ulong)1 << i; + } + + stream.Write( @" +// Bitmap of the shader binaries which use FP64 arithmetic instructions +constexpr uint64_t fp64ShadersBitmap = 0x{0:X}ull;", fp64Flags ); + } + + static void mainImpl() + { + string root = getSolutionRoot(); + LanguageCodes.produce( root ); + + FoundShaders shaders = new FoundShaders( readShaders( root ) ); + + writeHeader( root, shaders.names ); + writeCpp( root, shaders.names ); + writePayload( root, shaders, out int cbIn, out int cbOut ); + Console.WriteLine( "Compressed {0} compute shaders, {1:F1} kb -> {2:F1} kb", shaders.binaries.Length, cbIn / 1024.0, cbOut / 1024.0 ); + } + + static int Main( string[] args ) + { + try + { + mainImpl(); + return 0; + } + catch( Exception ex ) + { + Console.WriteLine( ex.Message ); + return ex.HResult; + } + } +}
\ No newline at end of file diff --git a/Tools/CompressShaders/CompressShaders.csproj b/Tools/CompressShaders/CompressShaders.csproj new file mode 100644 index 0000000..dee1710 --- /dev/null +++ b/Tools/CompressShaders/CompressShaders.csproj @@ -0,0 +1,10 @@ +<Project Sdk="Microsoft.NET.Sdk"> + <PropertyGroup> + <OutputType>Exe</OutputType> + <TargetFramework>net6.0</TargetFramework> + <ImplicitUsings>enable</ImplicitUsings> + <Nullable>enable</Nullable> + <CheckForOverflowUnderflow>true</CheckForOverflowUnderflow> + <AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath> + </PropertyGroup> +</Project>
\ No newline at end of file diff --git a/Tools/CompressShaders/DetectFp64.cs b/Tools/CompressShaders/DetectFp64.cs new file mode 100644 index 0000000..1d75126 --- /dev/null +++ b/Tools/CompressShaders/DetectFp64.cs @@ -0,0 +1,43 @@ +#pragma warning disable CS0649 +using System.Runtime.InteropServices; + +namespace CompressShaders +{ + static class DetectFp64 + { + struct DXBCHeader + { + public uint FourCC; // Four character code "DXBC" + public uint Hash0; // 32-bit hash of the DXBC file + public uint Hash1; // 32-bit hash of the DXBC file + public uint Hash2; // 32-bit hash of the DXBC file + public uint Hash3; // 32-bit hash of the DXBC file + public uint unknownOne; + public uint TotalSize; // Total size of the DXBC file in bytes + public int NumChunks; // Number of chunks in the DXBC file + }; + + public static bool usesFp64( ReadOnlySpan<byte> dxbc ) + { + ReadOnlySpan<DXBCHeader> dxbcHeaderSpan = MemoryMarshal.Cast<byte, DXBCHeader>( dxbc ); + DXBCHeader dxbcHeader = dxbcHeaderSpan[ 0 ]; + + int cbHeader = Marshal.SizeOf<DXBCHeader>(); + int nChunks = dxbcHeader.NumChunks; + ReadOnlySpan<int> chunkOffsets = MemoryMarshal.Cast<byte, int>( dxbc.Slice( cbHeader, nChunks * 4 ) ); + foreach( int off in chunkOffsets ) + { + uint id = MemoryMarshal.Cast<byte, uint>( dxbc.Slice( off, 4 ) )[ 0 ]; + const uint SFI0 = 0x30494653; + if( id != SFI0 ) + continue; + int size = MemoryMarshal.Cast<byte, int>( dxbc.Slice( off + 4, 4 ) )[ 0 ]; + if( size < 4 ) + throw new ApplicationException(); + uint data = MemoryMarshal.Cast<byte, uint>( dxbc.Slice( off + 8, 4 ) )[ 0 ]; + return 0 != ( data & 1u ); + } + return false; + } + } +}
\ No newline at end of file diff --git a/Tools/CompressShaders/LanguageCodes.cs b/Tools/CompressShaders/LanguageCodes.cs new file mode 100644 index 0000000..71a9909 --- /dev/null +++ b/Tools/CompressShaders/LanguageCodes.cs @@ -0,0 +1,103 @@ +using System.Globalization; +using System.Text.RegularExpressions; + +namespace CompressShaders +{ + static class LanguageCodes + { + record struct Row + { + public string keySource; + public uint keyValue; + public int code; + public string name; + } + + static uint makeKey( string str ) + { + if( str.Length > 4 ) + throw new ArgumentException(); + uint k = 0; + int shift = 0; + foreach( char c in str ) + { + if( c >= 0x80 ) + throw new ArgumentException(); + uint u = (uint)c; + k |= ( u << shift ); + shift += 8; + } + return k; + } + + static IEnumerable<Row> load( string path ) + { + using var stm = File.OpenText( path ); + while( true ) + { + string? line = stm.ReadLine(); + if( null == line ) + break; + if( string.IsNullOrWhiteSpace( line ) ) + continue; + string[] fields = line.Split( '\t' ); + yield return new Row() + { + keySource = fields[ 0 ], + keyValue = makeKey( fields[ 0 ] ), + code = int.Parse( fields[ 1 ] ), + name = fields[ 2 ] + }; + } + } + + static void writeCpp( string inl, Row[] data ) + { + // TODO [very low]: sort them by the key here, then in C++ use binary search instead of the hash map + using var stm = File.CreateText( inl ); + stm.WriteLine( "// This file is generated by a tool, from the `languageCodez.tsv` file in this repository" ); + foreach( Row row in data ) + stm.WriteLine( "Lang{{ 0x{0:X}, {1}, \"{2}\" }},", row.keyValue, row.code, row.name ); + } + + static readonly CultureInfo ci = new CultureInfo( "en-US", false ); + static string titleCase( this string name ) => + ci.TextInfo.ToTitleCase( name.ToLower( ci ) ); + + static void writeCs( string cs, Row[] data ) + { + using var stm = File.CreateText( cs ); + stm.WriteLine( @"// This file is generated by a tool, from the `languageCodez.tsv` file in this repository +namespace Whisper +{ + /// <summary>Supported languages</summary> + public enum eLanguage: uint + {" ); + + foreach( Row row in data ) + { + string tc = row.name.titleCase(); + stm.WriteLine( " /// <summary>{0}</summary>", tc ); + tc = Regex.Replace( tc, @"\s+", string.Empty ); + stm.WriteLine( " {0} = 0x{1:X},", tc, row.keyValue ); + } + stm.Write( @" } +}" ); + } + + static void produce( string tsv, string inl, string cs ) + { + Row[] data = load( tsv ).OrderBy( r => r.name ).ToArray(); + writeCpp( inl, data ); + writeCs( cs, data ); + } + + public static void produce( string solutionRoot ) + { + string tsv = Path.Combine( solutionRoot, "Whisper\\Whisper\\languageCodez.tsv" ); + string inl = Path.Combine( solutionRoot, "Whisper\\Whisper\\languageCodez.inl" ); + string cs = Path.Combine( solutionRoot, "WhisperNet\\API\\eLanguage.cs" ); + produce( tsv, inl, cs ); + } + } +}
\ No newline at end of file diff --git a/Tools/CompressShaders/Readme.txt b/Tools/CompressShaders/Readme.txt new file mode 100644 index 0000000..69ef35a --- /dev/null +++ b/Tools/CompressShaders/Readme.txt @@ -0,0 +1,10 @@ +This project builds a C# console app which serves as a code generator for a few pieces of Whisper.dll and WhisperNet.dll. + +Specifically, it generates two things. + +1. It compresses the compiled DXBC shaders into a blob of bytes, and prints std::array with these bytes into shaderData-Release.inl and shaderData-Debug.inl C++ files. + +2. It parses the `languageCodez.tsv`, and generates both C++ and C# code with the data from that table. + +The tool uses relative paths across source files. +These paths will break if you move the source of the tool, or the source data of the tool.
\ No newline at end of file diff --git a/Tools/CompressShaders/ShaderNames.cs b/Tools/CompressShaders/ShaderNames.cs new file mode 100644 index 0000000..81ba46e --- /dev/null +++ b/Tools/CompressShaders/ShaderNames.cs @@ -0,0 +1,27 @@ +static class ShaderNames +{ + public static void write( string path, IEnumerable<string> names ) + { + string[] arr = names.ToArray(); + using var stream = File.CreateText( path ); + stream.WriteLine( @"// This source file is generated by a tool +#include ""stdafx.h"" +#include ""shaderNames.h"" +" ); + + stream.WriteLine( "static const std::array<const char*, {0}> s_shaderNames = ", arr.Length ); + stream.WriteLine( "{" ); + foreach( string name in arr ) + stream.WriteLine( @" ""{0}"",", name ); + + stream.Write( @"}; + +const char* DirectCompute::computeShaderName( eComputeShader cs ) +{ + const uint16_t i = (uint16_t)cs; + if( i < s_shaderNames.size() ) + return s_shaderNames[ i ]; + return nullptr; +}" ); + } +}
\ No newline at end of file diff --git a/Tools/compareTraces/CommandLineArgs.cpp b/Tools/compareTraces/CommandLineArgs.cpp new file mode 100644 index 0000000..5f26bdb --- /dev/null +++ b/Tools/compareTraces/CommandLineArgs.cpp @@ -0,0 +1,51 @@ +#include "stdafx.h" +#include "CommandLineArgs.h" +#include <charconv> + +static bool printUsage() +{ + fprintf( stderr, "Usage: compareTraces.exe trace1.bin trace2.bin [-diff N]\n" ); + return false; +} + +bool CommandLineArgs::parse( int argc, wchar_t* argv[] ) +{ + size_t idx = 0; + CString sw; + CStringA tmp; + for( int i = 1; i < argc; i++ ) + { + if( argv[ i ][ 0 ] != L'-' ) + { + if( idx >= 2 ) + return printUsage(); + inputs[ idx ] = argv[ i ]; + idx++; + continue; + } + sw = argv[ i ]; + if( 0 == sw.CompareNoCase( L"-diff" ) ) + { + i++; + if( i >= argc ) + return printUsage(); + tmp.Format( "%S", argv[ i ] ); + tmp.Trim(); + uint64_t v; + auto res = std::from_chars( tmp, cstr( tmp ) + tmp.GetLength(), v ); + if( res.ec != (std::errc)0 ) + { + fprintf( stderr, "Unable to parse string into number\n" ); + return false; + } + printDiff = v; + continue; + } + return printUsage(); + } + + if( idx != 2 ) + return printUsage(); + + return true; +}
\ No newline at end of file diff --git a/Tools/compareTraces/CommandLineArgs.h b/Tools/compareTraces/CommandLineArgs.h new file mode 100644 index 0000000..d434e76 --- /dev/null +++ b/Tools/compareTraces/CommandLineArgs.h @@ -0,0 +1,9 @@ +#pragma once + +struct CommandLineArgs +{ + int64_t printDiff = -1; + std::array<CString, 2> inputs; + + bool parse( int argc, wchar_t* argv[] ); +};
\ No newline at end of file diff --git a/Tools/compareTraces/Readme.txt b/Tools/compareTraces/Readme.txt new file mode 100644 index 0000000..035d658 --- /dev/null +++ b/Tools/compareTraces/Readme.txt @@ -0,0 +1,9 @@ +This project builds a C++ console tool which compares debug traces of the model. + +Tracing files easily exceed 1GB, and by default they’re disabled with a preprocessor macro in stdafx.h of the Whisper project. + +When enabled, the main GPU implementation saves a trace into C:\Temp\2remove\Whisper\gpu.bin + +The reference CPU implementation saves a trace into C:\Temp\2remove\Whisper\ref.bin + +This code in this project is optimized for development speed. For this reason it requires AVX2 CPU, uses memory-mapped IO instead of proper parsing, and checks little to no errors.
\ No newline at end of file diff --git a/Tools/compareTraces/TraceReader.cpp b/Tools/compareTraces/TraceReader.cpp new file mode 100644 index 0000000..b4b9681 --- /dev/null +++ b/Tools/compareTraces/TraceReader.cpp @@ -0,0 +1,46 @@ +#include "stdafx.h" +#include "TraceReader.h" +using namespace Tracing; + +const sTraceItem& TraceReader::operator[]( size_t idx ) const +{ + if( idx >= countItems ) + throw E_BOUNDS; + return items[ idx ]; +} + +CStringA TraceReader::getName( const sTraceItem& item ) const +{ + const size_t idx = item.stringIndex; + if( idx >= countStrings ) + throw E_BOUNDS; + const char* const source = stringData + stringIndex[ idx ]; + CStringA res; + res.Format( source, item.formatArgs[ 0 ], item.formatArgs[ 1 ], item.formatArgs[ 2 ], item.formatArgs[ 3 ] ); + return res; +} + +HRESULT TraceReader::open( LPCTSTR path ) +{ + CHECK( file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING ) ); + CHECK( mapping.MapFile( file ) ); + + const uint8_t* rsi = mapping; + const sFileHeader& header = *(const sFileHeader*)rsi; + if( header.magic != header.correctMagic ) + return E_INVALIDARG; + countItems = header.countItems; + countStrings = header.countStrings; + + rsi += sizeof( sFileHeader ); + payloadPointer = rsi; + + rsi += header.bytesPayload; + stringIndex = (const uint32_t*)( rsi ); + stringData = (const char*)( rsi + countStrings * 4 ); + + rsi += header.bytesStrings; + items = (const sTraceItem*)rsi; + + return S_OK; +}
\ No newline at end of file diff --git a/Tools/compareTraces/TraceReader.h b/Tools/compareTraces/TraceReader.h new file mode 100644 index 0000000..8d1e1f2 --- /dev/null +++ b/Tools/compareTraces/TraceReader.h @@ -0,0 +1,35 @@ +#pragma once +#include "../../Whisper/Utils/Trace/TraceStructures.h" +#include <atlstr.h> +#include <atlfile.h> + +namespace Tracing +{ + class TraceReader + { + const uint8_t* payloadPointer = nullptr; + const sTraceItem* items = nullptr; + size_t countItems = 0; + size_t countStrings = 0; + const uint32_t* stringIndex = nullptr; + const char* stringData = nullptr; + + CAtlFile file; + CAtlFileMapping<uint8_t> mapping; + + public: + + TraceReader() = default; + ~TraceReader() = default; + + HRESULT open( LPCTSTR path ); + size_t size() const { return countItems; } + const sTraceItem& operator[]( size_t idx ) const; + CStringA getName( const sTraceItem& item ) const; + + const void* payload( const sTraceItem& item ) const + { + return payloadPointer + item.payloadOffset; + } + }; +}
\ No newline at end of file diff --git a/Tools/compareTraces/compare.cpp b/Tools/compareTraces/compare.cpp new file mode 100644 index 0000000..ec2a6ef --- /dev/null +++ b/Tools/compareTraces/compare.cpp @@ -0,0 +1,364 @@ +#include "stdafx.h" +#include "../../Whisper/API/iContext.cl.h" +#include "TraceReader.h" +#include "../../Whisper/ML/testUtils.h" +#include "compare.h" +using namespace Tracing; +using namespace DirectCompute; + +namespace +{ + inline const char* cstr( eItemType it ) + { + switch( it ) + { + case eItemType::Buffer: return "Buffer"; + case eItemType::Tensor: return "Tensor"; + } + throw E_INVALIDARG; + } + inline const char* cstr( const CStringA& s ) { return s; } + + inline int tensorDims( __m128i vec ) + { + const __m128i one = _mm_set1_epi32( 1 ); + const uint32_t bitmapOnes = (uint32_t)_mm_movemask_ps( _mm_castsi128_ps( _mm_cmpeq_epi32( vec, one ) ) ); + const uint32_t bitmapNotOnes = bitmapOnes ^ 0b1111u; + unsigned long idx; + if( !_BitScanReverse( &idx, bitmapNotOnes ) ) + return 0; + return idx + 1; + } + + int printSize( __m128i vec ) + { + const int sz = tensorDims( vec ); + switch( sz ) + { + case 0: + printf( "[ scalar ]" ); + break; + case 1: + printf( "[ %i ]", _mm_cvtsi128_si32( vec ) ); + break; + case 2: + printf( "[ %i, %i ]", _mm_cvtsi128_si32( vec ), _mm_extract_epi32( vec, 1 ) ); + break; + case 3: + printf( "[ %i, %i, %i ]", _mm_cvtsi128_si32( vec ), _mm_extract_epi32( vec, 1 ), _mm_extract_epi32( vec, 2 ) ); + break; + case 4: + printf( "[ %i, %i, %i, %i ]", _mm_cvtsi128_si32( vec ), _mm_extract_epi32( vec, 1 ), _mm_extract_epi32( vec, 2 ), _mm_extract_epi32( vec, 3 ) ); + break; + default: + throw E_UNEXPECTED; + } + return sz; + } + + class Comparer + { + TraceReader& readerA; + TraceReader& readerB; + + bool diffBuffers( size_t i, const sTraceItem& a, const sTraceItem& b, const CStringA& name ) + { + const size_t lenA = *(const uint64_t*)a.size.data(); + const size_t lenB = *(const uint64_t*)b.size.data(); + if( lenA != lenB ) + { + printf( "Buffer %zu \"%s\": different size, %zu in trace A, %zu in trace B\n", i, cstr( name ), lenA, lenB ); + return false; + } + if( a.dataType != b.dataType ) + { + printf( "Buffer %zu \"%s\": different data types\n", i, cstr( name ) ); + return false; + } + + switch( a.dataType ) + { + case eDataType::FP32: + return buffersFp32( i, name, (const float*)readerA.payload( a ), (const float*)readerB.payload( b ), lenA ); + } + throw E_NOTIMPL; + } + + bool diffTensors( size_t i, const sTraceItem& a, const sTraceItem& b, const CStringA& name ) + { + const __m128i ne1 = load( a.size ); + const __m128i ne2 = load( b.size ); + if( !vectorEqual( ne1, ne2 ) ) + { + printf( "Tensor %zu \"%s\" - different size: trace A size is ", i, cstr( name ) ); + printSize( ne1 ); + printf( ", trace B size is " ); + printSize( ne2 ); + printf( "\n" ); + return false; + } + + const __m128i stride1 = load( a.stride ); + const __m128i stride2 = load( b.stride ); + if( !vectorEqual( stride1, stride2 ) ) + { + printf( "Tensor %zu \"%s\" - different memory layout\n", i, cstr( name ) ); + return false; + } + + if( a.dataType != b.dataType ) + { + printf( "Tensor %zu \"%s\": different data types\n", i, cstr( name ) ); + return false; + } + + size_t elements = (uint32_t)_mm_cvtsi128_si32( ne1 ); + elements *= (uint32_t)_mm_extract_epi32( ne1, 1 ); + elements *= (uint32_t)_mm_extract_epi32( ne1, 2 ); + elements *= (uint32_t)_mm_extract_epi32( ne1, 3 ); + + switch( a.dataType ) + { + case eDataType::FP32: + return tensorsFp32( i, name, (const float*)readerA.payload( a ), (const float*)readerB.payload( b ), elements, ne1, stride1 ); + } + throw E_NOTIMPL; + } + + protected: + virtual bool buffersFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length ) = 0; + virtual bool tensorsFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length, __m128i ne, __m128i nb ) = 0; + + public: + + Comparer( TraceReader& t1, TraceReader& t2 ) : + readerA( t1 ), readerB( t2 ) { } + + bool compare( size_t i ) + { + const sTraceItem& a = readerA[ i ]; + const sTraceItem& b = readerB[ i ]; + CStringA name1 = readerA.getName( a ); + CStringA name2 = readerB.getName( b ); + + if( a.itemType != b.itemType ) + { + printf( "Item %zu: different type, trace A %s \"%s\", trace B %s \"%s\"\n", i, + cstr( a.itemType ), cstr( name1 ), cstr( b.itemType ), cstr( name2 ) ); + return false; + } + + if( name1 != name2 ) + { + printf( "%s %zu: different names, they are \"%s\" and \"%s\"\n", cstr( a.itemType ), i, cstr( name1 ), cstr( name2 ) ); + return false; + } + + switch( a.itemType ) + { + case eItemType::Buffer: + return diffBuffers( i, a, b, name1 ); + case eItemType::Tensor: + return diffTensors( i, a, b, name1 ); + default: + throw E_INVALIDARG; + } + } + }; + + class PrintSummary : public Comparer + { + bool buffersFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length ) override; + bool tensorsFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length, __m128i ne, __m128i nb ) override; + + public: + PrintSummary( TraceReader& a, TraceReader& b ) : Comparer( a, b ) { } + }; + + bool PrintSummary::buffersFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length ) + { + sTensorDiff diff = computeDiff( a, b, length ); + printf( "%s %zu \"%s\": ", cstr( eItemType::Buffer ), idx, cstr( name ) ); + diff.print(); + return true; + } + + bool PrintSummary::tensorsFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length, __m128i ne, __m128i nb ) + { + printSize( ne ); + printf( " " ); + sTensorDiff diff = computeDiff( a, b, length ); + printf( "%s %zu \"%s\": ", cstr( eItemType::Tensor ), idx, cstr( name ) ); + diff.print(); + return true; + } + + class PrintDiff : public Comparer + { + bool buffersFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length ) override; + bool tensorsFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length, __m128i ne, __m128i nb ) override; + public: + PrintDiff( TraceReader& a, TraceReader& b ) : Comparer( a, b ) { } + }; + + bool PrintDiff::buffersFp32( size_t idx, const CStringA& name, const float* A, const float* B, size_t length ) + { + printf( "idx\tA\tB\tA(hex)\tB(hex)\tdiff\n" ); + for( size_t i = 0; i < length; i++ ) + { + const float a = *A; + const float b = *B; + __m128 vf = _mm_setr_ps( a, b, 0, 0 ); + __m128i vi = _mm_castps_si128( vf ); + const float diff = std::abs( a - b ); + printf( "%zu\t%g\t%g\t0x%08X\t0x%08X\t%g\n", + i, a, b, _mm_cvtsi128_si32( vi ), _mm_extract_epi32( vi, 1 ), diff ); + } + return true; + } + + std::array<uint32_t, 4> storeSize( __m128i v ) + { + std::array<uint32_t, 4> a; + _mm_storeu_si128( ( __m128i* )a.data(), v ); + return a; + } + + std::array<size_t, 4> storeStrides( __m128i v ) + { + const __m128i zero = _mm_setzero_si128(); + std::array<size_t, 4> a; + _mm_storeu_si128( ( __m128i* ) & a[ 0 ], _mm_unpacklo_epi32( v, zero ) ); + _mm_storeu_si128( ( __m128i* ) & a[ 2 ], _mm_unpackhi_epi32( v, zero ) ); + return a; + } + + bool PrintDiff::tensorsFp32( size_t idx, const CStringA& name, const float* A, const float* B, size_t length, __m128i ne, __m128i nb ) + { + const int dims = tensorDims( ne ); + const std::array<uint32_t, 4> size = storeSize( ne ); + const std::array<size_t, 4> strides = storeStrides( ne ); + CStringA line; + if( dims > 4 ) + throw E_UNEXPECTED; + + for( int i = 0; i < dims; i++ ) + { + const char c = "xyzw"[ i ]; + line.AppendChar( c ); + line.AppendChar( '\t' ); + } + line += "A\tB\tA(hex)\tB(hex)\tdiff\n"; + printf( "%s", cstr( line ) ); + + if( 0 == dims ) + { + const float a = *A; + const float b = *B; + __m128 vf = _mm_setr_ps( a, b, 0, 0 ); + __m128i vi = _mm_castps_si128( vf ); + const float diff = std::abs( a - b ); + printf( "%g\t%g\t0x%08X\t0x%08X\t%g\n", + a, b, _mm_cvtsi128_si32( vi ), _mm_extract_epi32( vi, 1 ), diff ); + return true; + } + + size_t offLayer2 = 0; + for( uint32_t w = 0; w < size[ 3 ]; w++, offLayer2 += strides[ 3 ] ) + { + size_t offLayer = offLayer2; + for( uint32_t z = 0; z < size[ 2 ]; z++, offLayer += strides[ 2 ] ) + { + size_t offRow = offLayer; + for( uint32_t y = 0; y < size[ 1 ]; y++, offRow += strides[ 1 ] ) + { + size_t off = offRow; + for( uint32_t x = 0; x < size[ 0 ]; x++, off += strides[ 0 ] ) + { + line.Format( "%i\t", x ); + if( dims > 1 ) + line.AppendFormat( "%i\t", y ); + if( dims > 2 ) + line.AppendFormat( "%i\t", z ); + if( dims > 3 ) + line.AppendFormat( "%i\t", w ); + + const float a = A[ off ]; + const float b = B[ off ]; + __m128 vf = _mm_setr_ps( a, b, 0, 0 ); + __m128i vi = _mm_castps_si128( vf ); + const float diff = std::abs( a - b ); + line.AppendFormat( "%g\t%g\t0x%08X\t0x%08X\t%g\n", + a, b, _mm_cvtsi128_si32( vi ), _mm_extract_epi32( vi, 1 ), diff ); + printf( "%s", cstr( line ) ); + } + } + } + } + return true; + } +} + +HRESULT compareTraces( const CommandLineArgs& arguments ) +{ + const wchar_t* pathA = arguments.inputs[ 0 ]; + const wchar_t* pathB = arguments.inputs[ 1 ]; + + TraceReader a, b; + HRESULT hr = a.open( pathA ); + if( FAILED( hr ) ) + { + fwprintf( stderr, L"Unable to load trace A from \"%s\"", pathA ); + printError( hr ); + return hr; + } + + hr = b.open( pathB ); + if( FAILED( hr ) ) + { + fwprintf( stderr, L"Unable to load trace B from \"%s\"", pathA ); + printError( hr ); + return hr; + } + + wprintf( L"Trace A: %s\n", pathA ); + wprintf( L"Trace B: %s\n", pathB ); + const size_t sizeA = a.size(); + const size_t sizeB = b.size(); + const size_t count = std::min( sizeA, sizeB ); + + if( arguments.printDiff >= 0 ) + { + if( arguments.printDiff >= (int64_t)count ) + { + fprintf( stderr, "Trace A has %zu entries, trace B %zu entries; entry %zu ain't there\n", + sizeA, sizeB, (size_t)arguments.printDiff ); + return E_INVALIDARG; + } + try + { + PrintDiff print{ a, b }; + print.compare( arguments.printDiff ); + return S_OK; + } + catch( HRESULT hr ) + { + return hr; + } + } + + printf( "Trace A has %zu entries, trace B %zu entries, comparing first %zu\n", sizeA, sizeB, count ); + + try + { + PrintSummary print{ a, b }; + for( size_t i = 0; i < count; i++ ) + if( !print.compare( i ) ) + return S_FALSE; + return S_OK; + } + catch( HRESULT hr ) + { + return hr; + } +}
\ No newline at end of file diff --git a/Tools/compareTraces/compare.h b/Tools/compareTraces/compare.h new file mode 100644 index 0000000..2a4cd86 --- /dev/null +++ b/Tools/compareTraces/compare.h @@ -0,0 +1,4 @@ +#pragma once +#include "CommandLineArgs.h" + +HRESULT compareTraces( const CommandLineArgs& arguments );
\ No newline at end of file diff --git a/Tools/compareTraces/compareTraces.cpp b/Tools/compareTraces/compareTraces.cpp new file mode 100644 index 0000000..8813500 --- /dev/null +++ b/Tools/compareTraces/compareTraces.cpp @@ -0,0 +1,16 @@ +#include "stdafx.h" +#include <stdio.h> +#include "compare.h" +#include "CommandLineArgs.h" + +int wmain( int argc, wchar_t* argv[] ) +{ + CommandLineArgs cla; + if( !cla.parse( argc, argv ) ) + return 1; + + HRESULT hr = compareTraces( cla ); + if( SUCCEEDED( hr ) ) + return 0; + return hr; +}
\ No newline at end of file diff --git a/Tools/compareTraces/compareTraces.vcxproj b/Tools/compareTraces/compareTraces.vcxproj new file mode 100644 index 0000000..a9670b3 --- /dev/null +++ b/Tools/compareTraces/compareTraces.vcxproj @@ -0,0 +1,103 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{8478a77c-d851-4c63-9511-1770cc82d33e}</ProjectGuid> + <RootNamespace>compareTraces</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + <PrecompiledHeader>Use</PrecompiledHeader> + <EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + <PrecompiledHeader>Use</PrecompiledHeader> + <EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ClCompile Include="CommandLineArgs.cpp" /> + <ClCompile Include="compareTraces.cpp" /> + <ClCompile Include="compare.cpp" /> + <ClCompile Include="stdafx.cpp"> + <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader> + <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader> + </ClCompile> + <ClCompile Include="testUtils.cpp" /> + <ClCompile Include="TraceReader.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="CommandLineArgs.h" /> + <ClInclude Include="compare.h" /> + <ClInclude Include="stdafx.h" /> + <ClInclude Include="TraceReader.h" /> + </ItemGroup> + <ItemGroup> + <Text Include="Readme.txt" /> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/Tools/compareTraces/compareTraces.vcxproj.filters b/Tools/compareTraces/compareTraces.vcxproj.filters new file mode 100644 index 0000000..1d0d4c3 --- /dev/null +++ b/Tools/compareTraces/compareTraces.vcxproj.filters @@ -0,0 +1,20 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <ClCompile Include="compareTraces.cpp" /> + <ClCompile Include="TraceReader.cpp" /> + <ClCompile Include="compare.cpp" /> + <ClCompile Include="stdafx.cpp" /> + <ClCompile Include="testUtils.cpp" /> + <ClCompile Include="CommandLineArgs.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="TraceReader.h" /> + <ClInclude Include="stdafx.h" /> + <ClInclude Include="compare.h" /> + <ClInclude Include="CommandLineArgs.h" /> + </ItemGroup> + <ItemGroup> + <Text Include="Readme.txt" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/Tools/compareTraces/stdafx.cpp b/Tools/compareTraces/stdafx.cpp new file mode 100644 index 0000000..5c7f6c9 --- /dev/null +++ b/Tools/compareTraces/stdafx.cpp @@ -0,0 +1,30 @@ +#include "stdafx.h" + +namespace +{ + wchar_t* formatMessage( HRESULT hr ) + { + wchar_t* err; + if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + hr, + MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), + (LPTSTR)&err, + 0, + NULL ) ) + return err; + return nullptr; + } +} + +void printError( HRESULT hr ) +{ + const wchar_t* err = formatMessage( hr ); + if( nullptr != err ) + { + fwprintf( stderr, L"%s\n", err ); + LocalFree( (HLOCAL)err ); + } + else + fprintf( stderr, "Error code %i (0x%08X)\n", hr, hr ); +}
\ No newline at end of file diff --git a/Tools/compareTraces/stdafx.h b/Tools/compareTraces/stdafx.h new file mode 100644 index 0000000..1e496f3 --- /dev/null +++ b/Tools/compareTraces/stdafx.h @@ -0,0 +1,40 @@ +#pragma once +#include <stdint.h> +#include <assert.h> + +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include <windows.h> +#include <atlstr.h> +#include <d3d11.h> + +#include <vector> +#include <array> +#include <emmintrin.h> +#include <smmintrin.h> + +#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } + +inline __m128i load16( const int* rsi ) +{ + return _mm_loadu_si128( ( const __m128i* )rsi ); +} +inline __m128i load16( const uint32_t* rsi ) +{ + return _mm_loadu_si128( ( const __m128i* )rsi ); +} +inline __m128i load( const std::array<uint32_t, 4>& arr ) +{ + return load16( arr.data() ); +} + +inline bool vectorEqual( __m128i a, __m128i b ) +{ + __m128i xx = _mm_xor_si128( a, b ); + return (bool)_mm_testz_si128( xx, xx ); +} + +void printError( HRESULT hr ); + +inline const char* cstr( const CStringA& s ) { return s; } +inline const wchar_t* cstr( const CString& s ) { return s; }
\ No newline at end of file diff --git a/Tools/compareTraces/testUtils.cpp b/Tools/compareTraces/testUtils.cpp new file mode 100644 index 0000000..f9fa465 --- /dev/null +++ b/Tools/compareTraces/testUtils.cpp @@ -0,0 +1,224 @@ +#include "stdafx.h" +#include "../../Whisper/ML/testUtils.h" +#include <immintrin.h> +using namespace DirectCompute; + +namespace +{ + using DirectCompute::sTensorDiff; + + __forceinline __m256 load( const float* rsi ) + { + return _mm256_loadu_ps( rsi ); + } + + __forceinline __m256 load( const uint16_t* rsi ) + { + const __m128i iv = _mm_load_si128( ( const __m128i* )rsi ); + return _mm256_cvtph_ps( iv ); + } + + __forceinline void loadPartial( const uint16_t* x, const uint16_t* y, size_t count, __m256& fx, __m256& fy ) + { + __m128i ix, iy; + switch( count ) + { + case 1: // load 2 bytes + ix = _mm_cvtsi32_si128( *x ); + iy = _mm_cvtsi32_si128( *y ); + break; + case 2: // load 4 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + iy = _mm_cvtsi32_si128( *(const int*)y ); + break; + case 3: // load 6 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + iy = _mm_cvtsi32_si128( *(const int*)y ); + ix = _mm_insert_epi16( ix, x[ 2 ], 2 ); + iy = _mm_insert_epi16( iy, y[ 2 ], 2 ); + break; + case 4: // load 8 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + break; + case 5: // load 10 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi16( ix, x[ 4 ], 4 ); + iy = _mm_insert_epi16( iy, y[ 4 ], 4 ); + break; + case 6: // load 12 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 ); + break; + case 7: // load 14 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 ); + ix = _mm_insert_epi16( ix, x[ 6 ], 6 ); + iy = _mm_insert_epi16( iy, y[ 6 ], 6 ); + break; + default: + fx = fy = _mm256_setzero_ps(); + return; + } + + fx = _mm256_cvtph_ps( ix ); + fy = _mm256_cvtph_ps( iy ); + } + + inline __m128 loadFloat2( const float* rsi ) + { + return _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) ); + } + inline __m128 loadFloat3( const float* rsi ) + { + __m128 f = loadFloat2( rsi ); + f = _mm_insert_ps( f, _mm_load_ss( rsi + 2 ), 0x20 ); + return f; + } + __forceinline void loadPartial( const float* x, const float* y, size_t count, __m256& fx, __m256& fy ) + { + __m128 low1, high1; + __m128 low2, high2; + high1 = high2 = _mm_setzero_ps(); + switch( count ) + { + case 1: + low1 = _mm_load_ss( x ); + low2 = _mm_load_ss( y ); + break; + case 2: + low1 = loadFloat2( x ); + low2 = loadFloat2( y ); + break; + case 3: + low1 = loadFloat3( x ); + low2 = loadFloat3( y ); + break; + case 4: + low1 = _mm_loadu_ps( x ); + low2 = _mm_loadu_ps( y ); + break; + case 5: + low1 = _mm_loadu_ps( x ); + low2 = _mm_loadu_ps( y ); + high1 = _mm_load_ss( x + 4 ); + high2 = _mm_load_ss( y + 4 ); + break; + case 6: + low1 = _mm_loadu_ps( x ); + low2 = _mm_loadu_ps( y ); + high1 = loadFloat2( x + 4 ); + high2 = loadFloat2( y + 4 ); + break; + case 7: // load 14 bytes + low1 = _mm_loadu_ps( x ); + low2 = _mm_loadu_ps( y ); + high1 = loadFloat3( x + 4 ); + high2 = loadFloat3( y + 4 ); + break; + default: + fx = fy = _mm256_setzero_ps(); + return; + } + + fx = _mm256_setr_m128( low1, high1 ); + fy = _mm256_setr_m128( low2, high2 ); + } + + __forceinline float horizontalMaximum( __m256 v ) + { + __m128 s = _mm256_extractf128_ps( v, 1 ); + s = _mm_max_ps( s, _mm256_castps256_ps128( v ) ); + s = _mm_max_ps( s, _mm_movehl_ps( s, s ) ); + s = _mm_max_ss( s, _mm_movehdup_ps( s ) ); + return _mm_cvtss_f32( s ); + } + + __forceinline double horizontalSum( __m256 v ) + { + __m256d d = _mm256_cvtps_pd( _mm256_extractf128_ps( v, 1 ) ); + d = _mm256_add_pd( d, _mm256_cvtps_pd( _mm256_castps256_ps128( v ) ) ); + + __m128d s = _mm256_extractf128_pd( d, 1 ); + s = _mm_add_pd( s, _mm256_castpd256_pd128( d ) ); + s = _mm_add_sd( s, _mm_unpackhi_pd( s, s ) ); + return _mm_cvtsd_f64( s ); + } + + __m256 maskInfNan( __m256 diff, __m256 a, __m256 b ) + { + __m256i ai = _mm256_castps_si256( a ); + __m256i bi = _mm256_castps_si256( b ); + __m256i eqi = _mm256_cmpeq_epi32( ai, bi ); + __m256 eq = _mm256_castsi256_ps( eqi ); + return _mm256_andnot_ps( eq, diff ); + } + + class DiffAcc + { + __m256 maxAbs = _mm256_setzero_ps(); + __m256 sumSquares = _mm256_setzero_ps(); + + public: + + __forceinline void add( __m256 a, __m256 b ) + { + const __m256 neg0 = _mm256_set1_ps( -0.0f ); + __m256 diff = _mm256_sub_ps( b, a ); + diff = maskInfNan( diff, a, b ); + sumSquares = _mm256_fmadd_ps( diff, diff, sumSquares ); + const __m256 absDiff = _mm256_andnot_ps( neg0, diff ); + maxAbs = _mm256_max_ps( maxAbs, absDiff ); + } + + __forceinline sTensorDiff reduce( size_t count ) + { + sTensorDiff res; + res.maxAbsDiff = horizontalMaximum( maxAbs ); + res.avgDiffSquared = (float)( horizontalSum( sumSquares ) / (double)(int64_t)count ); + res.length = count; + return res; + } + }; + + template<class E> + static sTensorDiff __declspec( noinline ) diffVectors( const E* a, const E* b, size_t length ) + { + // const E* const aEnd = a + length; + const E* const aEndAligned = a + ( length / 8 ) * 8; + const size_t remainder = length % 8; + + DiffAcc acc; + for( ; a < aEndAligned; a += 8, b += 8 ) + acc.add( load( a ), load( b ) ); + + if( remainder != 0 ) + { + __m256 va, vb; + loadPartial( a, b, remainder, va, vb ); + acc.add( va, vb ); + } + + return acc.reduce( length ); + } +} + +sTensorDiff DirectCompute::computeDiff( const float* a, const float* b, size_t length ) +{ + return diffVectors( a, b, length ); +} + +sTensorDiff DirectCompute::computeDiff( const uint16_t* a, const uint16_t* b, size_t length ) +{ + return diffVectors( a, b, length ); +} + +void DirectCompute::sTensorDiff::print() const +{ + printf( "%zu elements, maxAbsDiff = %g, avgDiffSquared = %g\n", length, maxAbsDiff, avgDiffSquared ); +}
\ No newline at end of file diff --git a/Whisper/API/MfStructs.h b/Whisper/API/MfStructs.h new file mode 100644 index 0000000..cd27659 --- /dev/null +++ b/Whisper/API/MfStructs.h @@ -0,0 +1,51 @@ +#pragma once + +namespace Whisper +{ + struct sCaptureDevice + { + // The display name is suitable for showing to the user, but might not be unique. + const wchar_t* displayName; + + // Endpoint ID for an audio capture device + // It uniquely identifies the device on the system, but is not a readable string. + const wchar_t* endpoint; + }; + + using pfnFoundCaptureDevices = HRESULT( __stdcall* )( int len, const sCaptureDevice* buffer, void* pv ); + + // Flags for the audio capture + enum struct eCaptureFlags : uint32_t + { + // When the capture device supports stereo, keep stereo PCM samples in addition to mono + Stereo = 1, + }; + + // Parameters for audio capture + struct sCaptureParams + { + float minDuration = 2.0f; + float maxDuration = 3.0f; + float dropStartSilence = 0.25f; + float pauseDuration = 0.333f; + // Flags for the audio capture + uint32_t flags = 0; + }; + + enum struct eCaptureStatus : uint8_t + { + Listening = 1, + Voice = 2, + Transcribing = 4, + Stalled = 0x80, + }; + + using pfnShouldCancel = HRESULT( __stdcall* )( void* pv ) noexcept; + using pfnCaptureStatus = HRESULT( __stdcall* )( void* pv, eCaptureStatus status ) noexcept; + struct sCaptureCallbacks + { + pfnShouldCancel shouldCancel; + pfnCaptureStatus captureStatus; + void* pv; + }; +}
\ No newline at end of file diff --git a/Whisper/API/Readme.txt b/Whisper/API/Readme.txt new file mode 100644 index 0000000..7d40494 --- /dev/null +++ b/Whisper/API/Readme.txt @@ -0,0 +1,15 @@ +The headers in this folder define the complete public API of Whisper.dll. + +To consume the library in your C++ software, include exactly one of the following headers. + +1. If you’re building a windows app, include whisperWindows.h header, and you'll get traditional Win32 COM projection of the API. + +2. If you’re porting to other OS, or porting to different C++ compiler, or already using ComLight support library, include whisperComLight.h header. +If you do that, in addition to this "Whisper/API" folder you also gonna need the "ComLightLib" dependency. +This will get you the ComLight flavor of these COM interfaces. + +Internally, the actual implementation uses the ComLight flavour of the interfaces, but that’s fine because they are binary compatible. + +The reason for the difference between these flavors — Visual Studio’s CComPtr<T> and other related utilities expect interface IDs specified with __declspec(uuid) directive. + +That language extension is specific to Visual C++, not supported in GCC nor Clang compilers.
\ No newline at end of file diff --git a/Whisper/API/SpecialTokens.h b/Whisper/API/SpecialTokens.h new file mode 100644 index 0000000..67fd020 --- /dev/null +++ b/Whisper/API/SpecialTokens.h @@ -0,0 +1,25 @@ +#pragma once + +namespace Whisper +{ + struct SpecialTokens + { + // The end of a transcription, token_eot + int TranscriptionEnd; + // Start of a transcription, token_sot + int TranscriptionStart; + // Represents the previous word in the transcription. It is used to help the model predict the current word based on the context of the words that came before it. + int PreviousWord; // token_prev + // Start of a sentence + int SentenceStart; // token_solm + //Represents the word "not" in the transcription + int Not; // token_not + //New transcription + int TranscriptionBegin; // token_beg + + // token_translate + int TaskTranslate; + // token_transcribe + int TaskTranscribe; + }; +}
\ No newline at end of file diff --git a/Whisper/API/TranscribeStructs.h b/Whisper/API/TranscribeStructs.h new file mode 100644 index 0000000..ac28357 --- /dev/null +++ b/Whisper/API/TranscribeStructs.h @@ -0,0 +1,127 @@ +#pragma once +#include <stdint.h> +#include <assert.h> + +namespace Whisper +{ + enum struct eModelImplementation : uint32_t + { + GPU = 1, + Hybrid = 2, + Reference = 3, + }; + + struct sTimeSpanFields + { + uint32_t days; + uint8_t hours, minutes, seconds; + uint32_t ticks; + + sTimeSpanFields( uint64_t tt ) + { + ticks = (uint32_t)( tt % 10'000'000 ); + tt /= 10'000'000; + seconds = (uint8_t)( tt % 60 ); + tt /= 60; + minutes = (uint8_t)( tt % 60 ); + tt /= 60; + hours = (uint8_t)( tt % 24 ); + tt /= 24; + days = (uint32_t)tt; + } + }; + + struct sTimeSpan + { + uint64_t ticks; + + operator sTimeSpanFields() const + { + return sTimeSpanFields{ ticks }; + } + void operator=( uint64_t tt ) + { + ticks = tt; + } + void operator=( int64_t tt ) + { + assert( tt >= 0 ); + ticks = (uint64_t)tt; + } + }; + + // Start and end times of the segment or token, expressed in 100-nanosecond ticks + struct sTimeInterval + { + sTimeSpan begin, end; + }; + + // Segment data + struct sSegment + { + // Segment text, null-terminated, and probably UTF-8 encoded + const char* text; + // Start and end times of the segment + sTimeInterval time; + uint32_t firstToken, countTokens; + }; + + enum eTokenFlags : uint32_t + { + None = 0, + Special = 1, + }; + inline bool operator &( eTokenFlags a, eTokenFlags b ) + { + return 0 != ( (uint32_t)a & (uint32_t)b ); + } + + // Token data + struct sToken + { + // Token text, null-terminated, and probably UTF-8 encoded + const char* text; + // Start and end times of the token + sTimeInterval time; + // Probability of the token + float probability; + // Probability of the timestamp token + float probabilityTimestamp; + // Sum of probabilities of all timestamp tokens + float ptsum; + // Voice length of the token + float vlen; + // Token id + int id; + eTokenFlags flags; + }; + + struct sTranscribeLength + { + uint32_t countSegments, countTokens; + }; + + enum struct eResultFlags : uint32_t + { + None = 0, + // Return individual tokens in addition to the segments + Tokens = 1, + // Return timestamps + Timestamps = 2, + + // Create a new COM object for the results. + // Without this flag, the context returns a pointer to the COM object stored in the context. + // The content of that object is replaced every time you call iContext.getResults method + NewObject = 0x100, + }; + + inline eResultFlags operator |( eResultFlags a, eResultFlags b ) + { + return (eResultFlags)( (uint32_t)a | (uint32_t)b ); + } + + inline bool operator &( eResultFlags a, eResultFlags b ) + { + return 0 != ( (uint32_t)a & (uint32_t)b ); + } +}
\ No newline at end of file diff --git a/Whisper/API/iContext.cl.h b/Whisper/API/iContext.cl.h new file mode 100644 index 0000000..97d34c7 --- /dev/null +++ b/Whisper/API/iContext.cl.h @@ -0,0 +1,66 @@ +#pragma once +#include "../../ComLightLib/comLightCommon.h" +#include "iTranscribeResult.cl.h" +#include "SpecialTokens.h" +#include "loggerApi.h" +#include "sLanguageList.h" +#include "sLoadModelCallbacks.h" + +namespace Whisper +{ + struct iModel; + struct iAudioBuffer; + struct iAudioReader; + struct iAudioCapture; + struct sCaptureCallbacks; + struct sFullParams; + enum struct eModelImplementation : uint32_t; + enum struct eSamplingStrategy : int; + using whisper_token = int; + struct sProgressSink; + + struct DECLSPEC_NOVTABLE iContext : public ComLight::IUnknown + { + DEFINE_INTERFACE_ID( "{b9956374-3b18-4943-90f2-2ab18a404537}" ); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Uses the specified decoding strategy to obtain the text. + virtual HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) = 0; + virtual HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) = 0; + virtual HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) = 0; + + virtual HRESULT COMLIGHTCALL getResults( eResultFlags flags, iTranscribeResult** pp ) const = 0; + + virtual HRESULT COMLIGHTCALL getModel( iModel** pp ) = 0; + + virtual HRESULT COMLIGHTCALL fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi ) = 0; + + // Performance information + virtual HRESULT COMLIGHTCALL timingsPrint() = 0; + virtual HRESULT COMLIGHTCALL timingsReset() = 0; + }; + + struct DECLSPEC_NOVTABLE iModel : public ComLight::IUnknown + { + DEFINE_INTERFACE_ID( "{abefb4c9-e8d8-46a3-8747-5afbadef1adb}" ); + + virtual HRESULT COMLIGHTCALL createContext( iContext** pp ) = 0; + + virtual HRESULT COMLIGHTCALL isMultilingual() = 0; + + virtual HRESULT COMLIGHTCALL getSpecialTokens( SpecialTokens& rdi ) = 0; + + // Token Id -> String + virtual const char* COMLIGHTCALL stringFromToken( whisper_token token ) = 0; + }; + + HRESULT COMLIGHTCALL setupLogger( const sLoggerSetup& setup ); + HRESULT COMLIGHTCALL loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp ); + + uint32_t COMLIGHTCALL findLanguageKeyW( const wchar_t* lang ); + uint32_t COMLIGHTCALL findLanguageKeyA( const char* lang ); + + HRESULT COMLIGHTCALL getSupportedLanguages( sLanguageList& rdi ); +} + +#include "sFullParams.h"
\ No newline at end of file diff --git a/Whisper/API/iContext.h b/Whisper/API/iContext.h new file mode 100644 index 0000000..9661093 --- /dev/null +++ b/Whisper/API/iContext.h @@ -0,0 +1,61 @@ +#pragma once +#include "iTranscribeResult.h" +#include "SpecialTokens.h" +#include "loggerApi.h" +#include "sLanguageList.h" +#include "sLoadModelCallbacks.h" + +namespace Whisper +{ + __interface iModel; + __interface iAudioBuffer; + __interface iAudioReader; + __interface iAudioCapture; + struct sCaptureCallbacks; + struct sFullParams; + enum struct eModelImplementation : uint32_t; + enum struct eSamplingStrategy : int; + using whisper_token = int; + struct sProgressSink; + + __interface __declspec( novtable, uuid( "b9956374-3b18-4943-90f2-2ab18a404537" ) ) iContext : public IUnknown + { + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Uses the specified decoding strategy to obtain the text. + HRESULT __stdcall runFull( const sFullParams& params, const iAudioBuffer* buffer ); + HRESULT __stdcall runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ); + HRESULT __stdcall runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ); + + HRESULT __stdcall getResults( eResultFlags flags, iTranscribeResult** pp ) const; + + HRESULT __stdcall getModel( iModel** pp ); + + HRESULT __stdcall fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi ); + + // Performance information + HRESULT __stdcall timingsPrint(); + HRESULT __stdcall timingsReset(); + }; + + __interface __declspec( novtable, uuid( "abefb4c9-e8d8-46a3-8747-5afbadef1adb" ) ) iModel : public IUnknown + { + HRESULT __stdcall createContext( iContext** pp ); + + HRESULT __stdcall isMultilingual(); + + HRESULT __stdcall getSpecialTokens( SpecialTokens& rdi ); + + // Token Id -> String + const char* __stdcall stringFromToken( whisper_token token ); + }; + + HRESULT __stdcall setupLogger( const sLoggerSetup& setup ); + HRESULT __stdcall loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp ); + + uint32_t __stdcall findLanguageKeyW( const wchar_t* lang ); + uint32_t __stdcall findLanguageKeyA( const char* lang ); + + HRESULT __stdcall getSupportedLanguages( sLanguageList& rdi ); +} + +#include "sFullParams.h"
\ No newline at end of file diff --git a/Whisper/API/iMediaFoundation.cl.h b/Whisper/API/iMediaFoundation.cl.h new file mode 100644 index 0000000..516b67f --- /dev/null +++ b/Whisper/API/iMediaFoundation.cl.h @@ -0,0 +1,48 @@ +#pragma once +#include "../../ComLightLib/comLightCommon.h" +#include "MfStructs.h" + +struct IMFSourceReader; + +namespace Whisper +{ + struct DECLSPEC_NOVTABLE iAudioBuffer : public ComLight::IUnknown + { + DEFINE_INTERFACE_ID( "{013583aa-c9eb-42bc-83db-633c2c317051}" ); + + virtual uint32_t COMLIGHTCALL countSamples() const = 0; + virtual const float* COMLIGHTCALL getPcmMono() const = 0; + virtual const float* COMLIGHTCALL getPcmStereo() const = 0; + virtual HRESULT COMLIGHTCALL getTime( int64_t& rdi ) const = 0; + }; + + struct DECLSPEC_NOVTABLE iAudioReader : public ComLight::IUnknown + { + DEFINE_INTERFACE_ID( "{35b988da-04a6-476a-a193-d8891d5dc390}" ); + + virtual HRESULT COMLIGHTCALL getDuration( int64_t& rdi ) const = 0; + virtual HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const = 0; + virtual HRESULT COMLIGHTCALL requestedStereo() const = 0; + }; + + struct DECLSPEC_NOVTABLE iAudioCapture : public ComLight::IUnknown + { + DEFINE_INTERFACE_ID( "{747752c2-d9fd-40df-8847-583c781bf013}" ); + + virtual HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const = 0; + virtual const sCaptureParams& COMLIGHTCALL getParams() const = 0; + }; + + struct DECLSPEC_NOVTABLE iMediaFoundation : public ComLight::IUnknown + { + DEFINE_INTERFACE_ID( "{fb9763a5-d77d-4b6e-aff8-f494813cebd8}" ); + + virtual HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const = 0; + virtual HRESULT COMLIGHTCALL openAudioFile( LPCTSTR path, bool stereo, iAudioReader** pp ) = 0; + + virtual HRESULT COMLIGHTCALL listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ) = 0; + virtual HRESULT COMLIGHTCALL openCaptureDevice( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) = 0; + }; + + HRESULT COMLIGHTCALL initMediaFoundation( iMediaFoundation** pp ); +}
\ No newline at end of file diff --git a/Whisper/API/iMediaFoundation.h b/Whisper/API/iMediaFoundation.h new file mode 100644 index 0000000..93dc287 --- /dev/null +++ b/Whisper/API/iMediaFoundation.h @@ -0,0 +1,39 @@ +#pragma once +#include <stdint.h> +#include "MfStructs.h" +struct IMFSourceReader; + +namespace Whisper +{ + __interface __declspec( novtable, uuid( "013583aa-c9eb-42bc-83db-633c2c317051" ) ) iAudioBuffer : public IUnknown + { + uint32_t __stdcall countSamples() const; + const float* __stdcall getPcmMono() const; + const float* __stdcall getPcmStereo() const; + HRESULT __stdcall getTime( int64_t& rdi ) const; + }; + + __interface __declspec( novtable, uuid( "35b988da-04a6-476a-a193-d8891d5dc390" ) ) iAudioReader : public IUnknown + { + HRESULT __stdcall getDuration( int64_t& rdi ) const; + HRESULT __stdcall getReader( IMFSourceReader** pp ) const; + HRESULT __stdcall requestedStereo() const; + }; + + __interface __declspec( novtable, uuid( "747752c2-d9fd-40df-8847-583c781bf013" ) ) iAudioCapture : public IUnknown + { + HRESULT __stdcall getReader( IMFSourceReader** pp ) const; + const sCaptureParams& __stdcall getParams() const; + }; + + __interface __declspec( novtable, uuid( "fb9763a5-d77d-4b6e-aff8-f494813cebd8" ) ) iMediaFoundation : public IUnknown + { + HRESULT __stdcall loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const; + HRESULT __stdcall openAudioFile( LPCTSTR path, bool stereo, iAudioReader** pp ); + + HRESULT __stdcall listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ); + HRESULT __stdcall openCaptureDevice( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ); + }; + + HRESULT __stdcall initMediaFoundation( iMediaFoundation** pp ); +}
\ No newline at end of file diff --git a/Whisper/API/iTranscribeResult.cl.h b/Whisper/API/iTranscribeResult.cl.h new file mode 100644 index 0000000..ab65178 --- /dev/null +++ b/Whisper/API/iTranscribeResult.cl.h @@ -0,0 +1,15 @@ +#pragma once +#include "TranscribeStructs.h" +#include "../../ComLightLib/comLightCommon.h" + +namespace Whisper +{ + struct iTranscribeResult : public ComLight::IUnknown + { + DEFINE_INTERFACE_ID( "{2871a73f-5ce3-48f8-8779-6582ee11935e}" ); + + virtual HRESULT COMLIGHTCALL getSize( sTranscribeLength& rdi ) const = 0; + virtual const sSegment* COMLIGHTCALL getSegments() const = 0; + virtual const sToken* COMLIGHTCALL getTokens() const = 0; + }; +}
\ No newline at end of file diff --git a/Whisper/API/iTranscribeResult.h b/Whisper/API/iTranscribeResult.h new file mode 100644 index 0000000..27e0c0d --- /dev/null +++ b/Whisper/API/iTranscribeResult.h @@ -0,0 +1,12 @@ +#pragma once +#include "TranscribeStructs.h" + +namespace Whisper +{ + __interface __declspec( novtable, uuid( "2871a73f-5ce3-48f8-8779-6582ee11935e" ) ) iTranscribeResult : public IUnknown + { + HRESULT __stdcall getSize( sTranscribeLength& rdi ) const; + const sSegment* __stdcall getSegments() const; + const sToken* __stdcall getTokens() const; + }; +}
\ No newline at end of file diff --git a/Whisper/API/loggerApi.h b/Whisper/API/loggerApi.h new file mode 100644 index 0000000..6af1c4e --- /dev/null +++ b/Whisper/API/loggerApi.h @@ -0,0 +1,35 @@ +#pragma once +#include <stdint.h> + +namespace Whisper +{ + // Log level for messages + enum struct eLogLevel : uint8_t + { + Error = 0, + Warning = 1, + Info = 2, + Debug = 3 + }; + enum struct eLoggerFlags : uint8_t + { + UseStandardError = 1, + SkipFormatMessage = 2, + }; + + // C function pointer to receive log messages from the library. The messages are encoded in UTF-8. + using pfnLoggerSink = void( __stdcall* )( void* context, eLogLevel lvl, const char* message ); + + // A sink to receive log messages produced by MeshRepair.dll + struct sLoggerSetup + { + // C function pointer to receive log messages from the library + pfnLoggerSink sink = nullptr; + // Optional context parameter for the sink function; when consuming from C# you don't need that, pass IntPtr.Zero, delegates can capture things. + void* context = nullptr; + // Maximum log level to produce + eLogLevel level; + // Flags about the logger + eLoggerFlags flags = (eLoggerFlags)0; + }; +}
\ No newline at end of file diff --git a/Whisper/API/sFullParams.h b/Whisper/API/sFullParams.h new file mode 100644 index 0000000..0a1d352 --- /dev/null +++ b/Whisper/API/sFullParams.h @@ -0,0 +1,136 @@ +#pragma once +#include <stdint.h> +#include <assert.h> + +namespace Whisper +{ + // Available sampling strategies + enum struct eSamplingStrategy : int + { + // Always select the most probable token + Greedy, + // TODO: not implemented yet! + BeamSearch, + }; + + using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept; + using pfnEncoderBegin = HRESULT( __cdecl* )( iContext* ctx, void* user_data ) noexcept; + + enum struct eFullParamsFlags : uint32_t + { + Translate = 1, + NoContext = 2, + SingleSegment = 4, + PrintSpecial = 8, + PrintProgress = 0x10, + PrintRealtime = 0x20, + PrintTimestamps = 0x40, + + // Experimental + TokenTimestamps = 0x100, + SpeedupAudio = 0x200, + }; + + inline eFullParamsFlags operator | ( eFullParamsFlags a, eFullParamsFlags b ) + { + return (eFullParamsFlags)( (uint32_t)a | (uint32_t)b ); + } + inline void operator |= ( eFullParamsFlags& a, eFullParamsFlags b ) + { + a = a | b; + } + + struct sFullParams + { + eSamplingStrategy strategy; + // Count of CPU threads + int cpuThreads; + int n_max_text_ctx; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + eFullParamsFlags flags; + uint32_t language; + + // [EXPERIMENTAL] token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + int max_tokens; // max tokens per segment (0 = no limit) + + struct + { + int n_past; + } greedy; + + struct + { + int n_past; + int beam_width; + int n_best; + } beam_search; + + // [EXPERIMENTAL] speed-up techniques + int audio_ctx; // overwrite the audio context size (0 = use default) + + // tokens to provide the whisper model as initial prompt + // these are prepended to any existing text context from a previous call + const whisper_token* prompt_tokens; + int prompt_n_tokens; + + pfnNewSegment new_segment_callback; + void* new_segment_callback_user_data; + + pfnEncoderBegin encoder_begin_callback; + void* encoder_begin_callback_user_data; + + // Couple utility methods, they workaround the lack of bit fields in C++ + inline bool flag( eFullParamsFlags f ) const + { + return 0 != ( (uint32_t)flags & (uint32_t)f ); + } + inline void resetFlag( eFullParamsFlags bit ) + { + uint32_t f = (uint32_t)flags; + f &= ~(uint32_t)bit; + flags = (eFullParamsFlags)f; + } + inline void setFlag( eFullParamsFlags bit, bool set = true ) + { + uint32_t f = (uint32_t)flags; + if( set ) + f |= (uint32_t)bit; + else + f &= ~(uint32_t)bit; + flags = (eFullParamsFlags)f; + } + }; + + struct sSegmentTime + { + int64_t begin, end; + }; + + inline uint32_t makeLanguageKey( const char* code ) + { + assert( strlen( code ) <= 4 ); + uint32_t res = 0; + uint32_t shift = 0; + for( size_t i = 0; i < 4; i++, code++, shift += 8 ) + { + const char c = *code; + if( c == '\0' ) + return res; + uint32_t u32 = (uint8_t)c; + u32 = u32 << shift; + res |= u32; + } + return res; + } + + using pfnReportProgress = HRESULT( __stdcall* )( double val, iContext* ctx, void* pv ) noexcept; + struct sProgressSink + { + pfnReportProgress pfn; + void* pv; + }; +}
\ No newline at end of file diff --git a/Whisper/API/sLanguageList.h b/Whisper/API/sLanguageList.h new file mode 100644 index 0000000..49ca596 --- /dev/null +++ b/Whisper/API/sLanguageList.h @@ -0,0 +1,18 @@ +#pragma once +#include <stdint.h> + +namespace Whisper +{ + struct sLanguageEntry + { + uint32_t key; + int id; + const char* name; + }; + + struct sLanguageList + { + uint32_t length; + const sLanguageEntry* pointer; + }; +}
\ No newline at end of file diff --git a/Whisper/API/sLoadModelCallbacks.h b/Whisper/API/sLoadModelCallbacks.h new file mode 100644 index 0000000..f5248c6 --- /dev/null +++ b/Whisper/API/sLoadModelCallbacks.h @@ -0,0 +1,14 @@ +#pragma once + +namespace Whisper +{ + using pfnLoadProgress = HRESULT( __stdcall* )( double val, void* pv ) noexcept; + using pfnCancel = HRESULT( __stdcall* )( void* pv ) noexcept; + + struct sLoadModelCallbacks + { + pfnLoadProgress progress; + pfnCancel cancel; + void* pv; + }; +}
\ No newline at end of file diff --git a/Whisper/API/whisperComLight.h b/Whisper/API/whisperComLight.h new file mode 100644 index 0000000..c7f0b93 --- /dev/null +++ b/Whisper/API/whisperComLight.h @@ -0,0 +1,4 @@ +#pragma once +#include "iMediaFoundation.cl.h" +#include "iContext.cl.h" +#include "iTranscribeResult.cl.h"
\ No newline at end of file diff --git a/Whisper/API/whisperWindows.h b/Whisper/API/whisperWindows.h new file mode 100644 index 0000000..925e307 --- /dev/null +++ b/Whisper/API/whisperWindows.h @@ -0,0 +1,4 @@ +#pragma once +#include "iMediaFoundation.h" +#include "iContext.h" +#include "iTranscribeResult.h"
\ No newline at end of file diff --git a/Whisper/CPU/BufferAllocator.cpp b/Whisper/CPU/BufferAllocator.cpp new file mode 100644 index 0000000..156382c --- /dev/null +++ b/Whisper/CPU/BufferAllocator.cpp @@ -0,0 +1,145 @@ +#include <stdafx.h> +#include "BufferAllocator.h" +#include <immintrin.h> +#include <ammintrin.h> +using namespace CpuCompute; + +HRESULT BufferAllocator::create( size_t cb ) +{ + CHECK( buffer.allocate( cb ) ); + head = 0; + size = cb; + dbgMarkUninitializedMemory( buffer.pointer(), cb ); + return S_OK; +} + +namespace +{ + // Round up the integer by 32 bytes + __forceinline size_t roundUpAlloc( size_t cb ) + { + const size_t mask = 31; + cb += mask; + // We require AVX1+FMA3 support, might as well use BMI1 + return _andn_u64( mask, cb ); + } +} + +void* BufferAllocator::allocate( size_t cb, size_t align ) noexcept +{ + assert( align <= 32 ); + cb = roundUpAlloc( cb ); + + uint8_t* pointer = buffer.pointer(); + if( head + cb > size || nullptr == pointer ) + { + logError( u8"BufferAllocator.allocate, not enough capacity" ); + return nullptr; + } + + void* const res = pointer + head; + head += cb; + assert( head <= size ); + dbgMarkUninitializedMemory( res, cb ); + return res; +} + +namespace +{ + // 2 MB of memory, we hope the OS kernel will then be smart enough to give us large pages. + constexpr size_t virtualAllocGranularityExp2 = 21; + + constexpr size_t virtualAllocGranularityMask = ( ( (size_t)1 ) << virtualAllocGranularityExp2 ) - 1; + + // Round up the integer by 2 megabytes + __forceinline size_t roundUpVirtualAlloc( size_t cb ) + { + const size_t mask = virtualAllocGranularityMask; + cb += mask; + return _andn_u64( mask, cb ); + } +} + +HRESULT VirtualAllocator::create( size_t cb ) +{ + if( nullptr != pointer ) + return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); + cb = roundUpVirtualAlloc( cb ); + pointer = (uint8_t*)VirtualAlloc( NULL, cb, MEM_RESERVE, PAGE_READWRITE ); + if( nullptr != pointer ) + { + head = 0; + sizeAllocated = 0; + sizeVirtual = cb; + return S_OK; + } + + const HRESULT hr = getLastHr(); + logErrorHr( hr, u8"VirtualAlloc failed" ); + return hr; +} + +void* VirtualAllocator::allocate( size_t cb, size_t align ) noexcept +{ + assert( align <= 32 ); + cb = roundUpAlloc( cb ); + + const size_t newHead = head + cb; + if( newHead <= sizeAllocated ) + { + void* const res = pointer + head; + head = newHead; + dbgMarkUninitializedMemory( res, cb ); + return res; + } + + if( newHead <= sizeVirtual ) + { + uint8_t* const ptrCommit = pointer + sizeAllocated; + const size_t cbCommit = roundUpVirtualAlloc( newHead ) - sizeAllocated; + void* const res = VirtualAlloc( ptrCommit, cbCommit, MEM_COMMIT, PAGE_READWRITE ); + if( nullptr != res ) + { + sizeAllocated += cbCommit; + assert( sizeAllocated <= sizeVirtual ); + void* const res = pointer + head; + head = newHead; + dbgMarkUninitializedMemory( res, cb ); + return res; + } + + const HRESULT hr = getLastHr(); + logErrorHr( hr, u8"VirtualAllocator.allocate, VirtualAlloc failed" ); + return nullptr; + } + + logError( u8"VirtualAllocator.allocate, not enough arena capacity" ); + return nullptr; +} + +VirtualAllocator::~VirtualAllocator() +{ + if( nullptr == pointer ) + return; + + if( VirtualFree( pointer, 0, MEM_RELEASE ) ) + { + pointer = nullptr; + return; + } + + const HRESULT hr = getLastHr(); + logErrorHr( hr, u8"VirtualFree failed" ); +} + +#ifndef NDEBUG +// Reusing Microsoft's magic numbers: https://asawicki.info/news_1292_magic_numbers_in_visual_c +void CpuCompute::dbgMarkUninitializedMemory( void* pv, size_t cb ) +{ + __stosb( (uint8_t*)pv, 0xCD, cb ); +} +void CpuCompute::dbgMarkFreedMemory( void* pv, size_t cb ) +{ + __stosd( (DWORD*)pv, 0xFEEEFEEEu, cb / 4 ); +} +#endif
\ No newline at end of file diff --git a/Whisper/CPU/BufferAllocator.h b/Whisper/CPU/BufferAllocator.h new file mode 100644 index 0000000..750a565 --- /dev/null +++ b/Whisper/CPU/BufferAllocator.h @@ -0,0 +1,64 @@ +#pragma once +#include "LargeBuffer.h" +#include "Tensor.h" + +namespace CpuCompute +{ +#ifdef NDEBUG + inline void dbgMarkUninitializedMemory( void* pv, size_t cb ) { } + inline void dbgMarkFreedMemory( void* pv, size_t cb ) { } +#else + void dbgMarkUninitializedMemory( void* pv, size_t cb ); + void dbgMarkFreedMemory( void* pv, size_t cb ); +#endif + + // An implementation of arena allocator which slices pieces of a large buffer allocated in advance + class BufferAllocator : public iArenaAllocator + { + LargeBuffer buffer; + size_t head = 0; + size_t size = 0; + + void resetArena() noexcept override final + { + head = 0; + dbgMarkFreedMemory( buffer.pointer(), size ); + } + + void* allocate( size_t cb, size_t align ) noexcept override final; + + public: + BufferAllocator() = default; + BufferAllocator( const BufferAllocator& ) = delete; + ~BufferAllocator() = default; + + // Allocate a large buffer with the specified count of bytes + HRESULT create( size_t cb ); + }; + + // An implementation of arena allocator which allocates a large chunk of virtual memory, and maps new physical pages into that memory region as needed. + class VirtualAllocator : public iArenaAllocator + { + uint8_t* pointer = nullptr; + size_t head = 0; + size_t sizeAllocated = 0; + size_t sizeVirtual = 0; + + void resetArena() noexcept override final + { + head = 0; + dbgMarkFreedMemory( pointer, sizeAllocated ); + } + + void* allocate( size_t cb, size_t align ) noexcept override final; + + public: + + VirtualAllocator() = default; + VirtualAllocator( const VirtualAllocator& ) = delete; + ~VirtualAllocator(); + + // Reserve virtual memory space for the specified count of bytes in the arena, but don't allocate any pages + HRESULT create( size_t cb ); + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/DecoderTensors.cpp b/Whisper/CPU/DecoderTensors.cpp new file mode 100644 index 0000000..22de476 --- /dev/null +++ b/Whisper/CPU/DecoderTensors.cpp @@ -0,0 +1,68 @@ +#include "stdafx.h" +#include "DecoderTensors.h" +using namespace CpuCompute; + +#if TENSOR_GGML_COMPAT +namespace +{ + class CompatContext + { + std::vector<ggml_tensor>& vec; + size_t index; + + public: + CompatContext( std::vector<ggml_tensor>& dest, size_t layers ) : + vec( dest ) + { + constexpr size_t tensorsPerLayer = 21; + const size_t count = tensorsPerLayer * layers + 4; + vec.resize( count ); + index = 0; + } + + void add( const Tensor& rsi, ggml_tensor*& res ) + { + ggml_tensor& ten = vec[ index ]; + index++; + ten = rsi.ggml(); + res = &ten; + } + + void add2( const TensorPair& rsi, ggml_tensor*& w, ggml_tensor*& b ) + { + add( rsi.w, w ); + add( rsi.b, b ); + } + + bool isComplete() const + { + return index == vec.size(); + } + }; +} + +void DecoderTensors::makeCompatTensors() +{ + CompatContext ctx( ggml, layers.size() ); + + ctx.add( positionalEmbedding, d_pe ); + ctx.add( tokenEmbedding, d_te ); + ctx.add2( ln, d_ln_w, d_ln_b ); + + for( auto& i : layers ) + { + ctx.add2( i.attnLn0, i.attn_ln_0_w, i.attn_ln_0_b ); + ctx.add2( i.attnLn1, i.attn_ln_1_w, i.attn_ln_1_b ); + ctx.add2( i.attnQuery, i.attn_q_w, i.attn_q_b ); + ctx.add( i.attnKey, i.attn_k_w ); + ctx.add2( i.attnValue, i.attn_v_w, i.attn_v_b ); + ctx.add2( i.crossAttnLn0, i.cross_attn_ln_0_w, i.cross_attn_ln_0_b ); + ctx.add2( i.crossAttnLn1, i.cross_attn_ln_1_w, i.cross_attn_ln_1_b ); + ctx.add2( i.crossAttnQuery, i.cross_attn_q_w, i.cross_attn_q_b ); + ctx.add2( i.mlpLn, i.mlp_ln_w, i.mlp_ln_b ); + ctx.add2( i.mlp0, i.mlp_0_w, i.mlp_0_b ); + ctx.add2( i.mlp1, i.mlp_1_w, i.mlp_1_b ); + } + assert( ctx.isComplete() ); +} +#endif
\ No newline at end of file diff --git a/Whisper/CPU/DecoderTensors.h b/Whisper/CPU/DecoderTensors.h new file mode 100644 index 0000000..2efa519 --- /dev/null +++ b/Whisper/CPU/DecoderTensors.h @@ -0,0 +1,131 @@ +#pragma once +#include <vector> +#include "Tensor.h" +#include "LargeBuffer.h" +#if TENSOR_GGML_COMPAT +#include "../source/ggml.h" +#endif + +namespace CpuCompute +{ + // A set of tensors for one decoder's layer + struct LayerDecoder + { + // decoder.blocks.*.attn_ln + TensorPair attnLn0; + // decoder.blocks.*.attn.out + TensorPair attnLn1; + // decoder.blocks.*.attn.query + TensorPair attnQuery; + // decoder.blocks.*.attn.key + Tensor attnKey; + // decoder.blocks.*.attn.value + TensorPair attnValue; + // decoder.blocks.*.cross_attn_ln + TensorPair crossAttnLn0; + // decoder.blocks.*.cross_attn.out + TensorPair crossAttnLn1; + // decoder.blocks.*.cross_attn.query + TensorPair crossAttnQuery; + + // decoder.blocks.*.cross_attn.key + // Tensor crossAttnKey; + // decoder.blocks.*.cross_attn.value + // TensorPair crossAttnValue; + + // decoder.blocks.*.mlp_ln + TensorPair mlpLn; + // decoder.blocks.*.mlp.0 + TensorPair mlp0; + // decoder.blocks.*.mlp.2 + TensorPair mlp1; + +#if TENSOR_GGML_COMPAT + // decoder.blocks.*.attn_ln + ggml_tensor* attn_ln_0_w; + ggml_tensor* attn_ln_0_b; + + // decoder.blocks.*.attn.out + ggml_tensor* attn_ln_1_w; + ggml_tensor* attn_ln_1_b; + + // decoder.blocks.*.attn.query + ggml_tensor* attn_q_w; + ggml_tensor* attn_q_b; + + // decoder.blocks.*.attn.key + ggml_tensor* attn_k_w; + + // decoder.blocks.*.attn.value + ggml_tensor* attn_v_w; + ggml_tensor* attn_v_b; + + // decoder.blocks.*.cross_attn_ln + ggml_tensor* cross_attn_ln_0_w; + ggml_tensor* cross_attn_ln_0_b; + + // decoder.blocks.*.cross_attn.out + ggml_tensor* cross_attn_ln_1_w; + ggml_tensor* cross_attn_ln_1_b; + + // decoder.blocks.*.cross_attn.query + ggml_tensor* cross_attn_q_w; + ggml_tensor* cross_attn_q_b; + + // decoder.blocks.*.mlp_ln + ggml_tensor* mlp_ln_w; + ggml_tensor* mlp_ln_b; + + // decoder.blocks.*.mlp.0 + ggml_tensor* mlp_0_w; + ggml_tensor* mlp_0_b; + + // decoder.blocks.*.mlp.2 + ggml_tensor* mlp_1_w; + ggml_tensor* mlp_1_b; +#endif + }; + + struct DecoderTensors + { + // decoder.positional_embedding + Tensor positionalEmbedding; + + // decoder.token_embedding + Tensor tokenEmbedding; + + // decoder.ln + TensorPair ln; + // A vector of layers + std::vector<LayerDecoder> layers; + + void setMemoryBuffer( LargeBuffer&& mem ) noexcept + { + memory = std::move( mem ); +#if TENSOR_GGML_COMPAT + makeCompatTensors(); +#endif + } + +#if TENSOR_GGML_COMPAT + void makeCompatTensors(); + + // decoder.positional_embedding + ggml_tensor* d_pe; // DD + + // decoder.token_embedding + ggml_tensor* d_te; // DD + + // decoder.ln + ggml_tensor* d_ln_w; // DD + ggml_tensor* d_ln_b; // DD +#endif + + private: + // A smart pointer which owns the memory for all the above tensors + LargeBuffer memory; +#if TENSOR_GGML_COMPAT + std::vector<ggml_tensor> ggml; +#endif + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/HybridLoader.cpp b/Whisper/CPU/HybridLoader.cpp new file mode 100644 index 0000000..e96e3ac --- /dev/null +++ b/Whisper/CPU/HybridLoader.cpp @@ -0,0 +1,140 @@ +#include "stdafx.h" +#include "HybridLoader.h" +using namespace CpuCompute; +using namespace ComLight; + +static void populateDecodeTensorsMap( CAtlMap<CStringA, Tensor*>& map, int layersDec, DecoderTensors& dec ) +{ + dec.layers.resize( layersDec ); + + map[ "decoder.positional_embedding" ] = &dec.positionalEmbedding; + map[ "decoder.token_embedding.weight" ] = &dec.tokenEmbedding; + map[ "decoder.ln.weight" ] = &dec.ln.w; + map[ "decoder.ln.bias" ] = &dec.ln.b; + + CStringA tempString; + auto add = [ & ]( const char* name, int i, Tensor& t ) + { + tempString.Format( "decoder.blocks.%i.%s", i, name ); + map[ tempString ] = &t; + }; + + auto add2 = [ & ]( const char* name, int i, TensorPair& tensors ) + { + tempString.Format( "decoder.blocks.%i.%s.weight", i, name ); + map[ tempString ] = &tensors.w; + tempString.Format( "decoder.blocks.%i.%s.bias", i, name ); + map[ tempString ] = &tensors.b; + }; + + for( int i = 0; i < layersDec; i++ ) + { + auto& gpu = dec.layers[ i ]; + add2( "mlp_ln", i, gpu.mlpLn ); + add2( "mlp.0", i, gpu.mlp0 ); + add2( "mlp.2", i, gpu.mlp1 ); + add2( "attn_ln", i, gpu.attnLn0 ); + add2( "attn.query", i, gpu.attnQuery ); + add( "attn.key.weight", i, gpu.attnKey ); + + add2( "attn.value", i, gpu.attnValue ); + add2( "attn.out", i, gpu.attnLn1 ); + + add2( "cross_attn_ln", i, gpu.crossAttnLn0 ); + add2( "cross_attn.query", i, gpu.crossAttnQuery ); + + // These 3 tensors are used by the encode() method, to compute cross-attention buffers + // Need them in VRAM even for the hybrid model + // add( "cross_attn.key.weight", i, gpu.cross_attn_k_w ); + // add2( "cross_attn.value", i, gpu.cross_attn_v_w, gpu.cross_attn_v_b ); + add2( "cross_attn.out", i, gpu.crossAttnLn1 ); + } +} + +HybridLoader::HybridLoader( DecoderTensors& m, int countLayers ) : + destination( m ) +{ + populateDecodeTensorsMap( map, countLayers, destination ); + pending.reserve( map.GetCount() ); +} + +HRESULT HybridLoader::setupTensor( const CStringA& name, int n_dims, int ftype, const std::array<int, 4>& ne, ComLight::iReadStream* stream, int64_t& postponedBytes ) +{ + auto p = map.Lookup( name ); + if( nullptr == p ) + return S_FALSE; + + Tensor& rdi = *p->m_value; + PendingTensor& pt = pending.emplace_back(); + + __m128i vec = load16( ne.data() ); + vec = _mm_insert_epi32( vec, 1, 3 ); + store16( &rdi.ne, vec ); + rdi.setDenseStrides(); + + pt.destPointer = p->m_value; + CHECK( stream->getPosition( pt.streamOffset ) ); + pt.bufferOffset = bufferBytes; + + size_t cbElement; + if( ftype == 0 ) + { + rdi.setType( eDataType::FP32 ); + cbElement = 4; + } + else + { + rdi.setType( eDataType::FP16 ); + cbElement = 2; + } + + const size_t totalElts = (size_t)(uint32_t)ne[ 0 ] * (uint32_t)ne[ 1 ] * (uint32_t)ne[ 2 ]; + if( totalElts * cbElement > UINT_MAX ) + return DISP_E_OVERFLOW; + + size_t payloadBytes = cbElement * totalElts; + pt.payloadBytes = payloadBytes; + CHECK( stream->seek( payloadBytes, eSeekOrigin::Current ) ); + postponedBytes += (int64_t)payloadBytes; + + payloadBytes = ( payloadBytes + 31 ) & ( ~( (size_t)31 ) ); + bufferBytes += payloadBytes; + return S_OK; +} + +HRESULT HybridLoader::completeLoad( ComLight::iReadStream* stream, iLoaderProgressSink& progressSink ) +{ + if( pending.size() != map.GetCount() ) + { + logError( u8"Not all tensors loaded from model file - expected %zu, got %zu", map.GetCount(), pending.size() ); + return E_INVALIDARG; + } + + LargeBuffer buffer; + CHECK( buffer.allocate( bufferBytes ) ); + + uint8_t* rdi = buffer.pointer(); + + for( const auto& pt : pending ) + { + if( pt.payloadBytes > INT_MAX ) + return DISP_E_OVERFLOW; + CHECK( stream->seek( pt.streamOffset, eSeekOrigin::Begin ) ); + + int written = 0; + CHECK( stream->read( rdi, (int)pt.payloadBytes, written ) ); + CHECK( progressSink.gotBytes( (int64_t)pt.payloadBytes ) ); + + pt.destPointer->setDataPointer( rdi ); + + const size_t cb = ( pt.payloadBytes + 31 ) & ( ~( (size_t)31 ) ); + rdi += cb; + } + + CHECK( buffer.setReadOnly( bufferBytes ) ); + destination.setMemoryBuffer( std::move( buffer ) ); + + constexpr double mulMb = 1.0 / ( 1 << 20 ); + logDebug( u8"Loaded %zu decoder tensors, %g MB RAM", pending.size(), mulMb * (double)(int64_t)bufferBytes ); + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/CPU/HybridLoader.h b/Whisper/CPU/HybridLoader.h new file mode 100644 index 0000000..8b12804 --- /dev/null +++ b/Whisper/CPU/HybridLoader.h @@ -0,0 +1,37 @@ +#pragma once +#include "DecoderTensors.h" +#include <atlstr.h> +#include <atlcoll.h> +#include "../../ComLightLib/streams.h" + +namespace CpuCompute +{ + __interface iLoaderProgressSink + { + HRESULT gotBytes( int64_t cb ); + }; + + class HybridLoader + { + DecoderTensors& destination; + CAtlMap<CStringA, Tensor*> map; + size_t bufferBytes = 0; + + struct alignas( 32 ) PendingTensor + { + Tensor* destPointer = nullptr; + int64_t streamOffset = 0; + size_t bufferOffset = 0; + size_t payloadBytes = 0; + }; + std::vector<PendingTensor> pending; + + public: + + HybridLoader( DecoderTensors& m, int countLayers ); + + HRESULT setupTensor( const CStringA& name, int n_dims, int ftype, const std::array<int, 4>& ne, ComLight::iReadStream* stream, int64_t& postponedBytes ); + + HRESULT completeLoad( ComLight::iReadStream* stream, iLoaderProgressSink& progressSink ); + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/KvTensors.h b/Whisper/CPU/KvTensors.h new file mode 100644 index 0000000..a7897d3 --- /dev/null +++ b/Whisper/CPU/KvTensors.h @@ -0,0 +1,36 @@ +#pragma once +#include "Tensor.h" +#include "LargeBuffer.h" +#include "../Whisper/sModelParams.h" + +namespace CpuCompute +{ + class KvTensors + { + uint16_t* keys = nullptr; + uint16_t* values = nullptr; + uint32_t size = 0; + + CpuCompute::LargeBuffer memory; + + public: + // Create these two large tensors, FP16 precision + HRESULT create( const Whisper::sModelParams& mp ); + + // A slice of model.memory_cross_k tensor + Tensor keysView( uint32_t len, uint32_t off ) const + { + if( len + off <= size ) + return Tensor::fromData( keys + off, eDataType::FP16, len ); + throw E_BOUNDS; + } + + // A slice of model.memory_cross_v tensor + Tensor valuesView( uint32_t len, uint32_t off ) const + { + if( len + off <= size ) + return Tensor::fromData( values + off, eDataType::FP16, len ); + throw E_BOUNDS; + } + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/KvTensorsCpu.cpp b/Whisper/CPU/KvTensorsCpu.cpp new file mode 100644 index 0000000..88a70ff --- /dev/null +++ b/Whisper/CPU/KvTensorsCpu.cpp @@ -0,0 +1,19 @@ +#include "stdafx.h" +#include "KvTensors.h" +using namespace CpuCompute; + +// Create these two large tensors, FP16 precision +HRESULT KvTensors::create( const Whisper::sModelParams& mp ) +{ + const uint32_t n_mem = mp.n_text_layer * mp.n_text_ctx; + const uint32_t n_elements = mp.n_text_state * n_mem; + + const size_t cb = sizeof( uint16_t ) * (size_t)n_elements * 2; + CHECK( memory.allocate( cb ) ); + + uint16_t* pointer = (uint16_t*)memory.pointer(); + keys = pointer; + values = pointer + n_elements; + size = n_elements; + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/CPU/LargeBuffer.cpp b/Whisper/CPU/LargeBuffer.cpp new file mode 100644 index 0000000..e124686 --- /dev/null +++ b/Whisper/CPU/LargeBuffer.cpp @@ -0,0 +1,34 @@ +#include "stdafx.h" +#include "LargeBuffer.h" +using namespace CpuCompute; + +void LargeBuffer::deallocate() +{ + if( nullptr == pv ) + return; + VirtualFree( pv, 0, MEM_RELEASE ); + pv = nullptr; +} + +HRESULT LargeBuffer::allocate( size_t cb ) +{ + deallocate(); + + pv = VirtualAlloc( nullptr, cb, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE ); + if( nullptr != pv ) + return S_OK; + return HRESULT_FROM_WIN32( GetLastError() ); +} + +HRESULT LargeBuffer::setReadOnly( size_t cb ) +{ + if( nullptr != pv ) + { + DWORD op = 0; + if( VirtualProtect( pv, cb, PAGE_READONLY, &op ) ) + return S_OK; + return HRESULT_FROM_WIN32( GetLastError() ); + } + else + return OLE_E_BLANK; +}
\ No newline at end of file diff --git a/Whisper/CPU/LargeBuffer.h b/Whisper/CPU/LargeBuffer.h new file mode 100644 index 0000000..d9a9a22 --- /dev/null +++ b/Whisper/CPU/LargeBuffer.h @@ -0,0 +1,44 @@ +#pragma once + +namespace CpuCompute +{ + // A large memory buffer allocated with VirtualAlloc kernel API, bypassing the heap. + class LargeBuffer + { + void* pv = nullptr; + public: + LargeBuffer() = default; + LargeBuffer( const LargeBuffer& ) = delete; + LargeBuffer( LargeBuffer&& that ) noexcept + { + pv = that.pv; + that.pv = nullptr; + } + ~LargeBuffer() + { + deallocate(); + } + void operator=( LargeBuffer&& that ) noexcept + { + std::swap( pv, that.pv ); + } + void operator=( const LargeBuffer& that ) = delete; + + // Allocate buffer with specified count of bytes, and read+write memory protection + // The OS kernel guarantees zero-initialization of that memory. + HRESULT allocate( size_t cb ); + + // Change memory protection of the buffer to read only + HRESULT setReadOnly( size_t cb ); + + // Unless the pointer is nullptr, deallocate the buffer + void deallocate(); + + // Pointer to the start of the buffer, aligned by memory page = 4 kilobytes + uint8_t* pointer() const + { + assert( nullptr != pv ); + return (uint8_t*)pv; + } + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/MlContext.h b/Whisper/CPU/MlContext.h new file mode 100644 index 0000000..062a7b8 --- /dev/null +++ b/Whisper/CPU/MlContext.h @@ -0,0 +1,71 @@ +#pragma once +#include "Tensor.h" +#include "ParallelForRunner.h" + +namespace CpuCompute +{ + class MlContext + { + ParallelForRunner pfor; + iMemoryAllocator* allocator = nullptr; + + public: + MlContext( int threads ); + MlContext( const MlContext& ) = delete; + ~MlContext() = default; + + HRESULT setThreadsCount( int threads ) + { + return pfor.setThreadsCount( threads ); + } + + iMemoryAllocator* setAllocator( iMemoryAllocator* alloc ) + { + iMemoryAllocator* const ret = allocator; + allocator = alloc; + return ret; + } + + Tensor createTensor( eDataType type, const std::array<uint32_t, 4>& size ); + Tensor createTensor( eDataType type, std::initializer_list<uint32_t> size ); + + Tensor addRows( const Tensor& d_te, const Tensor& d_pe, const int* tokens, const int n_tokens, const int n_past ); + + Tensor norm( const Tensor& arg ); + + // cur = add( mul( repeat( w, cur ), cur ), repeat( b, cur ) ); + void fmaRepeat( Tensor& cur, const Tensor& w, const Tensor& b ); + + inline void fmaRepeat( Tensor& cur, const TensorPair wb ) + { + fmaRepeat( cur, wb.w, wb.b ); + } + + // Multiply two matrices + Tensor mulMat( const Tensor& a, const Tensor& b ); + + // cur = add( repeat( b, cur ), cur ); cur = scale(cur, scaling) + void addRepeatScale( Tensor& cur, const Tensor& b, float scaling ); + + void addRepeat( Tensor& cur, const Tensor& b ); + + Tensor add( const Tensor& a, const Tensor& b ); + void addInPlace( Tensor& a, const Tensor& b ); + void addRepeatGelu( Tensor& cur, const Tensor& b ); + + // cur = scale(cur, scaling) + void scale( Tensor& cur, float scaling ); + + void diagMaskInf( Tensor& cur, uint32_t n_past ); + + void softMax( Tensor& cur, float inputScale = 1.0f ); + + Tensor copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size ); + + HRESULT copyImpl( Tensor& result, const Tensor& source ); + + Tensor permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 ); + + void copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size ); + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/MlContextCpu.cpp b/Whisper/CPU/MlContextCpu.cpp new file mode 100644 index 0000000..e34a823 --- /dev/null +++ b/Whisper/CPU/MlContextCpu.cpp @@ -0,0 +1,597 @@ +#include "stdafx.h" +#include "MlContext.h" +#include "simdUtils.h" +#include "mulMat.h" +using namespace CpuCompute; + +MlContext::MlContext( int threads ) : pfor( threads ) +{ +} + +Tensor MlContext::createTensor( eDataType type, const std::array<uint32_t, 4>& size ) +{ + Tensor res; + check( res.create( type, size, allocator ) ); + return res; +} + +Tensor MlContext::createTensor( eDataType type, std::initializer_list<uint32_t> size ) +{ + Tensor res; + check( res.create( type, size, allocator ) ); + return res; +} + +namespace +{ + inline const uint16_t* getRow16( const Tensor& t, size_t index ) + { + const uint16_t* rsi = t.fp16(); + rsi += index * t.nb[ 1 ]; + return rsi; + } + inline const float* getRow32( const Tensor& t, size_t index ) + { + const float* rsi = t.fp32(); + rsi += index * t.nb[ 1 ]; + return rsi; + } +} + +Tensor MlContext::addRows( const Tensor& d_te, const Tensor& d_pe, const int* tokens, const int n_tokens, const int n_past ) +{ + if( d_te.type() != eDataType::FP16 || d_pe.type() != eDataType::FP32 ) + throw E_INVALIDARG; + if( d_te.ne[ 0 ] != d_pe.ne[ 0 ] ) + throw E_INVALIDARG; + if( n_tokens <= 0 ) + throw E_BOUNDS; + + Tensor res = createTensor( eDataType::FP32, { d_te.ne[ 0 ], (uint32_t)n_tokens } ); + + const size_t inner = (size_t)d_te.ne[ 0 ]; + const size_t outer = (size_t)n_tokens; + float* rdi = res.fp32(); + for( size_t i = 0; i < outer; i++, rdi += inner, tokens++ ) + { + const uint16_t* const source1 = getRow16( d_te, *(const uint32_t*)tokens ); + const float* const source2 = getRow32( d_pe, i + (size_t)n_past ); + addF16to32( rdi, source1, source2, inner ); + } + return res; +} + +namespace +{ + class DispatchHelper3 + { + std::array<uint32_t, 3> ne; + + public: + DispatchHelper3() = default; + DispatchHelper3( uint32_t x, uint32_t y, uint32_t z ) + { + assert( x > 0 && y > 0 && z > 0 ); + ne[ 0 ] = x; + ne[ 1 ] = y; + ne[ 2 ] = z; + } + size_t groupsCount() const + { + size_t res = ne[ 0 ]; + res *= ne[ 1 ]; + res *= ne[ 2 ]; + return res; + } + std::array<uint32_t, 3> unpack( size_t idx ) const + { + assert( idx < groupsCount() ); + std::array<uint32_t, 3> res; + res[ 0 ] = (uint32_t)( idx % ne[ 0 ] ); + idx = idx / ne[ 0 ]; + res[ 1 ] = (uint32_t)( idx % ne[ 1 ] ); + res[ 2 ] = (uint32_t)( idx / ne[ 1 ] ); + return res; + } + void next( std::array<uint32_t, 3>& i ) const + { + i[ 0 ]++; + if( i[ 0 ] < ne[ 0 ] ) + return; + i[ 0 ] = 0; + i[ 1 ]++; + if( i[ 1 ] < ne[ 1 ] ) + return; + i[ 1 ] = 0; + i[ 2 ]++; + } + }; + + inline const float* sourceRow( const float* rsi, const std::array<uint32_t, 3>& idx, size_t nb0, size_t nb1, size_t nb2 ) + { + const size_t r0 = idx[ 0 ] * nb0; + const size_t r1 = idx[ 1 ] * nb1; + const size_t r2 = idx[ 2 ] * nb2; + rsi = rsi + r0 + r1 + r2; + return rsi; + } + + struct NormContext : public iComputeRange + { + const float* source; + float* result; + size_t inner; + DispatchHelper3 threads; + std::array<uint32_t, 3> nbInput; + + HRESULT __stdcall compute( size_t i, size_t end ) const override final + { + ALIGNED_SPAN( temp, inner ); + + std::array<uint32_t, 3> idx = threads.unpack( i ); + float* rdi = result + i * inner; + for( ; i < end; i++, rdi += inner, threads.next( idx ) ) + { + const float* rsi = sourceRow( source, idx, nbInput[ 0 ], nbInput[ 1 ], nbInput[ 2 ] ); + norm( rdi, temp, rsi, inner ); + } + return S_OK; + } + }; +} + +Tensor MlContext::norm( const Tensor& arg ) +{ + if( arg.type() != eDataType::FP32 || arg.nb[ 0 ] != 1 ) + throw E_INVALIDARG; + Tensor res = createTensor( eDataType::FP32, arg.ne ); + + NormContext context; + context.source = arg.fp32(); + context.result = res.fp32(); + context.inner = arg.ne[ 0 ]; + context.threads = DispatchHelper3( arg.ne[ 1 ], arg.ne[ 2 ], arg.ne[ 3 ] ); + context.nbInput = { arg.nb[ 1 ], arg.nb[ 2 ], arg.nb[ 3 ] }; + + check( pfor.parallelFor( context, context.threads.groupsCount() ) ); + return res; +} + +void MlContext::fmaRepeat( Tensor& cur, const Tensor& w, const Tensor& b ) +{ + if( !( cur.isContinuous() && w.isContinuous() && b.isContinuous() ) ) + throw E_INVALIDARG; + + if( !( cur.type() == eDataType::FP32 && w.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) ) + throw E_INVALIDARG; + + if( !isSameShape( w, b ) ) + throw E_INVALIDARG; + + DispatchHelper3 helper{ cur.ne[ 1 ], cur.ne[ 2 ], cur.ne[ 3 ] }; + std::array<uint32_t, 3> idx = { 0, 0, 0 }; + const size_t countRows = helper.groupsCount(); + + const size_t innerRes = cur.ne[ 0 ]; + const size_t innerPattern = w.ne[ 0 ]; + + float* rdi = cur.fp32(); + for( size_t i = 0; i < countRows; i++, helper.next( idx ), rdi += innerRes ) + { + std::array<uint32_t, 3> idxPattern; + idxPattern[ 0 ] = idx[ 0 ] % w.ne[ 1 ]; + idxPattern[ 1 ] = idx[ 1 ] % w.ne[ 2 ]; + idxPattern[ 2 ] = idx[ 2 ] % w.ne[ 3 ]; + + const float* s1 = sourceRow( w.fp32(), idxPattern, w.nb[ 1 ], w.nb[ 2 ], w.nb[ 3 ] ); + const float* s2 = sourceRow( b.fp32(), idxPattern, b.nb[ 1 ], b.nb[ 2 ], b.nb[ 3 ] ); + fmaRepeatRow( rdi, innerRes, s1, s2, innerPattern ); + } +} + +Tensor MlContext::mulMat( const Tensor& a, const Tensor& b ) +{ + if( !DirectCompute::canMulMat( a, b ) ) + throw E_INVALIDARG; + + std::array<uint32_t, 4> ne{ a.ne[ 1 ], b.ne[ 1 ], a.ne[ 2 ], b.ne[ 3 ] }; + Tensor result = createTensor( eDataType::FP32, ne ); + + check( CpuCompute::mulMat( result, a, b, pfor ) ); + return result; +} + +// cur = add( repeat( b, cur ), cur ); cur = scale(cur, scaling) +void MlContext::addRepeatScale( Tensor& cur, const Tensor& b, float scaling ) +{ + if( !( cur.isContinuous() && b.isContinuous() ) ) + throw E_INVALIDARG; + if( !( cur.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) ) + throw E_INVALIDARG; + + DispatchHelper3 helper{ cur.ne[ 1 ], cur.ne[ 2 ], cur.ne[ 3 ] }; + std::array<uint32_t, 3> idx = { 0, 0, 0 }; + const size_t countRows = helper.groupsCount(); + + const size_t innerRes = (uint32_t)cur.ne[ 0 ]; + const size_t innerPattern = (uint32_t)b.ne[ 0 ]; + + float* rdi = cur.fp32(); + const __m256 scale = _mm256_set1_ps( scaling ); + for( size_t i = 0; i < countRows; i++, helper.next( idx ), rdi += innerRes ) + { + std::array<uint32_t, 3> idxPattern; + idxPattern[ 0 ] = idx[ 0 ] % (uint32_t)b.ne[ 1 ]; + idxPattern[ 1 ] = idx[ 1 ] % (uint32_t)b.ne[ 2 ]; + idxPattern[ 2 ] = idx[ 2 ] % (uint32_t)b.ne[ 3 ]; + + const float* source = sourceRow( b.fp32(), idxPattern, b.nb[ 1 ], b.nb[ 2 ], b.nb[ 3 ] ); + addRepeatScaleRow( rdi, innerRes, source, innerPattern, scale ); + } +} + +void MlContext::addRepeat( Tensor& cur, const Tensor& b ) +{ + if( !( cur.isContinuous() && b.isContinuous() ) ) + throw E_INVALIDARG; + if( !( cur.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) ) + throw E_INVALIDARG; + + DispatchHelper3 helper{ cur.ne[ 1 ], cur.ne[ 2 ], cur.ne[ 3 ] }; + std::array<uint32_t, 3> idx = { 0, 0, 0 }; + const size_t countRows = helper.groupsCount(); + + const size_t innerRes = (uint32_t)cur.ne[ 0 ]; + const size_t innerPattern = (uint32_t)b.ne[ 0 ]; + + float* rdi = cur.fp32(); + for( size_t i = 0; i < countRows; i++, helper.next( idx ), rdi += innerRes ) + { + std::array<uint32_t, 3> idxPattern; + idxPattern[ 0 ] = idx[ 0 ] % (uint32_t)b.ne[ 1 ]; + idxPattern[ 1 ] = idx[ 1 ] % (uint32_t)b.ne[ 2 ]; + idxPattern[ 2 ] = idx[ 2 ] % (uint32_t)b.ne[ 3 ]; + + const float* source = sourceRow( b.fp32(), idxPattern, b.nb[ 1 ], b.nb[ 2 ], b.nb[ 3 ] ); + addRepeatRow( rdi, innerRes, source, innerPattern ); + } +} + +// cur = scale(cur, scaling) +void MlContext::scale( Tensor& cur, float scaling ) +{ + if( !( cur.isContinuous() && cur.type() == eDataType::FP32 ) ) + throw E_INVALIDARG; + + const size_t len = cur.countElements(); + const __m256 scale = _mm256_set1_ps( scaling ); + scaleRow( cur.fp32(), len, scale ); +} + +void MlContext::diagMaskInf( Tensor& cur, uint32_t n_past ) +{ + if( !( cur.isContinuous() && cur.type() == eDataType::FP32 ) ) + throw E_INVALIDARG; + + const size_t n = cur.countRows(); + const size_t nc = cur.ne[ 0 ]; + const size_t nr = cur.ne[ 1 ]; + const size_t nz = n / nr; + + for( size_t k = 0; k < nz; k++ ) + { + for( size_t j = 0; j < nr; j++ ) + { + float* const rdi = cur.fp32() + k * cur.nb[ 2 ] + j * cur.nb[ 1 ]; + // +1 because the original code checked for `if( i > n_past + j )` + // That's why the first index to write is ( n_past + j + 1 ) + const size_t start = n_past + j + 1; + const ptrdiff_t len = (ptrdiff_t)nc - (ptrdiff_t)start; + if( len <= 0 ) + continue; + + // Generates a store string instruction (rep stosd). + // The magic number is negative infinity in FP32: https://www.h-schmidt.net/FloatConverter/IEEE754.html + __stosd( (DWORD*)( rdi + start ), 0xff800000u, (size_t)len ); + } + } +} + +void MlContext::softMax( Tensor& cur, float inputScale ) +{ + if( !( cur.isContinuous() && cur.type() == eDataType::FP32 ) ) + throw E_INVALIDARG; + + struct SoftMaxContext : public iComputeRange + { + float* data; + float inputScale; + size_t length, stride; + + HRESULT __stdcall compute( size_t i, size_t end ) const override final + { + float* rdi = data + stride * i; + for( ; i < end; i++, rdi += stride ) + ::softMax( rdi, length, inputScale ); + return S_OK; + } + }; + + SoftMaxContext context; + context.data = cur.fp32(); + context.inputScale = inputScale; + context.length = cur.ne[ 0 ]; + context.stride = cur.nb[ 1 ]; + + const size_t n = cur.countRows(); + pfor.parallelFor( context, n ); +} + +namespace +{ + template<class R, class S> + __forceinline void copyElement( R* rdi, const S* rsi ) + { + static_assert( std::is_same<R, S>() ); + *rdi = *rsi; + } + template<> + __forceinline void copyElement<float, uint16_t>( float* rdi, const uint16_t* rsi ) + { + __m128i iv = _mm_cvtsi32_si128( *rsi ); + __m128 fv = _mm_cvtph_ps( iv ); + _mm_store_ss( rdi, fv ); + } + template<> + __forceinline void copyElement<uint16_t, float>( uint16_t* rdi, const float* rsi ) + { + __m128 fv = _mm_load_ss( rsi ); + __m128i iv = _mm_cvtps_ph( fv, 0 ); + *rdi = (uint16_t)(uint32_t)_mm_cvtsi128_si32( iv ); + } + + template<class R, class S> + __forceinline void copyRow( R* rdi, const S* rsi, size_t length ) + { + static_assert( std::is_same<R, S>() ); + memcpy( rdi, rsi, length * sizeof( R ) ); + } + template<> + __forceinline void copyRow<uint16_t, float>( uint16_t* rdi, const float* rsi, size_t length ) + { + floatsDowncast( rdi, rsi, length ); + } + template<> + __forceinline void copyRow<float, uint16_t>( float* rdi, const uint16_t* rsi, size_t length ) + { + floatsUpcast( rdi, rsi, length ); + } + + template<class R, class S> + static void __declspec( noinline ) copyImpl( R* rdi, const S* rsi, const TensorShape& shape ) + { + const bool continuousRows = shape.nb[ 0 ] == 1; + + for( size_t i03 = 0; i03 < shape.ne[ 3 ]; i03++, rsi += shape.nb[ 3 ] ) + { + const S* source2 = rsi; + for( size_t i02 = 0; i02 < shape.ne[ 2 ]; i02++, source2 += shape.nb[ 2 ] ) + { + const S* source1 = source2; + for( size_t i01 = 0; i01 < shape.ne[ 1 ]; i01++, source1 += shape.nb[ 1 ] ) + { + // Performance optimization here: when the rows are dense, we can copy them much faster with memcpy() + // Or at least with AVX, when we need to convert between numeric types + if( continuousRows ) + { + // This branch is very predictable, same outcome for all loop iterations + copyRow( rdi, source1, shape.ne[ 0 ] ); + rdi += shape.ne[ 0 ]; + } + else + { + const S* source0 = source1; + for( size_t i00 = 0; i00 < shape.ne[ 0 ]; i00++, source0 += shape.nb[ 0 ] ) + { + copyElement( rdi, source0 ); + rdi++; + } + } + } + } + } + } +} + +HRESULT MlContext::copyImpl( Tensor& result, const Tensor& source ) +{ + if( !( result.isContinuous() && ( result.countElements() == source.countElements() ) ) ) + return E_INVALIDARG; + + const eDataType typeResult = result.type(); + const eDataType typeSource = source.type(); + if( source.isContinuous() ) + { + const size_t elts = result.countElements(); + if( typeResult == typeSource ) + { + const size_t bytes = elts * elementSize( typeResult ); + memcpy( result.data(), source.data(), bytes ); + return S_OK; + } + if( typeSource == eDataType::FP16 && typeResult == eDataType::FP32 ) + { + floatsUpcast( result.fp32(), source.fp16(), elts ); + return S_OK; + } + if( typeSource == eDataType::FP32 && typeResult == eDataType::FP16 ) + { + floatsDowncast( result.fp16(), source.fp32(), elts ); + return S_OK; + } + return E_UNEXPECTED; + } + else + { + if( typeSource == eDataType::FP16 && typeResult == eDataType::FP16 ) + { + ::copyImpl( result.fp16(), source.fp16(), source ); + return S_OK; + } + if( typeSource == eDataType::FP32 && typeResult == eDataType::FP32 ) + { + ::copyImpl( result.fp32(), source.fp32(), source ); + return S_OK; + } + if( typeSource == eDataType::FP16 && typeResult == eDataType::FP32 ) + { + ::copyImpl( result.fp32(), source.fp16(), source ); + return S_OK; + } + if( typeSource == eDataType::FP32 && typeResult == eDataType::FP16 ) + { + ::copyImpl( result.fp16(), source.fp32(), source ); + return S_OK; + } + return E_UNEXPECTED; + } +} + +Tensor MlContext::copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size ) +{ + const size_t dims = size.size(); + if( 0 == dims || dims > 4 ) + throw E_BOUNDS; + + size_t nRequested = 1; + for( size_t i = 0; i < dims; i++ ) + { + uint32_t n = size.begin()[ i ]; + nRequested *= n; + } + if( nRequested != a.countElements() ) + throw E_INVALIDARG; + + if( a.type() == type && a.isContinuous() ) + { + // Same type, and it's dense - no need to move data, equal to reshape + Tensor res{ a }; + for( size_t i = 0; i < dims; i++ ) + res.ne[ i ] = size.begin()[ i ];; + for( size_t i = dims; i < 4; i++ ) + res.ne[ i ] = 1; + res.setDenseStrides(); + return res; + } + else + { + // Need to convert types, and/or transpose the tensor. Make another tensor for the output + Tensor res = createTensor( type, size ); + check( copyImpl( res, a ) ); + return res; + } +} + +Tensor MlContext::permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 ) +{ + assert( axis0 < 4 ); + assert( axis1 < 4 ); + assert( axis2 < 4 ); + assert( axis3 < 4 ); + + assert( axis0 != axis1 ); + assert( axis0 != axis2 ); + assert( axis0 != axis3 ); + assert( axis1 != axis2 ); + assert( axis1 != axis3 ); + assert( axis2 != axis3 ); + + Tensor res = a; + res.ne[ axis0 ] = a.ne[ 0 ]; + res.ne[ axis1 ] = a.ne[ 1 ]; + res.ne[ axis2 ] = a.ne[ 2 ]; + res.ne[ axis3 ] = a.ne[ 3 ]; + + res.nb[ axis0 ] = a.nb[ 0 ]; + res.nb[ axis1 ] = a.nb[ 1 ]; + res.nb[ axis2 ] = a.nb[ 2 ]; + res.nb[ axis3 ] = a.nb[ 3 ]; + + return res; +} + +void MlContext::copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size ) +{ + assert( type == dest.type() ); + + const size_t dims = size.size(); + if( 0 == dims || dims > 4 ) + throw E_BOUNDS; + + size_t nRequested = 1; + for( size_t i = 0; i < dims; i++ ) + { + uint32_t n = size.begin()[ i ]; + nRequested *= n; + } + if( nRequested != a.countElements() || nRequested != dest.countElements() ) + throw E_INVALIDARG; + + // Reshape the destination + for( size_t i = 0; i < dims; i++ ) + dest.ne[ i ] = size.begin()[ i ]; + for( size_t i = dims; i < 4; i++ ) + dest.ne[ i ] = 1; + dest.setDenseStrides(); + + // Copy the data + check( copyImpl( dest, a ) ); +} + +void MlContext::addInPlace( Tensor& a, const Tensor& b ) +{ + if( !( a.isContinuous() && b.isContinuous() && a.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) ) + throw E_NOTIMPL; + + const size_t length = a.countElements(); + addRowInPlace( a.fp32(), b.fp32(), length ); +} + +Tensor MlContext::add( const Tensor& a, const Tensor& b ) +{ + if( !( a.isContinuous() && b.isContinuous() && a.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) ) + throw E_NOTIMPL; + + Tensor res = createTensor( eDataType::FP32, a.ne ); + const size_t length = a.countElements(); + addRow( res.fp32(), a.fp32(), b.fp32(), length ); + return res; +} + +void MlContext::addRepeatGelu( Tensor& cur, const Tensor& b ) +{ + if( !( cur.isContinuous() && b.isContinuous() ) ) + throw E_INVALIDARG; + if( !( cur.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) ) + throw E_INVALIDARG; + + DispatchHelper3 helper{ cur.ne[ 1 ], cur.ne[ 2 ], cur.ne[ 3 ] }; + std::array<uint32_t, 3> idx = { 0, 0, 0 }; + const size_t countRows = helper.groupsCount(); + + const size_t innerRes = (uint32_t)cur.ne[ 0 ]; + const size_t innerPattern = (uint32_t)b.ne[ 0 ]; + float* rdi = cur.fp32(); + auto& lookupTables = getLookupTables(); + for( size_t i = 0; i < countRows; i++, helper.next( idx ), rdi += innerRes ) + { + std::array<uint32_t, 3> idxPattern; + idxPattern[ 0 ] = idx[ 0 ] % (uint32_t)b.ne[ 1 ]; + idxPattern[ 1 ] = idx[ 1 ] % (uint32_t)b.ne[ 2 ]; + idxPattern[ 2 ] = idx[ 2 ] % (uint32_t)b.ne[ 3 ]; + + const float* source = sourceRow( b.fp32(), idxPattern, b.nb[ 1 ], b.nb[ 2 ], b.nb[ 3 ] ); + addRepeatGeluRow( rdi, innerRes, source, innerPattern, lookupTables ); + } + return; +}
\ No newline at end of file diff --git a/Whisper/CPU/ParallelForRunner.cpp b/Whisper/CPU/ParallelForRunner.cpp new file mode 100644 index 0000000..7151a23 --- /dev/null +++ b/Whisper/CPU/ParallelForRunner.cpp @@ -0,0 +1,149 @@ +#include "stdafx.h" +#include "ParallelForRunner.h" +using namespace CpuCompute; + +ParallelForRunner::ParallelForRunner( int threads ) : + maxThreads( threads ) +{ + if( maxThreads <= 1 ) + { + threadBuffers.resize( 1 ); + return; + } + + work = CreateThreadpoolWork( &workCallbackStatic, this, nullptr ); + if( nullptr == work ) + throw getLastHr(); + threadBuffers.resize( maxThreads ); +} + +HRESULT ParallelForRunner::setThreadsCount( int threads ) +{ + maxThreads = threads; + if( threads <= 1 ) + { + threadBuffers.resize( 1 ); + return S_OK; + } + + threadBuffers.resize( maxThreads ); + if( nullptr == work ) + { + work = CreateThreadpoolWork( &workCallbackStatic, this, nullptr ); + if( nullptr == work ) + return getLastHr(); + } + return S_OK; +} + +ParallelForRunner::~ParallelForRunner() +{ + if( nullptr != work ) + { + if( S_FALSE == status ) + WaitForThreadpoolWorkCallbacks( work, FALSE ); + CloseThreadpoolWork( work ); + } +} + +namespace +{ + thread_local uint32_t currentThreadIndex = UINT_MAX; +} + +void ParallelForRunner::runBatch( size_t ith ) noexcept +{ + currentThreadIndex = (uint32_t)ith; + const size_t begin = ( ith * countItems ) / countThreads; + const size_t end = ( ( ith + 1 ) * countItems ) / countThreads; + + HRESULT hr = E_UNEXPECTED; + try + { + hr = computeRange->compute( begin, end ); + } + catch( HRESULT code ) + { + hr = code; + } + catch( const std::bad_alloc& ) + { + hr = E_OUTOFMEMORY; + } + catch( const std::exception& ) + { + hr = E_FAIL; + } + currentThreadIndex = UINT_MAX; + if( SUCCEEDED( hr ) ) + return; + InterlockedCompareExchange( &status, hr, S_FALSE ); +} + +void* ParallelForRunner::threadLocalBuffer( size_t cb ) +{ + const uint32_t idx = currentThreadIndex; + if( idx < threadBuffers.size() ) + { + ThreadBuffer& tb = threadBuffers[ idx ]; + if( tb.cb >= cb ) + { + // We already have large enough buffer for the current thread + return tb.memory.pointer(); + } + tb.memory.deallocate(); + check( tb.memory.allocate( cb ) ); + tb.cb = cb; + return tb.memory.pointer(); + } + if( idx != UINT_MAX ) + throw E_BOUNDS; + else + { + logError( u8"threadLocalBuffer() method only works from inside a pool callback" ); + throw E_UNEXPECTED; + } +} + +void __stdcall ParallelForRunner::workCallbackStatic( PTP_CALLBACK_INSTANCE Instance, void* pv, PTP_WORK Work ) noexcept +{ + ParallelForRunner& context = *(ParallelForRunner*)pv; + const size_t ith = (uint32_t)( InterlockedIncrement( &context.threadIndex ) ); + context.runBatch( ith ); +} + +HRESULT ParallelForRunner::parallelFor( iComputeRange& compute, size_t length, size_t minBatch ) +{ + if( maxThreads <= 1 ) + { + currentThreadIndex = 0; + const HRESULT hr1 = compute.compute( 0, length ); + currentThreadIndex = UINT_MAX; + return hr1; + } + assert( minBatch > 0 ); + + size_t nth = length / minBatch; + nth = std::min( nth, (size_t)(uint32_t)maxThreads ); + + computeRange = &compute; + countItems = length; + countThreads = nth; + threadIndex = 0; + status = S_FALSE; + + for( size_t i = 1; i < nth; i++ ) + SubmitThreadpoolWork( work ); + runBatch( 0 ); + + if( nth > 1 ) + WaitForThreadpoolWorkCallbacks( work, FALSE ); + + computeRange = nullptr; + const HRESULT hr = status; + status = S_OK; + if( SUCCEEDED( hr ) ) + return S_OK; + + return hr; +}
\ No newline at end of file diff --git a/Whisper/CPU/ParallelForRunner.h b/Whisper/CPU/ParallelForRunner.h new file mode 100644 index 0000000..baef647 --- /dev/null +++ b/Whisper/CPU/ParallelForRunner.h @@ -0,0 +1,52 @@ +#pragma once +#include "LargeBuffer.h" + +namespace CpuCompute +{ + // Callback interface for the parallel `for` + __interface iComputeRange + { + // The implementation calls this method on multiple thread pool threads in parallel, and aggregates status codes. + HRESULT __stdcall compute( size_t begin, size_t end ) const; + }; + + // Similar to ThreadPoolWork in parallelFor.h, optimized to be used as a direct replacement of OpenMP pool. + class alignas( 64 ) ParallelForRunner + { + public: + ParallelForRunner( int threads ); + ~ParallelForRunner(); + + HRESULT setThreadsCount( int threads ); + + HRESULT parallelFor( iComputeRange& compute, size_t length, size_t minBatch = 1 ); + + // Allocate a temporary buffer for the calling thread. + // The pointer is guaranteed to be aligned by page size = 4kb + void* threadLocalBuffer( size_t cb ); + + private: + + int maxThreads; + PTP_WORK work = nullptr; + iComputeRange* computeRange = nullptr; + size_t countItems = 0; + size_t countThreads = 0; + + // Aligning by cache lines. + // Avoiding cache line sharing between CPU cores improves performance, despite wasting a few bytes of memory. + struct alignas( 64 ) ThreadBuffer + { + LargeBuffer memory; + size_t cb = 0; + }; + std::vector<ThreadBuffer> threadBuffers; + + alignas( 64 ) volatile long threadIndex = 0; + volatile HRESULT status = S_OK; + + void runBatch( size_t ith ) noexcept; + + static void __stdcall workCallbackStatic( PTP_CALLBACK_INSTANCE Instance, void* pv, PTP_WORK Work ) noexcept; + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/Readme.txt b/Whisper/CPU/Readme.txt new file mode 100644 index 0000000..702813d --- /dev/null +++ b/Whisper/CPU/Readme.txt @@ -0,0 +1 @@ +The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_HYBRID_VERSION macro in stdafx.h
\ No newline at end of file diff --git a/Whisper/CPU/Tensor.h b/Whisper/CPU/Tensor.h new file mode 100644 index 0000000..2fcde63 --- /dev/null +++ b/Whisper/CPU/Tensor.h @@ -0,0 +1,139 @@ +#pragma once +#include "../D3D/enums.h" +#include "../ML/TensorShape.h" +// 1 = new tensors can be allocated with `nullptr` iMemoryAllocator, by allocating memory internally and counting these references +// 0 = memory allocator is mandatory, create() methods will fail with E_POINTER if the allocator is `nullptr` +#define TENSOR_INTERNAL_ALLOC 0 + +// 1 = expose compatibility API for GGML interop +#define TENSOR_GGML_COMPAT 0 + +#if TENSOR_GGML_COMPAT +#include "../source/ggml.h" +#endif + +namespace CpuCompute +{ + using DirectCompute::TensorShape; + using DirectCompute::eDataType; + + __interface iMemoryAllocator + { + void* allocate( size_t cb, size_t align ); + }; + __interface iArenaAllocator : public iMemoryAllocator + { + void resetArena(); + }; + +#if TENSOR_GGML_COMPAT + class Tensor; + class GgmlTensorView + { + ggml_tensor tensor; + public: + GgmlTensorView( const Tensor& t ); + operator ggml_tensor* ( ) { return &tensor; } + }; +#endif + + // A functional equivalent of ggml_tensor structure, designed for use from C++ + class Tensor : public TensorShape + { + void* m_data = nullptr; + + eDataType m_type = (eDataType)0xFF; + +#if TENSOR_INTERNAL_ALLOC + // True when the memory block was allocated internally by this class + // In this case, this class does reference counting to support cheap copies. + // False when it's owned by someone else, such as iMemoryAllocator object, or a GGML's tensor + bool ownsMemory = false; + void deallocate(); +#endif + + // Private constructors for fromData() methods + Tensor( void* pointer, eDataType type, std::initializer_list<uint32_t> size ); + Tensor( void* pointer, eDataType type, uint32_t length ) noexcept; + public: + // Trivial constructors + Tensor() = default; +#if TENSOR_INTERNAL_ALLOC + ~Tensor() + { + deallocate(); + } +#else + ~Tensor() = default; +#endif + Tensor( Tensor&& that ) noexcept; + void operator=( Tensor&& that ) noexcept; + Tensor( const Tensor& that ); + void operator=( const Tensor& that ); + + // Allocate a new tensor + HRESULT create( eDataType type, const std::array<uint32_t, 4>& sizeElements, iMemoryAllocator* alloc = nullptr ); + // Allocate a new tensor + HRESULT create( eDataType type, std::initializer_list<uint32_t> sizeElements, iMemoryAllocator* alloc = nullptr ); + // Attach to pre-existing block of memory, interpreting the data as a dense tensor of the specified type and size + HRESULT attach( void* pointer, eDataType type, std::initializer_list<uint32_t> sizeElements ); + // Attach to pre-existing block of memory, interpret the data as a dense vector of the specified type and length + static Tensor fromData( void* pointer, eDataType type, uint32_t length ); + + eDataType type() const { return m_type; } + void* data() const { return m_data; } + + uint16_t* fp16() + { + assert( m_type == eDataType::FP16 ); + assert( nullptr != m_data ); + return (uint16_t*)m_data; + } + const uint16_t* fp16() const + { + assert( m_type == eDataType::FP16 ); + assert( nullptr != m_data ); + return (uint16_t*)m_data; + } + float* fp32() + { + assert( m_type == eDataType::FP32 ); + assert( nullptr != m_data ); + return (float*)m_data; + } + const float* fp32() const + { + assert( m_type == eDataType::FP32 ); + assert( nullptr != m_data ); + return (float*)m_data; + } + + Tensor reshape3d( uint32_t ne0, uint32_t ne1, uint32_t ne2 ) const; + + void setType( eDataType dt ) + { + m_type = dt; + } + void setDataPointer( void* pv ) + { + m_data = pv; + } + +#if TENSOR_GGML_COMPAT + // Compatibility with GGML's tensors, for testing and lulz + Tensor( const ggml_tensor* ggml ); + ggml_tensor ggml() const; + + operator GgmlTensorView() const + { + return GgmlTensorView( *this ); + } +#endif + }; + + // A pair of tensors containing weights and biases; apparently, both tensors are of the same shape + struct TensorPair + { + Tensor w, b; + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/TensorCpu.cpp b/Whisper/CPU/TensorCpu.cpp new file mode 100644 index 0000000..dc34464 --- /dev/null +++ b/Whisper/CPU/TensorCpu.cpp @@ -0,0 +1,401 @@ +#include <stdafx.h> +#include <atomic> +#include "Tensor.h" +using namespace CpuCompute; + +#if TENSOR_INTERNAL_ALLOC +namespace +{ + // This structure is immediately before the payload of every tensor which has an internally-allocated memory buffer + class alignas( 32 ) sTensorMemoryHeader + { + std::atomic_ptrdiff_t refCounter; + public: + // Reset the counter to the specified value + void reset( ptrdiff_t rc ) + { + refCounter = rc; + } + // Increment the ref.counter + void increment() + { + refCounter++; + } + // Decrement the ref.counter, and return true if it reached zero as the result + bool decrement() + { + ptrdiff_t val = --refCounter; + assert( val >= 0 ); + return 0 == val; + } + }; + + inline sTensorMemoryHeader* getMemBlockHeader( void* pv ) + { + assert( nullptr != pv ); + uint8_t* pb = (uint8_t*)pv; + static_assert( sizeof( sTensorMemoryHeader ) == 32 ); + return (sTensorMemoryHeader*)( pb - sizeof( sTensorMemoryHeader ) ); + } + + inline void releaseBlock( sTensorMemoryHeader* pointer ) + { + assert( nullptr != pointer ); + _aligned_free( pointer ); + } + + inline void* allocateBlock( size_t cb, ptrdiff_t initialRefCounter = 1 ) + { + cb += sizeof( sTensorMemoryHeader ); + void* pv = _aligned_malloc( cb, 32 ); + if( nullptr == pv ) + return nullptr; + + sTensorMemoryHeader* header = (sTensorMemoryHeader*)pv; + header->reset( initialRefCounter ); + return ( (uint8_t*)pv ) + sizeof( sTensorMemoryHeader ); + } +} + +void Tensor::deallocate() +{ + if( ownsMemory && nullptr != m_data ) + { + sTensorMemoryHeader* const header = getMemBlockHeader( m_data ); + if( header->decrement() ) + { + // This tensor is the last one which had a reference to that block of memory + // Release the memory back to the heap + releaseBlock( header ); + } + } + ownsMemory = false; + + TensorShape::setZero(); + m_data = nullptr; + m_type = (eDataType)0xFF; +} +#endif + +Tensor::Tensor( const Tensor& that ) +{ + store( ne, that.sizeVec() ); + store( nb, that.stridesVec() ); + m_data = that.m_data; + m_type = that.m_type; +#if TENSOR_INTERNAL_ALLOC + if( that.ownsMemory && nullptr != m_data ) + { + getMemBlockHeader( m_data )->increment(); + ownsMemory = true; + } + else + ownsMemory = false; +#endif +} + +Tensor::Tensor( Tensor&& that ) noexcept +{ + store( ne, that.sizeVec() ); + store( nb, that.stridesVec() ); + m_data = that.m_data; + m_type = that.m_type; +#if TENSOR_INTERNAL_ALLOC + ownsMemory = that.ownsMemory; + that.ownsMemory = false; +#endif + that.m_data = nullptr; +} + +void Tensor::operator=( const Tensor& that ) +{ + assert( this != &that ); +#if TENSOR_INTERNAL_ALLOC + deallocate(); +#endif + + store( ne, that.sizeVec() ); + store( nb, that.stridesVec() ); + m_data = that.m_data; + m_type = that.m_type; +#if TENSOR_INTERNAL_ALLOC + if( that.ownsMemory && nullptr != m_data ) + { + getMemBlockHeader( m_data )->increment(); + ownsMemory = true; + } + else + ownsMemory = false; +#endif +} + +void Tensor::operator=( Tensor&& that ) noexcept +{ + assert( this != &that ); +#if TENSOR_INTERNAL_ALLOC + deallocate(); +#endif + store( ne, that.sizeVec() ); + store( nb, that.stridesVec() ); + m_data = that.m_data; + m_type = that.m_type; + that.m_data = nullptr; +#if TENSOR_INTERNAL_ALLOC + ownsMemory = that.ownsMemory; + that.ownsMemory = false; +#endif +} + +HRESULT Tensor::create( eDataType type, const std::array<uint32_t, 4>& sizeElements, iMemoryAllocator* alloc ) +{ + const size_t len = (size_t)sizeElements[ 0 ] * sizeElements[ 1 ] * sizeElements[ 2 ] * sizeElements[ 3 ]; + const size_t cbElement = DirectCompute::elementSize( type ); + const size_t cb = len * cbElement; + +#if TENSOR_INTERNAL_ALLOC + deallocate(); +#endif + + store( ne, load( sizeElements ) ); + TensorShape::setDenseStrides(); + this->m_type = type; + + if( nullptr != alloc ) + { +#if TENSOR_INTERNAL_ALLOC + ownsMemory = false; +#endif + m_data = alloc->allocate( cb, 32 ); + if( nullptr == m_data ) + return E_OUTOFMEMORY; + return S_OK; + } + else + { +#if TENSOR_INTERNAL_ALLOC + m_data = allocateBlock( cb, 1 ); + if( nullptr == m_data ) + return E_OUTOFMEMORY; + ownsMemory = true; + return S_OK; +#else + return E_POINTER; +#endif + } +} + +namespace +{ + static HRESULT arrayFromList( std::array<uint32_t, 4>& arr, std::initializer_list<uint32_t> list ) + { + const size_t dims = list.size(); + if( dims == 0 || dims > 4 ) + return E_INVALIDARG; + + for( size_t i = 0; i < dims; i++ ) + { + uint32_t u = list.begin()[ i ]; + if( u == 0 ) + return E_INVALIDARG; + arr[ i ] = u; + } + + for( size_t i = dims; i < 4; i++ ) + arr[ i ] = 1; + + return S_OK; + } +} + +HRESULT Tensor::create( eDataType type, std::initializer_list<uint32_t> sizeElements, iMemoryAllocator* alloc ) +{ + std::array<uint32_t, 4> arr; + CHECK( arrayFromList( arr, sizeElements ) ); + + return create( type, arr, alloc ); +} + +Tensor::Tensor( void* pointer, eDataType type, std::initializer_list<uint32_t> size ) +{ + if( nullptr == pointer ) + throw E_POINTER; + check( arrayFromList( ne, size ) ); + TensorShape::setDenseStrides(); + m_data = pointer; + m_type = type; +#if TENSOR_INTERNAL_ALLOC + ownsMemory = false; +#endif +} + +Tensor::Tensor( void* pointer, eDataType type, uint32_t length ) noexcept +{ + // size = [ length, 1, 1, 1 ] + const __m128i one = _mm_set1_epi32( 1 ); + __m128i v = _mm_insert_epi32( one, (int)length, 0 ); + store( ne, v ); + // stride = [ 1, length, length, length ] + v = _mm_shuffle_epi32( v, _MM_SHUFFLE( 0, 0, 0, 1 ) ); + store( nb, v ); + + m_data = pointer; + m_type = type; +#if TENSOR_INTERNAL_ALLOC + ownsMemory = false; +#endif +} + +Tensor Tensor::fromData( void* pointer, eDataType type, uint32_t length ) +{ + HRESULT hr = E_UNEXPECTED; + if( nullptr != pointer ) + { + if( 0 != length ) + return Tensor{ pointer, type, length }; + else + hr = E_INVALIDARG; + } + else + hr = E_POINTER; + throw hr; +} + +HRESULT Tensor::attach( void* pointer, eDataType type, std::initializer_list<uint32_t> sizeElements ) +{ + if( nullptr == pointer ) + return E_POINTER; + + std::array<uint32_t, 4> arr; + CHECK( arrayFromList( arr, sizeElements ) ); + +#if TENSOR_INTERNAL_ALLOC + deallocate(); +#endif + store( ne, load( arr ) ); + TensorShape::setDenseStrides(); + + m_data = pointer; + this->m_type = type; +#if TENSOR_INTERNAL_ALLOC + ownsMemory = false; +#endif + return S_OK; +} + +Tensor Tensor::reshape3d( uint32_t ne0, uint32_t ne1, uint32_t ne2 ) const +{ + if( !isContinuous() ) + throw E_NOTIMPL; + if( countElements() != ne0 * ne1 * ne2 ) + throw E_INVALIDARG; + + Tensor res = *this; + res.ne = { ne0, ne1, ne2, 1 }; + res.setDenseStrides(); + return res; +} + +#if TENSOR_GGML_COMPAT +static const __m128i s_maskAlignment16 = _mm_set1_epi64x( 1 ); +static const __m128i s_maskAlignment32 = _mm_set1_epi64x( 3 ); + +bool isAlignedProperly( __m128i r0, __m128i r1, __m128i mask ) +{ + __m128i test = _mm_or_si128( r0, r1 ); + return (bool)_mm_testz_si128( test, mask ); +} + +Tensor::Tensor( const ggml_tensor* ggml ) +{ + store( ne, load16( ggml->ne ) ); + + __m128i r0 = load16( (const int*)&ggml->nb[ 0 ] ); + __m128i r1 = load16( (const int*)&ggml->nb[ 2 ] ); + // Divide from bytes into elements by right-shifting the 64-bit integers in these vectors + switch( ggml->type ) + { + case GGML_TYPE_F16: + assert( isAlignedProperly( r0, r1, s_maskAlignment16 ) ); + r0 = _mm_srli_epi64( r0, 1 ); + r1 = _mm_srli_epi64( r1, 1 ); + m_type = eDataType::FP16; + break; + + case GGML_TYPE_F32: + assert( isAlignedProperly( r0, r1, s_maskAlignment32 ) ); + r0 = _mm_srli_epi64( r0, 2 ); + r1 = _mm_srli_epi64( r1, 2 ); + m_type = eDataType::FP32; + break; + + case GGML_TYPE_I32: + assert( isAlignedProperly( r0, r1, s_maskAlignment32 ) ); + r0 = _mm_srli_epi64( r0, 2 ); + r1 = _mm_srli_epi64( r1, 2 ); + m_type = eDataType::U32; + break; + + default: + throw E_INVALIDARG; + } + // downcast uint64_t into uint32_t in a single vector + r0 = _mm_shuffle_epi32( r0, _MM_SHUFFLE( 3, 3, 2, 0 ) ); + r1 = _mm_shuffle_epi32( r1, _MM_SHUFFLE( 2, 0, 3, 3 ) ); + store( nb, _mm_blend_epi16( r0, r1, 0b11110000 ) ); + + m_data = ggml->data; +} + +ggml_tensor Tensor::ggml() const +{ + ggml_tensor res; + memset( &res, 0, sizeof( ggml_tensor ) ); + + const __m128i size = sizeVec(); + store16( res.ne, size ); + + const __m128i one = _mm_set1_epi32( 1 ); + const uint32_t maskOnes = (uint32_t)_mm_movemask_ps( _mm_castsi128_ps( _mm_cmpeq_epi32( size, one ) ) ); + const uint32_t maskNotOnes = maskOnes ^ 0b1111; + unsigned long idx; + if( _BitScanReverse( &idx, maskNotOnes ) ) + res.n_dims = (int)idx + 1; + else + res.n_dims = 0; + + const __m128i strides = stridesVec(); + // Upcast strides from u32 to u64 + const __m128i zero = _mm_setzero_si128(); + __m128i r0 = _mm_unpacklo_epi32( strides, zero ); + __m128i r1 = _mm_unpackhi_epi32( strides, zero ); + // Scale from elements into bytes with left shift vector instructions + switch( m_type ) + { + case eDataType::FP16: + r0 = _mm_slli_epi64( r0, 1 ); + r1 = _mm_slli_epi64( r1, 1 ); + res.type = GGML_TYPE_F16; + break; + case eDataType::FP32: + r0 = _mm_slli_epi64( r0, 2 ); + r1 = _mm_slli_epi64( r1, 2 ); + res.type = GGML_TYPE_F32; + break; + case eDataType::U32: + r0 = _mm_slli_epi64( r0, 2 ); + r1 = _mm_slli_epi64( r1, 2 ); + res.type = GGML_TYPE_I32; + break; + default: + throw OLE_E_BLANK; + } + + store16( &res.nb[ 0 ], r0 ); + store16( &res.nb[ 2 ], r1 ); + + res.data = m_data; + return res; +} + +GgmlTensorView::GgmlTensorView( const Tensor& t ) : tensor( t.ggml() ) {} +#endif
\ No newline at end of file diff --git a/Whisper/CPU/mulMat.cpp b/Whisper/CPU/mulMat.cpp new file mode 100644 index 0000000..1b6ed25 --- /dev/null +++ b/Whisper/CPU/mulMat.cpp @@ -0,0 +1,54 @@ +#include "stdafx.h" +#include "mulMat.h" +#include "mulMatImpl.h" +using namespace CpuCompute; + +namespace +{ + template<uint8_t panelHeightRegs, uint8_t tileWidthFloats> + static HRESULT mulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) + { + MulMatImpl<panelHeightRegs, tileWidthFloats> impl{ result, a, b, pfor }; + return impl.run( pfor ); + } +} + +HRESULT CpuCompute::mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) +{ + if( a.type() != eDataType::FP16 ) + return E_NOTIMPL; + if( b.type() != eDataType::FP32 ) + return E_NOTIMPL; + + // return mulMatImpl<1, 1>( result, a, b, pfor ); + + if( b.ne[ 1 ] == 1 ) + { + // Multiplying by a single row + if( a.ne[ 1 ] >= 32 ) + return mulMatImpl<4, 1>( result, a, b, pfor ); + else + return mulMatImpl<1, 1>( result, a, b, pfor ); + } + else if( b.ne[ 1 ] == 2 ) + { + if( a.ne[ 1 ] >= 32 ) + return mulMatImpl<4, 2>( result, a, b, pfor ); + else + return mulMatImpl<1, 2>( result, a, b, pfor ); + } + else if( b.ne[ 1 ] == 3 ) + { + if( a.ne[ 1 ] >= 16 ) + return mulMatImpl<2, 3>( result, a, b, pfor ); + else + return mulMatImpl<1, 3>( result, a, b, pfor ); + } + else + { + if( a.ne[ 1 ] >= 16 ) + return mulMatImpl<2, 4>( result, a, b, pfor ); + else + return mulMatImpl<1, 4>( result, a, b, pfor ); + } +}
\ No newline at end of file diff --git a/Whisper/CPU/mulMat.h b/Whisper/CPU/mulMat.h new file mode 100644 index 0000000..36e56fe --- /dev/null +++ b/Whisper/CPU/mulMat.h @@ -0,0 +1,17 @@ +#pragma once +#include "ParallelForRunner.h" +#include "Tensor.h" + +namespace CpuCompute +{ + HRESULT mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ); +} + +#if TENSOR_GGML_COMPAT +#include "../source/ggml.h" +inline HRESULT mulMat( ggml_tensor* result, const ggml_tensor* a, const ggml_tensor* b, CpuCompute::ParallelForRunner& pfor ) +{ + CpuCompute::Tensor r{ result }, lhs{ a }, rhs{ b }; + return CpuCompute::mulMat( r, lhs, rhs, pfor ); +} +#endif
\ No newline at end of file diff --git a/Whisper/CPU/mulMat.kernel.hpp b/Whisper/CPU/mulMat.kernel.hpp new file mode 100644 index 0000000..80b51a0 --- /dev/null +++ b/Whisper/CPU/mulMat.kernel.hpp @@ -0,0 +1,742 @@ +#pragma once +#include <stdint.h> +#include <immintrin.h> +#include "simdUtils.h" + +template<uint8_t panelHeightRegs, uint8_t tileWidthFloats> +struct ResultTile +{ + static constexpr size_t totalRegs = (size_t)(tileWidthFloats)*panelHeightRegs; + std::array<__m256, totalRegs> arr; + + template<size_t idx> + __forceinline void fmadd( __m256 a, __m256 b ) + { + arr[ idx ] = _mm256_fmadd_ps( a, b, arr[ idx ] ); + } + __forceinline void kernel( const std::array<__m256, panelHeightRegs>& panel, const float* rsi, size_t stride ); + __forceinline void kernelPartial( const std::array<__m256, panelHeightRegs>& panel, const float* rsi, size_t stride, size_t rem ) + { + throw E_UNEXPECTED; + } + __forceinline void store( float* rdi, size_t w, size_t h, size_t stride ) const; +}; + +#pragma region setZero functions +__forceinline void setZero( std::array<__m256, 1>& dest ) +{ + dest[ 0 ] = _mm256_setzero_ps(); +} +__forceinline void setZero( std::array<__m256, 2>& dest ) +{ + dest[ 0 ] = _mm256_setzero_ps(); + dest[ 1 ] = _mm256_setzero_ps(); +} +__forceinline void setZero( std::array<__m256, 3>& dest ) +{ + dest[ 0 ] = _mm256_setzero_ps(); + dest[ 1 ] = _mm256_setzero_ps(); + dest[ 2 ] = _mm256_setzero_ps(); +} +__forceinline void setZero( std::array<__m256, 4>& dest ) +{ + dest[ 0 ] = _mm256_setzero_ps(); + dest[ 1 ] = _mm256_setzero_ps(); + dest[ 2 ] = _mm256_setzero_ps(); + dest[ 3 ] = _mm256_setzero_ps(); +} +__forceinline void setZero( std::array<__m256, 6>& dest ) +{ + dest[ 0 ] = _mm256_setzero_ps(); + dest[ 1 ] = _mm256_setzero_ps(); + dest[ 2 ] = _mm256_setzero_ps(); + dest[ 3 ] = _mm256_setzero_ps(); + dest[ 4 ] = _mm256_setzero_ps(); + dest[ 5 ] = _mm256_setzero_ps(); +} +__forceinline void setZero( std::array<__m256, 8>& dest ) +{ + dest[ 0 ] = _mm256_setzero_ps(); + dest[ 1 ] = _mm256_setzero_ps(); + dest[ 2 ] = _mm256_setzero_ps(); + dest[ 3 ] = _mm256_setzero_ps(); + dest[ 4 ] = _mm256_setzero_ps(); + dest[ 5 ] = _mm256_setzero_ps(); + dest[ 6 ] = _mm256_setzero_ps(); + dest[ 7 ] = _mm256_setzero_ps(); +} +#pragma endregion + +#pragma region Micro-kernels +__forceinline void ResultTile<1, 1>::kernel( const std::array<__m256, 1>& panel, const float* rsi, size_t stride ) +{ + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); +} +__forceinline void ResultTile<1, 2>::kernel( const std::array<__m256, 1>& panel, const float* rsi, size_t stride ) +{ + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<1>( panel[ 0 ], b ); +} +__forceinline void ResultTile<1, 2>::kernelPartial( const std::array<__m256, 1>& panel, const float* rsi, size_t stride, size_t rem ) +{ + assert( 1 == rem ); + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); +} +__forceinline void ResultTile<1, 3>::kernel( const std::array<__m256, 1>& panel, const float* rsi, size_t stride ) +{ + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<1>( panel[ 0 ], b ); + b = _mm256_broadcast_ss( rsi + stride * 2 ); + fmadd<2>( panel[ 0 ], b ); +} +__forceinline void ResultTile<1, 3>::kernelPartial( const std::array<__m256, 1>& panel, const float* rsi, size_t stride, size_t rem ) +{ + assert( rem > 0 && rem < 3 ); + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + if( rem > 1 ) + { + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<1>( panel[ 0 ], b ); + } +} + +__forceinline void ResultTile<1, 4>::kernel( const std::array<__m256, 1>& panel, const float* rsi, size_t stride ) +{ + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<1>( panel[ 0 ], b ); + b = _mm256_broadcast_ss( rsi + stride * 2 ); + fmadd<2>( panel[ 0 ], b ); + b = _mm256_broadcast_ss( rsi + stride * 3 ); + fmadd<3>( panel[ 0 ], b ); +} +__forceinline void ResultTile<1, 4>::kernelPartial( const std::array<__m256, 1>& panel, const float* rsi, size_t stride, size_t rem ) +{ + assert( rem > 0 && rem < 4 ); + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + + switch( rem ) + { + case 3: + b = _mm256_broadcast_ss( rsi + stride * 2 ); + fmadd<2>( panel[ 0 ], b ); + case 2: + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<1>( panel[ 0 ], b ); + } +} +__forceinline void ResultTile<4, 1>::kernel( const std::array<__m256, 4>& panel, const float* rsi, size_t stride ) +{ + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + fmadd<1>( panel[ 1 ], b ); + fmadd<2>( panel[ 2 ], b ); + fmadd<3>( panel[ 3 ], b ); +} +__forceinline void ResultTile<2, 4>::kernel( const std::array<__m256, 2>& panel, const float* rsi, size_t stride ) +{ + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + fmadd<1>( panel[ 1 ], b ); + + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<2>( panel[ 0 ], b ); + fmadd<3>( panel[ 1 ], b ); + + b = _mm256_broadcast_ss( rsi + stride * 2 ); + fmadd<4>( panel[ 0 ], b ); + fmadd<5>( panel[ 1 ], b ); + + b = _mm256_broadcast_ss( rsi + stride * 3 ); + fmadd<6>( panel[ 0 ], b ); + fmadd<7>( panel[ 1 ], b ); +} + +__forceinline void ResultTile<2, 4>::kernelPartial( const std::array<__m256, 2>& panel, const float* rsi, size_t stride, size_t rem ) +{ + assert( rem > 0 && rem < 4 ); + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + fmadd<1>( panel[ 1 ], b ); + + switch( rem ) + { + case 3: + b = _mm256_broadcast_ss( rsi + stride * 2 ); + fmadd<4>( panel[ 0 ], b ); + fmadd<5>( panel[ 1 ], b ); + case 2: + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<2>( panel[ 0 ], b ); + fmadd<3>( panel[ 1 ], b ); + } +} + +__forceinline void ResultTile<2, 3>::kernel( const std::array<__m256, 2>& panel, const float* rsi, size_t stride ) +{ + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + fmadd<1>( panel[ 1 ], b ); + + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<2>( panel[ 0 ], b ); + fmadd<3>( panel[ 1 ], b ); + + b = _mm256_broadcast_ss( rsi + stride * 2 ); + fmadd<4>( panel[ 0 ], b ); + fmadd<5>( panel[ 1 ], b ); +} +__forceinline void ResultTile<2, 3>::kernelPartial( const std::array<__m256, 2>& panel, const float* rsi, size_t stride, size_t rem ) +{ + assert( rem > 0 && rem < 3 ); + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + fmadd<1>( panel[ 1 ], b ); + if( rem > 1 ) + { + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<2>( panel[ 0 ], b ); + fmadd<3>( panel[ 1 ], b ); + } +} + +__forceinline void ResultTile<4, 2>::kernel( const std::array<__m256, 4>& panel, const float* rsi, size_t stride ) +{ + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + fmadd<1>( panel[ 1 ], b ); + fmadd<2>( panel[ 2 ], b ); + fmadd<3>( panel[ 3 ], b ); + + b = _mm256_broadcast_ss( rsi + stride ); + fmadd<4>( panel[ 0 ], b ); + fmadd<5>( panel[ 1 ], b ); + fmadd<6>( panel[ 2 ], b ); + fmadd<7>( panel[ 3 ], b ); +} +__forceinline void ResultTile<4, 2>::kernelPartial( const std::array<__m256, 4>& panel, const float* rsi, size_t stride, size_t rem ) +{ + assert( 1 == rem ); + __m256 b = _mm256_broadcast_ss( rsi ); + fmadd<0>( panel[ 0 ], b ); + fmadd<1>( panel[ 1 ], b ); + fmadd<2>( panel[ 2 ], b ); + fmadd<3>( panel[ 3 ], b ); +} +#pragma endregion + +#pragma region Loads +// This function should compile into a single `vcvtph2ps` instruction, with memory operand +__forceinline __m256 loadUpcasted( const uint16_t* rsi ) +{ + __m128i i = _mm_load_si128( ( const __m128i* )rsi ); + return _mm256_cvtph_ps( i ); +} + +// We loading the panel from the temporary buffer. +// For this reason, we don't need to handle remainders, the code which made the buffer wrote zeros into the remainder elements +// We can even use aligned load instructions. +__forceinline void loadPanel( const uint16_t* rsi, std::array<__m256, 1>& dest ) +{ + dest[ 0 ] = loadUpcasted( rsi ); +} +__forceinline void loadPanel( const uint16_t* rsi, std::array<__m256, 2>& dest ) +{ + dest[ 0 ] = loadUpcasted( rsi ); + dest[ 1 ] = loadUpcasted( rsi + 8 ); +} +__forceinline void loadPanel( const uint16_t* rsi, std::array<__m256, 3>& dest ) +{ + dest[ 0 ] = loadUpcasted( rsi ); + dest[ 1 ] = loadUpcasted( rsi + 8 ); + dest[ 2 ] = loadUpcasted( rsi + 8 * 2 ); +} +__forceinline void loadPanel( const uint16_t* rsi, std::array<__m256, 4>& dest ) +{ + dest[ 0 ] = loadUpcasted( rsi ); + dest[ 1 ] = loadUpcasted( rsi + 8 ); + dest[ 2 ] = loadUpcasted( rsi + 8 * 2 ); + dest[ 3 ] = loadUpcasted( rsi + 8 * 3 ); +} +#pragma endregion + +#pragma region Stores +__forceinline void ResultTile<1, 1>::store( float* rdi, size_t w, size_t h, size_t stride ) const +{ + assert( h == 1 && w > 0 && w <= 8 ); + if( w == 8 ) + _mm256_storeu_ps( rdi, arr[ 0 ] ); + else + { + const __m256i mask = loadTailMaskInt( w ); + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + } +} + +__forceinline void ResultTile<1, 2>::store( float* rdi, size_t w, size_t h, size_t stride ) const +{ + assert( h > 0 && w > 0 && h <= 2 && w <= 8 ); + if( w == 8 ) + { + switch( h ) + { + case 2: + _mm256_storeu_ps( rdi + stride, arr[ 1 ] ); + case 1: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + } + } + else + { + const __m256i mask = loadTailMaskInt( w ); + switch( h ) + { + case 2: + _mm256_maskstore_ps( rdi + stride, mask, arr[ 1 ] ); + case 1: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + } + } +} + +__forceinline void ResultTile<1, 3>::store( float* rdi, size_t w, size_t h, size_t stride ) const +{ + assert( h > 0 && w > 0 && h <= 3 && w <= 8 ); + if( w == 8 ) + { + switch( h ) + { + case 3: + _mm256_storeu_ps( rdi + stride * 2, arr[ 2 ] ); + case 2: + _mm256_storeu_ps( rdi + stride, arr[ 1 ] ); + case 1: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + } + } + else + { + const __m256i mask = loadTailMaskInt( w ); + switch( h ) + { + case 3: + _mm256_maskstore_ps( rdi + stride * 2, mask, arr[ 2 ] ); + case 2: + _mm256_maskstore_ps( rdi + stride, mask, arr[ 1 ] ); + case 1: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + } + } +} + +__forceinline void ResultTile<1, 4>::store( float* rdi, size_t w, size_t h, size_t stride ) const +{ + assert( h > 0 && w > 0 && h <= 4 && w <= 8 ); + + if( w == 8 ) + { + switch( h ) + { + case 4: + _mm256_storeu_ps( rdi + stride * 3, arr[ 3 ] ); + case 3: + _mm256_storeu_ps( rdi + stride * 2, arr[ 2 ] ); + case 2: + _mm256_storeu_ps( rdi + stride, arr[ 1 ] ); + case 1: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + } + } + else + { + const __m256i mask = loadTailMaskInt( w ); + switch( h ) + { + case 4: + _mm256_maskstore_ps( rdi + stride * 3, mask, arr[ 3 ] ); + case 3: + _mm256_maskstore_ps( rdi + stride * 2, mask, arr[ 2 ] ); + case 2: + _mm256_maskstore_ps( rdi + stride, mask, arr[ 1 ] ); + case 1: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + } + } +} + +__forceinline void ResultTile<4, 1>::store( float* rdi, size_t w, size_t h, size_t stride ) const +{ + assert( h == 1 && w > 0 && w <= 32 ); + if( w == 32 ) + { + // 4 complete vectors, this branch is very likely to be taken + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] ); + _mm256_storeu_ps( rdi + 8 * 3, arr[ 3 ] ); + } + else + { + const size_t rem = w % 8; + const __m256i mask = loadTailMaskInt<false>( rem ); + const size_t completeVectors = w / 8; + const size_t key = ( completeVectors << 1 ) | ( ( 0 == rem ) ? 0 : 1 ); + switch( key ) + { + case 1: // 0 complete vectors + remainder + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + break; + case 2: // 1 complete vector + _mm256_storeu_ps( rdi, arr[ 0 ] ); + break; + case 3: // 1 complete vector + remainder + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + break; + case 4: // 2 complete vectors + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + break; + case 5: // 2 complete vectors + remainder + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_maskstore_ps( rdi + 8 * 2, mask, arr[ 2 ] ); + break; + case 6: // 3 complete vectors + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] ); + break; + case 7: // 3 complete vectors + remainder + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] ); + _mm256_maskstore_ps( rdi + 8 * 3, mask, arr[ 3 ] ); + break; + default: + throw E_UNEXPECTED; + } + } +} +__forceinline void ResultTile<4, 2>::store( float* rdi, size_t w, size_t h, size_t stride ) const +{ + assert( h > 0 && w > 0 && h <= 2 && w <= 32 ); + const bool twoRows = h == 2; + float* const rdi1 = rdi + stride; + if( w == 32 ) + { + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] ); + _mm256_storeu_ps( rdi + 8 * 3, arr[ 3 ] ); + + if( twoRows ) + { + _mm256_storeu_ps( rdi1, arr[ 4 ] ); + _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] ); + _mm256_storeu_ps( rdi1 + 8 * 2, arr[ 6 ] ); + _mm256_storeu_ps( rdi1 + 8 * 3, arr[ 7 ] ); + } + } + else + { + const size_t rem = w % 8; + const __m256i mask = loadTailMaskInt<false>( rem ); + const size_t completeVectors = w / 8; + // Lowest bit: remainder + // Next bit: set when storing 2 rows + // Next 2 bits: count of complete vectors in X direction, [ 0..3 ] + const size_t key = ( completeVectors << 2 ) | ( ( 0 == rem ) ? 0 : 1 ) | ( twoRows ? 2 : 0 ); + switch( key ) + { + case 1: // 0 complete vectors + remainder, 1 row + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + break; + case 3: // 0 complete vectors + remainder, 2 rows + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + _mm256_maskstore_ps( rdi1, mask, arr[ 4 ] ); + break; + case 4: // 1 complete vector, 1 row + _mm256_storeu_ps( rdi, arr[ 0 ] ); + break; + case 5: // 1 complete vector + remainder, 1 row + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + break; + case 6: // 1 complete vector, 2 rows + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi1, arr[ 4 ] ); + break; + case 7: // 1 complete vector + remainder, 2 rows + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + + _mm256_storeu_ps( rdi1, arr[ 4 ] ); + _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 5 ] ); + break; + case 8: // 2 complete vectors, 1 row + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + break; + case 9: // 2 complete vectors + remainder, 1 row + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_maskstore_ps( rdi + 8 * 2, mask, arr[ 2 ] ); + break; + case 10: // 2 complete vectors, 2 rows + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + + _mm256_storeu_ps( rdi1, arr[ 4 ] ); + _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] ); + break; + case 11: // 2 complete vectors + remainder, 2 rows + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_maskstore_ps( rdi + 8 * 2, mask, arr[ 2 ] ); + + _mm256_storeu_ps( rdi1, arr[ 4 ] ); + _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] ); + _mm256_maskstore_ps( rdi1 + 8 * 2, mask, arr[ 6 ] ); + break; + case 12: // 3 complete vectors, 1 row + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] ); + break; + case 13: // 3 complete vectors + remainder, 1 row + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] ); + _mm256_maskstore_ps( rdi + 8 * 3, mask, arr[ 3 ] ); + break; + case 14: // 3 complete vectors, 2 rows + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] ); + + _mm256_storeu_ps( rdi1, arr[ 4 ] ); + _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] ); + _mm256_storeu_ps( rdi1 + 8 * 2, arr[ 6 ] ); + break; + case 15: // 3 complete vectors + remainder, 2 rows + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] ); + _mm256_maskstore_ps( rdi + 8 * 3, mask, arr[ 3 ] ); + + _mm256_storeu_ps( rdi1, arr[ 4 ] ); + _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] ); + _mm256_storeu_ps( rdi1 + 8 * 2, arr[ 6 ] ); + _mm256_maskstore_ps( rdi1 + 8 * 3, mask, arr[ 7 ] ); + break; + default: + throw E_UNEXPECTED; + } + } +} + +__forceinline void ResultTile<2, 4>::store( float* rdi, size_t w, size_t h, size_t stride ) const +{ + assert( h > 0 && w > 0 && h <= 4 && w <= 16 ); + h--; + float* const rdi1 = rdi + stride; + float* const rdi2 = rdi + stride * 2; + float* const rdi3 = rdi + stride * 3; + + if( w == 16 ) + { + switch( h ) + { + case 3: + _mm256_storeu_ps( rdi3, arr[ 6 ] ); + _mm256_storeu_ps( rdi3 + 8, arr[ 7 ] ); + case 2: + _mm256_storeu_ps( rdi2, arr[ 4 ] ); + _mm256_storeu_ps( rdi2 + 8, arr[ 5 ] ); + case 1: + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_storeu_ps( rdi1 + 8, arr[ 3 ] ); + case 0: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + } + } + else + { + const size_t rem = w % 8; + const __m256i mask = loadTailMaskInt<false>( rem ); + // 0 for partial first vector, 1 for exactly 1 complete vector, 2 for 1 complete vector with remainder + const size_t partialCase = ( w < 8 ) ? 0 : ( ( w == 8 ) ? 1 : 2 ); + // Merge into a single integer for the switch statement + const size_t key = partialCase + h * 3; + + switch( key ) + { + // h = 1 + case 0: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + break; + case 1: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + break; + case 2: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + break; + // h = 2 + case 3: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] ); + break; + case 4: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + break; + case 5: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] ); + break; + // h = 3 + case 6: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] ); + _mm256_maskstore_ps( rdi2, mask, arr[ 4 ] ); + break; + case 7: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_storeu_ps( rdi2, arr[ 4 ] ); + break; + case 8: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] ); + _mm256_storeu_ps( rdi2, arr[ 4 ] ); + _mm256_maskstore_ps( rdi2 + 8, mask, arr[ 5 ] ); + break; + // h = 4 + case 9: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] ); + _mm256_maskstore_ps( rdi2, mask, arr[ 4 ] ); + _mm256_maskstore_ps( rdi3, mask, arr[ 6 ] ); + break; + case 10: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_storeu_ps( rdi2, arr[ 4 ] ); + _mm256_storeu_ps( rdi3, arr[ 6 ] ); + break; + case 11: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] ); + _mm256_storeu_ps( rdi2, arr[ 4 ] ); + _mm256_maskstore_ps( rdi2 + 8, mask, arr[ 5 ] ); + _mm256_storeu_ps( rdi3, arr[ 6 ] ); + _mm256_maskstore_ps( rdi3 + 8, mask, arr[ 7 ] ); + break; + default: + throw E_UNEXPECTED; + } + } +} + +__forceinline void ResultTile<2, 3>::store( float* rdi, size_t w, size_t h, size_t stride ) const +{ + assert( h > 0 && w > 0 && h <= 3 && w <= 16 ); + float* const rdi1 = rdi + stride; + float* const rdi2 = rdi + stride * 2; + h--; + + if( w == 16 ) + { + switch( h ) + { + case 2: + _mm256_storeu_ps( rdi2, arr[ 4 ] ); + _mm256_storeu_ps( rdi2 + 8, arr[ 5 ] ); + case 1: + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_storeu_ps( rdi1 + 8, arr[ 3 ] ); + case 0: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi + 8, arr[ 1 ] ); + } + } + else + { + const size_t rem = w % 8; + const __m256i mask = loadTailMaskInt<false>( rem ); + // 0 for partial first vector, 1 for exactly 1 complete vector, 2 for 1 complete vector with remainder + const size_t partialCase = ( w < 8 ) ? 0 : ( ( w == 8 ) ? 1 : 2 ); + // Merge into a single integer for the switch statement + const size_t key = partialCase + h * 3; + + switch( key ) + { + // h = 1 + case 0: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + break; + case 1: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + break; + case 2: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + break; + // h = 2 + case 3: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] ); + break; + case 4: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + break; + case 5: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] ); + break; + // h = 3 + case 6: + _mm256_maskstore_ps( rdi, mask, arr[ 0 ] ); + _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] ); + _mm256_maskstore_ps( rdi2, mask, arr[ 4 ] ); + break; + case 7: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_storeu_ps( rdi2, arr[ 4 ] ); + break; + case 8: + _mm256_storeu_ps( rdi, arr[ 0 ] ); + _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] ); + _mm256_storeu_ps( rdi1, arr[ 2 ] ); + _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] ); + _mm256_storeu_ps( rdi2, arr[ 4 ] ); + _mm256_maskstore_ps( rdi2 + 8, mask, arr[ 5 ] ); + break; + default: + throw E_UNEXPECTED; + } + } +} +#pragma endregion
\ No newline at end of file diff --git a/Whisper/CPU/mulMatImpl.avx2.cpp b/Whisper/CPU/mulMatImpl.avx2.cpp new file mode 100644 index 0000000..b15ae63 --- /dev/null +++ b/Whisper/CPU/mulMatImpl.avx2.cpp @@ -0,0 +1,362 @@ +#include "stdafx.h" +#include "mulMatImpl.h" +#include <immintrin.h> +#include "mulMatUtils.hpp" +using namespace CpuCompute; + +namespace +{ + constexpr size_t prefetchBytes = 96; + constexpr int prefetchHint = _MM_HINT_T0; + + constexpr size_t maskAlign16 = ~(size_t)15; + + __forceinline __m256i load( const void* rsi ) + { + return _mm256_loadu_si256( ( const __m256i* )rsi ); + } + +#define TRANSPOSE_8X16() \ + \ + __m256i t0 = _mm256_unpacklo_epi16( r0, r1 ); \ + __m256i t1 = _mm256_unpackhi_epi16( r0, r1 ); \ + __m256i t2 = _mm256_unpacklo_epi16( r2, r3 ); \ + __m256i t3 = _mm256_unpackhi_epi16( r2, r3 ); \ + __m256i t4 = _mm256_unpacklo_epi16( r4, r5 ); \ + __m256i t5 = _mm256_unpackhi_epi16( r4, r5 ); \ + __m256i t6 = _mm256_unpacklo_epi16( r6, r7 ); \ + __m256i t7 = _mm256_unpackhi_epi16( r6, r7 ); \ + \ + r0 = _mm256_unpacklo_epi32( t0, t2 ); \ + r1 = _mm256_unpackhi_epi32( t0, t2 ); \ + r2 = _mm256_unpacklo_epi32( t1, t3 ); \ + r3 = _mm256_unpackhi_epi32( t1, t3 ); \ + r4 = _mm256_unpacklo_epi32( t4, t6 ); \ + r5 = _mm256_unpackhi_epi32( t4, t6 ); \ + r6 = _mm256_unpacklo_epi32( t5, t7 ); \ + r7 = _mm256_unpackhi_epi32( t5, t7 ); \ + \ + t0 = _mm256_unpacklo_epi64( r0, r4 ); \ + t1 = _mm256_unpackhi_epi64( r0, r4 ); \ + t2 = _mm256_unpacklo_epi64( r1, r5 ); \ + t3 = _mm256_unpackhi_epi64( r1, r5 ); \ + t4 = _mm256_unpacklo_epi64( r2, r6 ); \ + t5 = _mm256_unpackhi_epi64( r2, r6 ); \ + t6 = _mm256_unpacklo_epi64( r3, r7 ); \ + t7 = _mm256_unpackhi_epi64( r3, r7 ) + + __forceinline void storeLow( void* rdi, __m256i v ) + { + __m128i i = _mm256_castsi256_si128( v ); + _mm_store_si128( ( __m128i* )rdi, i ); + } + +#define STORE_8X16_LOW() \ + storeLow( rdi, t0 ); \ + storeLow( rdi + destStride, t1 ); \ + storeLow( rdi + destStride * 2, t2 ); \ + rdi += destStride * 8; \ + storeLow( rdiMid, t3 ); \ + storeLow( rdiMid + destStride, t4 ); \ + storeLow( rdiMid + destStride * 2, t5 ); \ + rdiMid += destStride * 8; \ + storeLow( rdiLast, t6 ); \ + storeLow( rdiLast + destStride, t7 ); \ + rdiLast += destStride * 8 + + __forceinline void storeHigh( void* rdi, __m256i v ) + { + __m128i i = _mm256_extracti128_si256( v, 1 ); + _mm_store_si128( ( __m128i* )rdi, i ); + } + +#define STORE_8X16_HIGH() \ + storeHigh( rdi, t0 ); \ + storeHigh( rdi + destStride, t1 ); \ + storeHigh( rdi + destStride * 2, t2 ); \ + rdi += destStride * 8; \ + storeHigh( rdiMid, t3 ); \ + storeHigh( rdiMid + destStride, t4 ); \ + storeHigh( rdiMid + destStride * 2, t5 ); \ + rdiMid += destStride * 8; \ + storeHigh( rdiLast, t6 ); \ + storeHigh( rdiLast + destStride, t7 ); \ + rdiLast += destStride * 8 + + __forceinline void prefetch( const uint8_t* p ) + { + _mm_prefetch( (const char*)p, prefetchHint ); + } + + __forceinline void transpose8Avx2( uint16_t* rdiWords, size_t w, const uint16_t* rsiWords, size_t sourceStride, size_t destStride ) + { + assert( 0 == ( (size_t)rdiWords ) % 16 ); + assert( 0 == destStride % 8 ); + assert( w <= sourceStride ); + + // Scale strides to bytes, and cast the pointers + sourceStride *= 2; + destStride *= 2; + uint8_t* rdi = (uint8_t*)rdiWords; + const uint8_t* rsi = (const uint8_t*)rsiWords; + + const uint8_t* const rsiEndAligned = rsi + ( w & maskAlign16 ) * 2; + const uint8_t* const rsiEnd = rsi + w * 2; + const uint8_t* rsiMid = rsi + sourceStride * 3; + const uint8_t* rsiLast = rsi + sourceStride * 6; + uint8_t* rdiMid = rdi + destStride * 3; + uint8_t* rdiLast = rdi + destStride * 6; + + while( rsi < rsiEndAligned ) + { + // Load 16x8 block into 8 registers + __m256i r0 = load( rsi ); + __m256i r1 = load( rsi + sourceStride ); + __m256i r2 = load( rsi + sourceStride * 2 ); + rsi += 32; + __m256i r3 = load( rsiMid ); + __m256i r4 = load( rsiMid + sourceStride ); + __m256i r5 = load( rsiMid + sourceStride * 2 ); + rsiMid += 32; + __m256i r6 = load( rsiLast ); + __m256i r7 = load( rsiLast + sourceStride ); + rsiLast += 32; + + // Transpose FP16 values in registers + TRANSPOSE_8X16(); + + // Store + STORE_8X16_LOW(); + STORE_8X16_HIGH(); + + if constexpr( prefetchBytes > 0 ) + { + if( rsi + prefetchBytes < rsiEnd ) + { + prefetch( rsi + prefetchBytes ); + prefetch( rsi + sourceStride + prefetchBytes ); + prefetch( rsi + sourceStride * 2 + prefetchBytes ); + prefetch( rsiMid + prefetchBytes ); + prefetch( rsiMid + sourceStride + prefetchBytes ); + prefetch( rsiMid + sourceStride * 2 + prefetchBytes ); + prefetch( rsiLast + prefetchBytes ); + prefetch( rsiLast + sourceStride + prefetchBytes ); + } + } + } + + if( rsi < rsiEnd ) + { + // Loading 8 elements into corresponding lanes of 8 vectors + // This way there's no data dependencies between these load instructions + // Out of order execution should hopefully do it's magic in the CPU, running all these loads in parallel. + __m128i r0; + __m128i r1 = _mm_setzero_si128(); + __m128i r2 = _mm_setzero_si128(); + __m128i r3 = _mm_setzero_si128(); + __m128i r4 = _mm_setzero_si128(); + __m128i r5 = _mm_setzero_si128(); + __m128i r6 = _mm_setzero_si128(); + __m128i r7 = _mm_setzero_si128(); + + __m128i t0, t1, t2, t3, t4, t5, t6; + +#pragma loop( no_vector ) + while( rsi < rsiEnd ) + { + r0 = _mm_cvtsi32_si128( *(const uint16_t*)rsi ); + r1 = _mm_insert_epi16( r1, *(const int16_t*)( rsi + sourceStride ), 1 ); + r2 = _mm_insert_epi16( r2, *(const int16_t*)( rsi + sourceStride * 2 ), 2 ); + rsi += 2; + r3 = _mm_insert_epi16( r3, *(const int16_t*)( rsiMid ), 3 ); + r4 = _mm_insert_epi16( r4, *(const int16_t*)( rsiMid + sourceStride ), 4 ); + r5 = _mm_insert_epi16( r5, *(const int16_t*)( rsiMid + sourceStride * 2 ), 5 ); + rsiMid += 2; + r6 = _mm_insert_epi16( r6, *(const int16_t*)( rsiLast ), 6 ); + r7 = _mm_insert_epi16( r7, *(const int16_t*)( rsiLast + sourceStride ), 7 ); + rsiLast += 2; + + // Bitwise operations are pretty fast, AMD Zen3 CPU can run 4 of them every clock cycle + // Combine 8 vectors into one + t0 = _mm_or_si128( r0, r1 ); + t1 = _mm_or_si128( r2, r3 ); + t2 = _mm_or_si128( r4, r5 ); + t3 = _mm_or_si128( r6, r7 ); + + t4 = _mm_or_si128( t0, t1 ); + t5 = _mm_or_si128( t2, t3 ); + + t6 = _mm_or_si128( t4, t5 ); + // Store 8 FP16 values, the destination is aligned + _mm_store_si128( ( __m128i* )rdi, t6 ); + rdi += destStride; + } + } + } + + __forceinline void transpose8PartialAvx2( uint16_t* rdiWords, size_t w, size_t h, const uint16_t* rsiWords, size_t sourceStride, size_t destStride ) + { + assert( 0 == ( (size_t)rdiWords ) % 16 ); + assert( 0 == destStride % 8 ); + assert( w <= sourceStride ); + assert( h > 0 && h < 8 ); + + // Scale strides to bytes, and cast the pointers + sourceStride *= 2; + destStride *= 2; + uint8_t* rdi = (uint8_t*)rdiWords; + const uint8_t* rsi = (const uint8_t*)rsiWords; + + const uint8_t* const rsiEndAligned = rsi + ( w & maskAlign16 ) * 2; + const uint8_t* const rsiEnd = rsi + w * 2; + const uint8_t* rsiMid = rsi + sourceStride * 3; + const uint8_t* rsiLast = rsi + sourceStride * 6; + uint8_t* rdiMid = rdi + destStride * 3; + uint8_t* rdiLast = rdi + destStride * 6; + + while( rsi < rsiEndAligned ) + { + // Load the block into 8 registers, set unused rows to zero + __m256i r0 = load( rsi ); + __m256i r1 = _mm256_setzero_si256(); + __m256i r2 = _mm256_setzero_si256(); + __m256i r3 = _mm256_setzero_si256(); + __m256i r4 = _mm256_setzero_si256(); + __m256i r5 = _mm256_setzero_si256(); + __m256i r6 = _mm256_setzero_si256(); + // These branches, whether direct or indirect, are very predictable: same outcome for all iterations of the outer loop + switch( h ) + { + case 7: + r6 = load( rsiLast ); + case 6: + r5 = load( rsiMid + sourceStride * 2 ); + case 5: + r4 = load( rsiMid + sourceStride ); + case 4: + r3 = load( rsiMid ); + case 3: + r2 = load( rsi + sourceStride * 2 ); + case 2: + r1 = load( rsi + sourceStride ); + } + rsi += 32; + rsiMid += 32; + rsiLast += 32; + + __m256i r7 = _mm256_setzero_si256(); + + // Transpose FP16 values in registers + TRANSPOSE_8X16(); + + // Store + STORE_8X16_LOW(); + + STORE_8X16_HIGH(); + } + + if( rsi < rsiEnd ) + { + // Loading 8 elements into corresponding lanes of 8 vectors + // This way there's no data dependencies between these load instructions + // Out of order execution should hopefully do it's magic in the CPU, running all these loads in parallel. + __m128i r0; + __m128i r1 = _mm_setzero_si128(); + __m128i r2 = _mm_setzero_si128(); + __m128i r3 = _mm_setzero_si128(); + __m128i r4 = _mm_setzero_si128(); + __m128i r5 = _mm_setzero_si128(); + __m128i r6 = _mm_setzero_si128(); + + __m128i t0, t1, t2, t3, t4, t5; + +#pragma loop( no_vector ) + while( rsi < rsiEnd ) + { + r0 = _mm_cvtsi32_si128( *(const uint16_t*)rsi ); + + switch( h ) + { + case 7: + r6 = _mm_insert_epi16( r6, *(const int16_t*)( rsiLast ), 6 ); + case 6: + r5 = _mm_insert_epi16( r5, *(const int16_t*)( rsiMid + sourceStride * 2 ), 5 ); + case 5: + r4 = _mm_insert_epi16( r4, *(const int16_t*)( rsiMid + sourceStride ), 4 ); + case 4: + r3 = _mm_insert_epi16( r3, *(const int16_t*)( rsiMid ), 3 ); + case 3: + r2 = _mm_insert_epi16( r2, *(const int16_t*)( rsi + sourceStride * 2 ), 2 ); + case 2: + r1 = _mm_insert_epi16( r1, *(const int16_t*)( rsi + sourceStride ), 1 ); + } + rsi += 2; + rsiMid += 2; + rsiLast += 2; + + // Bitwise operations are pretty fast, AMD Zen3 CPU can run 4 of them every clock cycle + // Combine 7 vectors into one + t0 = _mm_or_si128( r0, r1 ); + t1 = _mm_or_si128( r2, r3 ); + t2 = _mm_or_si128( r4, r5 ); + + t3 = _mm_or_si128( t0, t1 ); + t4 = _mm_or_si128( t2, r6 ); + + t5 = _mm_or_si128( t3, t4 ); + // Store 8 FP16 values, the destination is aligned + _mm_store_si128( ( __m128i* )rdi, t5 ); + rdi += destStride; + } + } + } +} + +// At least for the hybrid decoder, this method absolutely dominates the CPU time. +// And not due to the integer shuffles - the bottleneck is loading data from the source matrix. +HRESULT MulMatBase::transposePanelAvx2( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const +{ + assert( stridesA[ 0 ] == 1 ); + + const size_t heightFloats = (size_t)panelHeightRegisters * 8; + i *= heightFloats; + + const uint16_t* rsi = (const uint16_t*)pa; + rsi += m3 * stridesA[ 3 ]; + rsi += m2 * stridesA[ 2 ]; + rsi += i * stridesA[ 1 ]; + + const size_t resultStride = heightFloats; + + if( i + heightFloats <= resultSize[ 0 ] ) + { + // A complete panel + for( size_t i = 0; i < panelHeightRegisters; i++ ) + { + transpose8Avx2( rdi, length, rsi, stridesA[ 1 ], resultStride ); + // Advance by 8 floats in the output buffer + rdi += 8; + // Advance by 8 rows in the source matrix + rsi += 8 * stridesA[ 1 ]; + } + } + else + { + // A partial panel, at the bottom of the first argument matrix + const size_t remainder = resultSize[ 0 ] - i; + assert( remainder > 0 && remainder < heightFloats ); + zeroAlignedMemory( rdi, resultStride * length * sizeof( uint16_t ) ); + + const size_t completePanels = remainder / 8; + for( size_t i = 0; i < completePanels; i++ ) + { + transpose8Avx2( rdi, length, rsi, stridesA[ 1 ], resultStride ); + rdi += 8; + rsi += 8 * stridesA[ 1 ]; + } + const size_t lastPanel = remainder % 8; + if( 0 != lastPanel ) + transpose8PartialAvx2( rdi, length, lastPanel, rsi, stridesA[ 1 ], resultStride ); + } + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/CPU/mulMatImpl.cpp b/Whisper/CPU/mulMatImpl.cpp new file mode 100644 index 0000000..fc50b03 --- /dev/null +++ b/Whisper/CPU/mulMatImpl.cpp @@ -0,0 +1,213 @@ +#include "stdafx.h" +#include <intrin.h> +#include "mulMatImpl.h" +#include "mulMat.kernel.hpp" + +#define DBG_TRACK_TEMPLATE_INSTANTIATION 0 + +#if DBG_TRACK_TEMPLATE_INSTANTIATION +#include <unordered_set> +static std::unordered_set<uint16_t> g_mulMatTemplates; +#endif + +namespace +{ + using namespace CpuCompute; + + bool checkAvx2Support() + { + int cpuInfo[ 4 ]; + __cpuid( cpuInfo, 7 ); + return ( cpuInfo[ 1 ] & ( 1 << 5 ) ) != 0; + } + + // a / b, rounded up to the next integer + inline uint32_t divRoundUp( uint32_t a, uint32_t b ) + { + assert( b != 0 ); + return ( a + ( b - 1 ) ) / b; + } +} + +const bool MulMatBase::haveAvx2 = checkAvx2Support(); + +MulMatBase::MulMatBase( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor, uint8_t panelHeightRegs, uint8_t tileWidthFloats ) : + resultPointer( result.fp32() ), + pa( a.data() ), + pb( b.data() ), + runner( pfor ) +{ + length = a.ne[ 0 ]; + resultStrides[ 0 ] = result.nb[ 1 ]; + resultStrides[ 1 ] = result.nb[ 2 ]; + resultStrides[ 2 ] = result.nb[ 3 ]; + store( resultSize, result.sizeVec() ); + store( stridesA, a.stridesVec() ); + store( stridesB, b.stridesVec() ); + + countPanels = divRoundUp( resultSize[ 0 ], panelHeightRegs * 8 ); + completeTilesPerPanel = resultSize[ 1 ] / tileWidthFloats; + lastColumnsInPanel = (uint8_t)( resultSize[ 1 ] % tileWidthFloats ); + this->panelHeightRegisters = panelHeightRegs; + this->tileWidth = tileWidthFloats; + + // Pick a method which reshapes a panel of the matrix A into the shape we need to compute the product + // Store the pointer to that method in the field of this class + if( a.nb[ 0 ] == 1 ) + { + if( haveAvx2 ) + pfnMakePanel = &MulMatBase::transposePanelAvx2; + else + pfnMakePanel = &MulMatBase::transposePanel; + } + else if( a.nb[ 1 ] == 1 ) + { + switch( panelHeightRegs ) + { + case 1: + pfnMakePanel = &MulMatBase::copyPanelColumnMajor8; + break; + case 2: + pfnMakePanel = &MulMatBase::copyPanelColumnMajor16; + break; + case 4: + pfnMakePanel = &MulMatBase::copyPanelColumnMajor32; + break; + default: + throw E_NOTIMPL; + } + } + else + pfnMakePanel = &MulMatBase::gatherPanel; + + // That last version is generic and very simple, unlikely to have weird bugs + // pfnMakePanel = &MulMatBase::gatherPanel; + +#if DBG_TRACK_TEMPLATE_INSTANTIATION + uint16_t key = panelHeightRegs; + key = key << 8; + key |= tileWidthFloats; + if( !g_mulMatTemplates.emplace( key ).second ) + return; + logDebug( u8"MulMatImpl<panelHeightRegs = %i, tileWidthFloats = %i>", (int)panelHeightRegs, (int)tileWidthFloats ); +#endif +} + +HRESULT MulMatBase::run( ParallelForRunner& pfor ) +{ + size_t length = (size_t)countPanels * resultSize[ 2 ] * resultSize[ 3 ]; + return pfor.parallelFor( *this, length ); +} + +const float* MulMatBase::getLayerB( size_t m2, size_t m3 ) const +{ + const float* rsi = (const float*)this->pb; + rsi += m2 * stridesB[ 2 ]; + rsi += m3 * stridesB[ 3 ]; + return rsi; +} + +// This method is the main one, it’s called by the thread pool +template<uint8_t panelHeightRegs, uint8_t tileWidthFloats> +HRESULT __stdcall MulMatImpl<panelHeightRegs, tileWidthFloats>::compute( size_t i, size_t end ) const noexcept +{ + // Allocate a thread-local buffer for the transposed panel + constexpr size_t panelHeightFloats = panelHeightRegs * 8; + uint16_t* const panel = (uint16_t*)runner.threadLocalBuffer( floatsPerPanel() * 2 ); + const size_t resultStride = resultStrides[ 0 ]; + + // Load a few numbers from this class into local variables, while upcasting from DWORD into size_t + const size_t length = this->length; + const std::array<size_t, 2> stridesB{ this->stridesB[ 0 ], this->stridesB[ 1 ] }; + + // This outer loop iterates over the panels assigned to the current thread + // For example, matrix A of size [ 1024, 1024 ] may be split into panels of size [ 1024, 16 ] + // Each iteration of that loop computes matrix product of that panel, with the complete matrix B + for( ; i < end; i++ ) + { + const size_t iPanel = i % countPanels; + size_t j = i / countPanels; + const size_t m2 = j % (size_t)resultSize[ 2 ]; + const size_t m3 = j / (size_t)resultSize[ 2 ]; + + CHECK( ( this->*pfnMakePanel )( panel, iPanel, m2, m3 ) ); + // We got a column-major panel in the thread local buffer, of size [ length, panelHeightRegs * 8 ] + // Hopefully, these buffers should all fit at least in L3 cache + // The longest matrix I saw in the debugger had 4096 elements, with panelHeightRegs = 4 that's 256 kb of data in the panel + const float* pb = getLayerB( m2, m3 ); + float* rdi = getPanelDest( iPanel, m2, m3 ); + + const size_t storeWidth = std::min( panelHeightFloats, (size_t)resultSize[ 0 ] - iPanel * panelHeightFloats ); + std::array<__m256, panelHeightRegs> vecPanel; +#if 1 + ResultTile<panelHeightRegs, tileWidthFloats> tile; + + // This loop iterates over tiles within the panel. + // Each iteration of the loop computes an output tile of the result matrix. + for( j = 0; j < completeTilesPerPanel; j++, pb += tileWidthFloats * stridesB[ 1 ], rdi += resultStride * tileWidthFloats ) + { + setZero( tile.arr ); + const uint16_t* rsiA = panel; + const uint16_t* const rsiAEnd = panel + length * panelHeightFloats; + const float* rsiB = pb; + // This loop runs for `length` iterations, iterates over the first dimensions of both matrices, accumulating these dot products we're after + for( ; rsiA < rsiAEnd; rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] ) + { + loadPanel( rsiA, vecPanel ); + tile.kernel( vecPanel, rsiB, stridesB[ 1 ] ); + } + tile.store( rdi, storeWidth, tileWidthFloats, resultStride ); + } + + if( 0 != lastColumnsInPanel ) + { + setZero( tile.arr ); + const uint16_t* rsiA = panel; + const uint16_t* rsiAEnd = panel + length * panelHeightFloats; + const float* rsiB = pb; + for( ; rsiA < rsiAEnd; rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] ) + { + loadPanel( rsiA, vecPanel ); + tile.kernelPartial( vecPanel, rsiB, stridesB[ 1 ], lastColumnsInPanel ); + } + tile.store( rdi, storeWidth, lastColumnsInPanel, resultStride ); + } +#else + // This version bypasses horizontal tiling, instead implements a brute force algorithm to multiply the current panel by the complete B matrix + // Not terribly efficient, only implemented for debugging purposes + const size_t resHeight = resultSize[ 1 ]; + std::array<__m256, panelHeightRegs> tile; + for( size_t j = 0; j < resHeight; j++, pb += stridesB[ 1 ], rdi += resultStride ) + { + setZero( tile ); + + const uint16_t* rsiA = panel; + const uint16_t* const rsiAEnd = panel + length * panelHeightFloats; + const float* rsiB = pb; + for( size_t k = 0; k < length; k++, rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] ) + { + loadPanel( rsiA, vecPanel ); + const __m256 b = _mm256_broadcast_ss( rsiB ); + for( size_t r = 0; r < panelHeightRegs; r++ ) + tile[ r ] = _mm256_fmadd_ps( vecPanel[ r ], b, tile[ r ] ); + } + + alignas( 32 ) std::array<float, panelHeightFloats> arr; + for( size_t k = 0; k < panelHeightRegs; k++ ) + _mm256_store_ps( &arr[ k * 8 ], tile[ k ] ); + memcpy( rdi, arr.data(), storeWidth * 4 ); + } +#endif + } + return S_OK; +} + +// Instantiate the templates we need +template class MulMatImpl<4, 1>; +template class MulMatImpl<1, 1>; +template class MulMatImpl<4, 2>; +template class MulMatImpl<1, 2>; +template class MulMatImpl<2, 3>; +template class MulMatImpl<1, 3>; +template class MulMatImpl<2, 4>; +template class MulMatImpl<1, 4>;
\ No newline at end of file diff --git a/Whisper/CPU/mulMatImpl.h b/Whisper/CPU/mulMatImpl.h new file mode 100644 index 0000000..8e0062f --- /dev/null +++ b/Whisper/CPU/mulMatImpl.h @@ -0,0 +1,106 @@ +#pragma once +// Matrix*matrix multiplication is the most expensive algorithm in the model, by far. +// For this reason, the code in this source file, and in the mulMat.kernel.hpp header, is optimized for performance. Readability suffers. +// The implementation is inspired by following two articles: +// https://gist.github.com/nadavrot/5b35d44e8ba3dd718e595e40184d03f0 +// https://link.springer.com/article/10.1007/s11227-022-05003-3 +#include "ParallelForRunner.h" +#include "Tensor.h" + +namespace CpuCompute +{ + // Abstract base class for all implementations, to reduce binary size + class MulMatBase : public iComputeRange + { + protected: + // Pointers to the payload of the output matrix + float* const resultPointer; + + // Lengths of the dot products to compute, equal to width of both source matrices + uint32_t length; + + // Last 3 strides of the output matrix, expressed as count of elements. The first one is always 1 because the output matrix is continuous. + std::array<uint32_t, 3> resultStrides; + + // Size of the output matrix + std::array<uint32_t, 4> resultSize; + + // Pointers to the payload of the source matrices + const void* const pa; + const void* const pb; + + // Matrix strides, expressed as count of elements + std::array<uint32_t, 4> stridesA, stridesB; + + // Total count of panels in the layer of the output matrix. + // The last panel might be incomplete, with smaller height. + // The thread-local buffer however is always complete, unused elements will be zeros. + uint32_t countPanels; + + // Complete tiles in the length of the panel + uint32_t completeTilesPerPanel; + + // Count of the last remainder columns in the panel, can be 0 + uint8_t lastColumnsInPanel; + + // Same as panelHeightRegs template argument - height of the panels, in AVX vectors + uint8_t panelHeightRegisters; + + // Same as tileWidthFloats template argument - width of the tile, in floats + uint8_t tileWidth; + + // Method pointer to reshape a panel from the source matrix into a thread-local buffer + using pfnTransposePanel = HRESULT( MulMatBase::* )( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const; + pfnTransposePanel pfnMakePanel; + // The object which implements multithreading for this job, and supplies memory for thread-local buffers + ParallelForRunner& runner; + + // Count of FP16 values in the thread-local panel buffer + uint32_t floatsPerPanel() const + { + return length * panelHeightRegisters * 8; + } + + // Transpose a horizontal panel of the first matrix, when the rows are continuous in that matrix + HRESULT transposePanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const; + HRESULT transposePanelAvx2( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const; + // Copy a horizontal panel of the first matrix without transpose, for column major layout of that matrix + HRESULT copyPanelColumnMajor8( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const; + HRESULT copyPanelColumnMajor16( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const; + HRESULT copyPanelColumnMajor32( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const; + // Transpose a panel of the first matrix for irregular layout of that matrix, when neither rows nor columns are at sequential addresses. + // This one ain't implemented yet. + HRESULT gatherPanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const; + + const uint16_t* getPanelA( size_t i, size_t m2, size_t m3 ) const; + // Pointer to the first element of the second source matrix in the specified layer + const float* getLayerB( size_t m2, size_t m3 ) const; + + // Pointer to the first element of the output tile of the result matrix + float* getPanelDest( size_t i, size_t m2, size_t m3 ) const + { + float* rdi = resultPointer; + rdi += m2 * resultStrides[ 1 ]; + rdi += m3 * resultStrides[ 2 ]; + rdi += i * panelHeightRegisters * 8; + return rdi; + } + + static const bool haveAvx2; + public: + MulMatBase( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor, uint8_t panelHeightRegs, uint8_t tileWidthFloats ); + HRESULT run( ParallelForRunner& pfor ); + }; + + // This class actually contains the kernels implementations + template<uint8_t panelHeightRegs, uint8_t tileWidthFloats> + class MulMatImpl : public MulMatBase + { + HRESULT __stdcall compute( size_t i, size_t end ) const noexcept override final; + + public: + MulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) : + MulMatBase( result, a, b, pfor, panelHeightRegs, tileWidthFloats ) + { } + }; +}
\ No newline at end of file diff --git a/Whisper/CPU/mulMatImpl.panel.cpp b/Whisper/CPU/mulMatImpl.panel.cpp new file mode 100644 index 0000000..f3baf21 --- /dev/null +++ b/Whisper/CPU/mulMatImpl.panel.cpp @@ -0,0 +1,274 @@ +#include "stdafx.h" +#include <intrin.h> +#include "mulMatImpl.h" +#include "mulMatUtils.hpp" +using namespace CpuCompute; + +// We want to keep code size reasonable, that's why these panel reshaping methods are in the base class +HRESULT MulMatBase::transposePanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const +{ + assert( stridesA[ 0 ] == 1 ); + + const size_t heightFloats = (size_t)panelHeightRegisters * 8; + i *= heightFloats; + + const uint16_t* rsi = (const uint16_t*)pa; + rsi += m3 * stridesA[ 3 ]; + rsi += m2 * stridesA[ 2 ]; + rsi += i * stridesA[ 1 ]; + + const size_t resultStride = heightFloats; + + if( i + heightFloats <= resultSize[ 0 ] ) + { + // A complete panel + for( size_t i = 0; i < panelHeightRegisters; i++ ) + { + transpose8( rdi, length, rsi, stridesA[ 1 ], resultStride ); + // Advance by 8 floats in the output buffer + rdi += 8; + // Advance by 8 rows in the source matrix + rsi += 8 * stridesA[ 1 ]; + } + } + else + { + // A partial panel, at the bottom of the first argument matrix + const size_t remainder = resultSize[ 0 ] - i; + assert( remainder > 0 && remainder < heightFloats ); + zeroAlignedMemory( rdi, resultStride * length * sizeof( uint16_t ) ); + + const size_t completePanels = remainder / 8; + for( size_t i = 0; i < completePanels; i++ ) + { + transpose8( rdi, length, rsi, stridesA[ 1 ], resultStride ); + rdi += 8; + rsi += 8 * stridesA[ 1 ]; + } + const size_t lastPanel = remainder % 8; + if( 0 != lastPanel ) + transpose8Partial( rdi, length, lastPanel, rsi, stridesA[ 1 ], resultStride ); + } + return S_OK; +} + +inline const uint16_t* MulMatBase::getPanelA( size_t i, size_t m2, size_t m3 ) const +{ + const uint16_t* rsi = (const uint16_t*)pa; + rsi += m3 * stridesA[ 3 ]; + rsi += m2 * stridesA[ 2 ]; + rsi += i * stridesA[ 1 ]; + return rsi; +} + +HRESULT MulMatBase::copyPanelColumnMajor8( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const +{ + assert( stridesA[ 1 ] == 1 ); + assert( panelHeightRegisters == 1 ); + + constexpr size_t heightFloats = 8; + i *= heightFloats; + const uint16_t* rsi = getPanelA( i, m2, m3 ); + + constexpr size_t resultStride = heightFloats; + + if( i + heightFloats <= resultSize[ 0 ] ) + { + // A complete panel, height = 8 elements + copyColumnMajor( rdi, length, rsi, stridesA[ 0 ], resultStride ); + } + else + { + // A partial panel, at the bottom of the first argument matrix + const size_t remainder = resultSize[ 0 ] - i; + assert( remainder > 0 && remainder < heightFloats ); + copyColumnMajorPartial( rdi, length, remainder, rsi, stridesA[ 0 ], resultStride ); + } + return S_OK; +} + +__forceinline __m128i load8Partial( const uint16_t* x, size_t len ) +{ + assert( len > 0 && len < 8 ); + __m128i ix = _mm_setzero_si128(); + switch( len ) + { + case 1: // load 2 bytes + ix = _mm_cvtsi32_si128( *x ); + break; + case 2: // load 4 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + break; + case 3: // load 6 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + ix = _mm_insert_epi16( ix, x[ 2 ], 2 ); + break; + case 4: // load 8 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + break; + case 5: // load 10 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi16( ix, x[ 4 ], 4 ); + break; + case 6: // load 12 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + break; + case 7: // load 14 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + ix = _mm_insert_epi16( ix, x[ 6 ], 6 ); + break; + } + return ix; +} + +__forceinline __m256i load16Partial( const uint16_t* rsi, size_t len ) +{ + assert( len > 0 && len < 16 ); + + if( len < 8 ) + { + __m128i low = load8Partial( rsi, len ); + return _mm256_setr_m128i( low, _mm_setzero_si128() ); + } + else if( len > 8 ) + { + __m128i low = load16( (const int*)rsi ); + __m128i high = load8Partial( rsi + 8, len - 8 ); + return _mm256_setr_m128i( low, high ); + } + else + { + __m128i low = load16( (const int*)rsi ); + return _mm256_setr_m128i( low, _mm_setzero_si128() ); + } +} + +HRESULT MulMatBase::copyPanelColumnMajor16( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const +{ + assert( stridesA[ 1 ] == 1 ); + assert( panelHeightRegisters == 2 ); + + constexpr size_t heightFloats = 16; + i *= heightFloats; + + const uint16_t* rsi = getPanelA( i, m2, m3 ); + uint16_t* const rdiEnd = rdi + 16 * length; + + if( i + heightFloats <= resultSize[ 0 ] ) + { + // A complete panel, height = 16 elements + for( ; rdi < rdiEnd; rdi += 16, rsi += stridesA[ 0 ] ) + { + __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi ); + _mm256_store_si256( ( __m256i* )rdi, v ); + } + } + else + { + // A partial panel, at the bottom of the first argument matrix + const size_t remainder = resultSize[ 0 ] - i; + assert( remainder > 0 && remainder < heightFloats ); + + for( ; rdi < rdiEnd; rdi += 16, rsi += stridesA[ 0 ] ) + { + __m256i v = load16Partial( rsi, remainder ); + _mm256_store_si256( ( __m256i* )rdi, v ); + } + } + return S_OK; +} + +HRESULT MulMatBase::copyPanelColumnMajor32( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const +{ + assert( stridesA[ 1 ] == 1 ); + assert( panelHeightRegisters == 4 ); + + constexpr size_t heightFloats = 32; + i *= heightFloats; + + const uint16_t* rsi = getPanelA( i, m2, m3 ); + uint16_t* const rdiEnd = rdi + 32 * length; + + if( i + heightFloats <= resultSize[ 0 ] ) + { + // A complete panel, height = 32 elements + for( ; rdi < rdiEnd; rdi += 32, rsi += stridesA[ 0 ] ) + { + __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi ); + _mm256_store_si256( ( __m256i* )rdi, v ); + v = _mm256_loadu_si256( ( const __m256i* )( rsi + 16 ) ); + _mm256_store_si256( ( __m256i* )( rdi + 16 ), v ); + } + } + else + { + // A partial panel, at the bottom of the first argument matrix + const size_t remainder = resultSize[ 0 ] - i; + assert( remainder > 0 && remainder < heightFloats ); + + // _mm256_setzero_si256 probably compiles into vpxor, that's AVX2, we don't want that here + const __m256 zero = _mm256_setzero_ps(); + + for( ; rdi < rdiEnd; rdi += 32, rsi += stridesA[ 0 ] ) + { + if( remainder < 16 ) + { + __m256i v = load16Partial( rsi, remainder ); + _mm256_store_si256( ( __m256i* )rdi, v ); + _mm256_store_ps( (float*)( rdi + 16 ), zero ); + } + else if( remainder > 16 ) + { + __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi ); + _mm256_store_si256( ( __m256i* )rdi, v ); + v = load16Partial( rsi + 16, remainder - 16 ); + _mm256_store_si256( ( __m256i* )( rdi + 16 ), v ); + } + else + { + __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi ); + _mm256_store_si256( ( __m256i* )rdi, v ); + _mm256_store_ps( (float*)( rdi + 16 ), zero ); + } + } + } + return S_OK; +} + +HRESULT MulMatBase::gatherPanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const +{ + // BTW, I never saw this method called. + const size_t heightFloats = (size_t)panelHeightRegisters * 8; + const size_t length = this->length; + + zeroAlignedMemory( rdi, length * heightFloats * sizeof( uint16_t ) ); + + const size_t height = std::min( heightFloats, resultSize[ 0 ] - i ); + const size_t strideElement = stridesA[ 0 ]; + const size_t strideRow = stridesA[ 1 ]; + const uint16_t* rsi = getPanelA( i * heightFloats, m2, m3 ); + + if( strideElement < strideRow ) + { + for( size_t r = 0; r < height; r++, rsi += strideRow, rdi++ ) + { + const uint16_t* sourceRow = rsi; + uint16_t* destRow = rdi; + for( size_t c = 0; c < length; c++, sourceRow += strideElement, destRow += heightFloats ) + *destRow = *sourceRow; + } + } + else + { + for( size_t c = 0; c < length; c++, rsi += strideElement, rdi += heightFloats ) + { + const uint16_t* sourceCol = rsi; + uint16_t* destCol = rdi; + for( size_t r = 0; r < height; r++, sourceCol += strideRow, destCol++ ) + *destCol = *sourceCol; + } + } + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/CPU/mulMatUtils.hpp b/Whisper/CPU/mulMatUtils.hpp new file mode 100644 index 0000000..1276323 --- /dev/null +++ b/Whisper/CPU/mulMatUtils.hpp @@ -0,0 +1,301 @@ +#pragma once +#include <immintrin.h> +#include <stdint.h> +#include <assert.h> + +__forceinline __m128i f16Load( const uint16_t* rsi ) +{ + return _mm_loadu_si128( ( const __m128i* )rsi ); +} + +constexpr size_t maskAlign8 = ~(size_t)7; + +__forceinline void transpose8( uint16_t* rdi, size_t w, const uint16_t* rsi, size_t sourceStride, size_t destStride ) +{ + assert( 0 == ( (size_t)rdi ) % 16 ); + assert( 0 == destStride % 8 ); + assert( w <= sourceStride ); + + const uint16_t* const rsiEndAligned = rsi + ( w & maskAlign8 ); + const uint16_t* rsi5 = rsi + sourceStride * 5; + uint16_t* rdi5 = rdi + destStride * 5; + const size_t rem = w % 8; + for( ; rsi < rsiEndAligned; rsi += 8, rsi5 += 8, rdi += 8 * destStride, rdi5 += 8 * destStride ) + { + // Load 8x8 block into 8 registers + __m128i r0 = f16Load( rsi ); // 00, 01, 02, 03, 04, 05, 06, 07 + __m128i r1 = f16Load( rsi + sourceStride ); // 10, 11, 12, 13, 14, 15, 16, 17 + __m128i r2 = f16Load( rsi + sourceStride * 2 ); // 20, 21, 22, 23, 24, 25, 26, 27 + __m128i r3 = f16Load( rsi5 - sourceStride * 2 ); // 30, 31, 32, 33, 34, 35, 36, 37 + __m128i r4 = f16Load( rsi5 - sourceStride ); // 40, 41, 42, 43, 44, 45, 46, 47 + __m128i r5 = f16Load( rsi5 ); // 50, 51, 52, 53, 54, 55, 56, 57 + __m128i r6 = f16Load( rsi5 + sourceStride ); // 60, 61, 62, 63, 64, 65, 66, 67 + __m128i r7 = f16Load( rsi5 + sourceStride * 2 ); // 70, 71, 72, 73, 74, 75, 76, 77 + + // Transpose FP16 values in registers + __m128i t0 = _mm_unpacklo_epi16( r0, r1 ); // 00, 10, 01, 11, 02, 12, 03, 13 + __m128i t1 = _mm_unpackhi_epi16( r0, r1 ); // 04, 14, 05, 15, 06, 16, 07, 17 + __m128i t2 = _mm_unpacklo_epi16( r2, r3 ); // 20, 30, 21, 31, 22, 32, 23, 33 + __m128i t3 = _mm_unpackhi_epi16( r2, r3 ); // 24, 34, 25, 35, 26, 36, 27, 37 + __m128i t4 = _mm_unpacklo_epi16( r4, r5 ); // 40, 50, 41, 52, 42, 52, 43, 53 + __m128i t5 = _mm_unpackhi_epi16( r4, r5 ); // 44, 54, 45, 55, 46, 56, 47, 57 + __m128i t6 = _mm_unpacklo_epi16( r6, r7 ); // 60, 70, 61, 71, 62, 72, 63, 73 + __m128i t7 = _mm_unpackhi_epi16( r6, r7 ); // 64, 74, 65, 75, 66, 76, 67, 77 + + r0 = _mm_unpacklo_epi32( t0, t2 ); // 00, 10, 20, 30, 01, 11, 21, 31 + r1 = _mm_unpackhi_epi32( t0, t2 ); // 02, 12, 22, 32, 03, 13, 23, 33 + r2 = _mm_unpacklo_epi32( t1, t3 ); // 04, 14, 24, 34, 05, 15, 25, 35 + r3 = _mm_unpackhi_epi32( t1, t3 ); // 06, 16, 26, 36, 07, 17, 27, 37 + r4 = _mm_unpacklo_epi32( t4, t6 ); // 40, 50, 60, 70, 41, 51, 61, 71 + r5 = _mm_unpackhi_epi32( t4, t6 ); // 42, 52, 62, 72, 43, 53, 63, 73 + r6 = _mm_unpacklo_epi32( t5, t7 ); // 44, 54, 64, 74, 45, 55, 65, 75 + r7 = _mm_unpackhi_epi32( t5, t7 ); // 46, 56, 66, 76, 47, 57, 67, 77 + + t0 = _mm_unpacklo_epi64( r0, r4 ); // 00, 10, 20, 30, 40, 50, 60, 70 + t1 = _mm_unpackhi_epi64( r0, r4 ); // 01, 11, 21, 31, 41, 52, 61, 71 + t2 = _mm_unpacklo_epi64( r1, r5 ); // 02, 12, 22, 32, 42, 52, 62, 72 + t3 = _mm_unpackhi_epi64( r1, r5 ); // 03, 13, 23, 33, 43, 53, 63, 73 + t4 = _mm_unpacklo_epi64( r2, r6 ); + t5 = _mm_unpackhi_epi64( r2, r6 ); + t6 = _mm_unpacklo_epi64( r3, r7 ); + t7 = _mm_unpackhi_epi64( r3, r7 ); + + // Store + store16( rdi, t0 ); + store16( rdi + destStride, t1 ); + store16( rdi + destStride * 2, t2 ); + store16( rdi5 - destStride * 2, t3 ); + store16( rdi5 - destStride, t4 ); + store16( rdi5, t5 ); + store16( rdi5 + destStride, t6 ); + store16( rdi5 + destStride * 2, t7 ); + } + +#pragma loop( no_vector ) + for( size_t i = 0; i < rem; rsi++, rsi5++, rdi += destStride ) + { + const int16_t* p0 = (const int16_t*)rsi; + const int16_t* p5 = (const int16_t*)rsi5; + // Load a complete column into a vector + __m128i v = _mm_cvtsi32_si128( *rsi ); + v = _mm_insert_epi16( v, *( p0 + sourceStride ), 1 ); + v = _mm_insert_epi16( v, *( p0 + sourceStride * 2 ), 2 ); + v = _mm_insert_epi16( v, *( p5 - sourceStride * 2 ), 3 ); + v = _mm_insert_epi16( v, *( p5 - sourceStride ), 4 ); + v = _mm_insert_epi16( v, *( p5 ), 5 ); + v = _mm_insert_epi16( v, *( p5 + sourceStride ), 6 ); + v = _mm_insert_epi16( v, *( p5 + sourceStride * 2 ), 7 ); + // Store 8 FP16 values + store16( rdi, v ); + } +} + +inline void transpose8Partial( uint16_t* rdi, size_t w, size_t h, const uint16_t* rsi, size_t sourceStride, size_t destStride ) +{ + assert( 0 == ( (size_t)rdi ) % 16 ); + assert( 0 == destStride % 8 ); + assert( w <= sourceStride ); + assert( h > 0 && h < 8 ); + + const uint16_t* const rsiEndAligned = rsi + ( w & maskAlign8 ); + const uint16_t* rsi5 = rsi + sourceStride * 5; + uint16_t* rdi5 = rdi + destStride * 5; + const size_t rem = w % 8; + for( ; rsi < rsiEndAligned; rsi += 8, rsi5 += 8, rdi += 8 * destStride, rdi5 += 8 * destStride ) + { + // Load the block into 8 registers, set unused rows to zero + __m128i r0 = f16Load( rsi ); + __m128i r1 = _mm_setzero_si128(); + __m128i r2 = _mm_setzero_si128(); + __m128i r3 = _mm_setzero_si128(); + __m128i r4 = _mm_setzero_si128(); + __m128i r5 = _mm_setzero_si128(); + __m128i r6 = _mm_setzero_si128(); + // These branches, whether direct or indirect, are very predictable: same outcome for all iterations of the outer loop + switch( h ) + { + case 7: + r6 = f16Load( rsi5 + sourceStride ); + case 6: + r5 = f16Load( rsi5 ); + case 5: + r4 = f16Load( rsi5 - sourceStride ); + case 4: + r3 = f16Load( rsi5 - sourceStride * 2 ); + case 3: + r2 = f16Load( rsi + sourceStride * 2 ); + case 2: + r1 = f16Load( rsi + sourceStride ); + } + __m128i r7 = _mm_setzero_si128(); + + // Transpose FP16 values in registers + __m128i t0 = _mm_unpacklo_epi16( r0, r1 ); // 00, 10, 01, 11, 02, 12, 03, 13 + __m128i t1 = _mm_unpackhi_epi16( r0, r1 ); // 04, 14, 05, 15, 06, 16, 07, 17 + __m128i t2 = _mm_unpacklo_epi16( r2, r3 ); // 20, 30, 21, 31, 22, 32, 23, 33 + __m128i t3 = _mm_unpackhi_epi16( r2, r3 ); // 24, 34, 25, 35, 26, 36, 27, 37 + __m128i t4 = _mm_unpacklo_epi16( r4, r5 ); // 40, 50, 41, 52, 42, 52, 43, 53 + __m128i t5 = _mm_unpackhi_epi16( r4, r5 ); // 44, 54, 45, 55, 46, 56, 47, 57 + __m128i t6 = _mm_unpacklo_epi16( r6, r7 ); // 60, 70, 61, 71, 62, 72, 63, 73 + __m128i t7 = _mm_unpackhi_epi16( r6, r7 ); // 64, 74, 65, 75, 66, 76, 67, 77 + + r0 = _mm_unpacklo_epi32( t0, t2 ); // 00, 10, 20, 30, 01, 11, 21, 31 + r1 = _mm_unpackhi_epi32( t0, t2 ); // 02, 12, 22, 32, 03, 13, 23, 33 + r2 = _mm_unpacklo_epi32( t1, t3 ); // 04, 14, 24, 34, 05, 15, 25, 35 + r3 = _mm_unpackhi_epi32( t1, t3 ); // 06, 16, 26, 36, 07, 17, 27, 37 + r4 = _mm_unpacklo_epi32( t4, t6 ); // 40, 50, 60, 70, 41, 51, 61, 71 + r5 = _mm_unpackhi_epi32( t4, t6 ); // 42, 52, 62, 72, 43, 53, 63, 73 + r6 = _mm_unpacklo_epi32( t5, t7 ); // 44, 54, 64, 74, 45, 55, 65, 75 + r7 = _mm_unpackhi_epi32( t5, t7 ); // 46, 56, 66, 76, 47, 57, 67, 77 + + t0 = _mm_unpacklo_epi64( r0, r4 ); // 00, 10, 20, 30, 40, 50, 60, 70 + t1 = _mm_unpackhi_epi64( r0, r4 ); // 01, 11, 21, 31, 41, 52, 61, 71 + t2 = _mm_unpacklo_epi64( r1, r5 ); // 02, 12, 22, 32, 42, 52, 62, 72 + t3 = _mm_unpackhi_epi64( r1, r5 ); // 03, 13, 23, 33, 43, 53, 63, 73 + t4 = _mm_unpacklo_epi64( r2, r6 ); + t5 = _mm_unpackhi_epi64( r2, r6 ); + t6 = _mm_unpacklo_epi64( r3, r7 ); + t7 = _mm_unpackhi_epi64( r3, r7 ); + + // Store + store16( rdi, t0 ); + store16( rdi + destStride, t1 ); + store16( rdi + destStride * 2, t2 ); + store16( rdi5 - destStride * 2, t3 ); + store16( rdi5 - destStride, t4 ); + store16( rdi5, t5 ); + store16( rdi5 + destStride, t6 ); + store16( rdi5 + destStride * 2, t7 ); + } + +#pragma loop( no_vector ) + for( size_t i = 0; i < rem; rsi++, rsi5++, rdi += destStride ) + { + const int16_t* p0 = (const int16_t*)rsi; + const int16_t* p5 = (const int16_t*)rsi5; + // Load a partial column into vector + __m128i v = _mm_cvtsi32_si128( *rsi ); + switch( h ) + { + case 7: + v = _mm_insert_epi16( v, *( p5 + sourceStride ), 6 ); + case 6: + v = _mm_insert_epi16( v, *( p5 ), 5 ); + case 5: + v = _mm_insert_epi16( v, *( p5 - sourceStride ), 4 ); + case 4: + v = _mm_insert_epi16( v, *( p5 - sourceStride * 2 ), 3 ); + case 3: + v = _mm_insert_epi16( v, *( p0 + sourceStride * 2 ), 2 ); + case 2: + v = _mm_insert_epi16( v, *( p0 + sourceStride ), 1 ); + } + // Store 8 FP16 values + store16( rdi, v ); + } +} + +// Same as above, but skip the transpose. The source stride is distance between columns of the matrix. +__forceinline void copyColumnMajor( uint16_t* rdi, size_t w, const uint16_t* rsi, size_t sourceStride, size_t destStride ) +{ + assert( 0 == ( (size_t)rdi ) % 16 ); + assert( 0 == destStride % 8 ); + + constexpr size_t maskAlign4 = ~(size_t)3; + + const uint16_t* const rsiEndAligned = rsi + sourceStride * ( w & maskAlign4 ); + const uint16_t* const rsiEnd = rsi + sourceStride * w; + for( ; rsi < rsiEndAligned; rsi += sourceStride * 4, rdi += destStride * 4 ) + { + __m128i c = f16Load( rsi ); + store16( rdi, c ); + + c = f16Load( rsi + sourceStride ); + store16( rdi + destStride, c ); + + c = f16Load( rsi + sourceStride * 2 ); + store16( rdi + destStride * 2, c ); + + c = f16Load( rsi + sourceStride * 3 ); + store16( rdi + destStride * 3, c ); + } + + for( ; rsi < rsiEnd; rsi += sourceStride, rdi += destStride ) + { + __m128i c = f16Load( rsi ); + store16( rdi, c ); + } +} + +__forceinline __m128i loadPartial( const uint16_t* x, size_t count ) +{ + assert( count < 8 ); + __m128i ix; + switch( count ) + { + case 1: // load 2 bytes + ix = _mm_cvtsi32_si128( *x ); + break; + case 2: // load 4 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + break; + case 3: // load 6 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + ix = _mm_insert_epi16( ix, x[ 2 ], 2 ); + break; + case 4: // load 8 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + break; + case 5: // load 10 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi16( ix, x[ 4 ], 4 ); + break; + case 6: // load 12 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + break; + case 7: // load 14 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + ix = _mm_insert_epi16( ix, x[ 6 ], 6 ); + break; + default: + return _mm_setzero_si128(); + } + return ix; +} + +inline void copyColumnMajorPartial( uint16_t* rdi, size_t w, size_t h, const uint16_t* rsi, size_t sourceStride, size_t destStride ) +{ + assert( 0 == ( (size_t)rdi ) % 32 ); + assert( 0 == destStride % 8 ); + assert( h > 0 && h < 8 ); + + const uint16_t* const rsiEnd = rsi + sourceStride * w; + for( ; rsi < rsiEnd; rsi += sourceStride, rdi += destStride ) + { + // Can't use mask loads because loading 2-byte elements + // Still, that switch() in loadPartial makes a very predictable branch, same outcome for all iterations of this loop. + __m128i c = loadPartial( rsi, h ); + store16( rdi, c ); + } +} + +// Store zeros into block of memory, with aligned AVX store instructions +__forceinline void zeroAlignedMemory( void* pv, size_t cb ) +{ + assert( 0 == cb % 16 ); + assert( 0 == ( (size_t)pv % 32 ) ); + + uint8_t* rdi = (uint8_t*)pv; + constexpr size_t maskAlign32 = ~(size_t)31; + uint8_t* const rdiEndAligned = rdi + ( cb & maskAlign32 ); + uint8_t* const rdiEnd = rdi + cb; + + const __m256 zero = _mm256_setzero_ps(); + for( ; rdi < rdiEndAligned; rdi += 32 ) + _mm256_store_ps( (float*)rdi, zero ); + + if( rdi < rdiEnd ) + _mm_store_ps( (float*)rdi, _mm_setzero_ps() ); +}
\ No newline at end of file diff --git a/Whisper/CPU/simdUtils.cpp b/Whisper/CPU/simdUtils.cpp new file mode 100644 index 0000000..0e5f77d --- /dev/null +++ b/Whisper/CPU/simdUtils.cpp @@ -0,0 +1,738 @@ +#include "stdafx.h" +#include "simdUtils.h" +#include "../ML/LookupTablesData.h" +#include <cmath> +#include <memory> + +namespace +{ + constexpr size_t maskAlign8 = ~(size_t)7; + + __forceinline __m256 load8( const uint16_t* rsi ) + { + __m128i i = _mm_loadu_si128( ( const __m128i* )rsi ); + return _mm256_cvtph_ps( i ); + } + + __forceinline void loadPartial( const uint16_t* x, const uint16_t* y, size_t count, __m256& fx, __m256& fy ) + { + assert( count < 8 ); + + __m128i ix, iy; + switch( count ) + { + case 1: // load 2 bytes + ix = _mm_cvtsi32_si128( *x ); + iy = _mm_cvtsi32_si128( *y ); + break; + case 2: // load 4 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + iy = _mm_cvtsi32_si128( *(const int*)y ); + break; + case 3: // load 6 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + iy = _mm_cvtsi32_si128( *(const int*)y ); + ix = _mm_insert_epi16( ix, x[ 2 ], 2 ); + iy = _mm_insert_epi16( iy, y[ 2 ], 2 ); + break; + case 4: // load 8 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + break; + case 5: // load 10 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi16( ix, x[ 4 ], 4 ); + iy = _mm_insert_epi16( iy, y[ 4 ], 4 ); + break; + case 6: // load 12 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 ); + break; + case 7: // load 14 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 ); + ix = _mm_insert_epi16( ix, x[ 6 ], 6 ); + iy = _mm_insert_epi16( iy, y[ 6 ], 6 ); + break; + default: + fx = fy = _mm256_setzero_ps(); + return; + } + + fx = _mm256_cvtph_ps( ix ); + fy = _mm256_cvtph_ps( iy ); + } + + __forceinline __m256 loadPartial( const uint16_t* x, size_t count ) + { + assert( count < 8 ); + __m128i ix; + switch( count ) + { + case 1: // load 2 bytes + ix = _mm_cvtsi32_si128( *x ); + break; + case 2: // load 4 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + break; + case 3: // load 6 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + ix = _mm_insert_epi16( ix, x[ 2 ], 2 ); + break; + case 4: // load 8 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + break; + case 5: // load 10 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi16( ix, x[ 4 ], 4 ); + break; + case 6: // load 12 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + break; + case 7: // load 14 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + ix = _mm_insert_epi16( ix, x[ 6 ], 6 ); + break; + default: + return _mm256_setzero_ps(); + } + return _mm256_cvtph_ps( ix ); + } + + __forceinline __m128 loadFloat2( const float* rsi ) + { + return _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) ); + } + __forceinline __m128 loadFloat3( const float* rsi ) + { + __m128 f = loadFloat2( rsi ); + f = _mm_insert_ps( f, _mm_load_ss( rsi + 2 ), 0x20 ); + return f; + } + + __forceinline __m256 loadPartial( const float* rsi, size_t count ) + { + assert( count < 8 ); + __m128 low = _mm_setzero_ps(); + __m128 high = _mm_setzero_ps(); + switch( count ) + { + case 1: + low = _mm_load_ss( rsi ); + break; + case 2: + low = loadFloat2( rsi ); + break; + case 3: + low = loadFloat3( rsi ); + break; + case 4: + low = _mm_loadu_ps( rsi ); + break; + case 5: + low = _mm_loadu_ps( rsi ); + high = _mm_load_ss( rsi + 4 ); + break; + case 6: + low = _mm_loadu_ps( rsi ); + high = loadFloat2( rsi + 4 ); + break; + case 7: + low = _mm_loadu_ps( rsi ); + high = loadFloat3( rsi + 4 ); + break; + } + return _mm256_setr_m128( low, high ); + } + + __forceinline void storeFloat2( float* rdi, __m128 vec ) + { + _mm_store_sd( (double*)rdi, _mm_castps_pd( vec ) ); + } + + __forceinline void storePartial( float* rdi, __m256 vec, size_t count ) + { + assert( count < 8 ); + + __m128 tmp = _mm256_castps256_ps128( vec ); + if( count >= 4 ) + { + _mm_storeu_ps( rdi, tmp ); + if( count == 4 ) + return; + count -= 4; + rdi += 4; + tmp = _mm256_extractf128_ps( vec, 1 ); + } + + switch( count ) + { + case 1: + _mm_store_ss( rdi, tmp ); + return; + case 2: + storeFloat2( rdi, tmp ); + return; + case 3: + storeFloat2( rdi, tmp ); + ( (int*)rdi )[ 2 ] = _mm_extract_ps( tmp, 2 ); + return; + } + } +} + +void addF16to32( float* rdi, const uint16_t* a, const uint16_t* b, size_t length ) +{ + const uint16_t* const endAligned = a + ( length & maskAlign8 ); + const size_t rem = length % 8; + + for( ; a < endAligned; a += 8, b += 8, rdi += 8 ) + { + __m256 f1 = load8( a ); + __m256 f2 = load8( b ); + __m256 res = _mm256_add_ps( f1, f2 ); + _mm256_storeu_ps( rdi, res ); + } + + if( rem != 0 ) + { + __m256 f1, f2; + loadPartial( a, b, rem, f1, f2 ); + __m256 res = _mm256_add_ps( f1, f2 ); + storePartial( rdi, res, rem ); + } +} + +void addF16to32( float* rdi, const uint16_t* a, const float* b, size_t length ) +{ + const uint16_t* const endAligned = a + ( length & maskAlign8 ); + const size_t rem = length % 8; + + for( ; a < endAligned; a += 8, b += 8, rdi += 8 ) + { + __m256 f1 = load8( a ); + __m256 f2 = _mm256_loadu_ps( b ); + __m256 res = _mm256_add_ps( f1, f2 ); + _mm256_storeu_ps( rdi, res ); + } + + if( rem != 0 ) + { + __m256 f1 = loadPartial( a, rem ); + __m256 f2 = loadPartial( b, rem ); + __m256 res = _mm256_add_ps( f1, f2 ); + storePartial( rdi, res, rem ); + } +} + +alignas( 64 ) const std::array<int, 16> s_zeroTailMask = +{ + -1,-1,-1,-1,-1,-1,-1,-1, + 0, 0, 0, 0, 0, 0, 0, 0, +}; + +namespace +{ + __forceinline float horizontalSum( __m256 vec ) + { + __m128 v = _mm256_extractf128_ps( vec, 1 ); + v = _mm_add_ps( v, _mm256_castps256_ps128( vec ) ); + v = _mm_add_ps( v, _mm_movehl_ps( v, v ) ); + v = _mm_add_ss( v, _mm_movehdup_ps( v ) ); + return _mm_cvtss_f32( v ); + } +} + +void norm( float* rdi, float* temp, const float* rsi, size_t length ) +{ + assert( (size_t)temp % 32 == 0 ); + const float* rsiEndAligned = rsi + ( length & maskAlign8 ); + const size_t rem = length % 8; + + // First pass: copy to temp buffer, and compute the sum; computeVectorSum() in HLSL + __m256 sum = _mm256_setzero_ps(); + float* t; + for( t = temp; rsi < rsiEndAligned; rsi += 8, t += 8 ) + { + __m256 v = _mm256_loadu_ps( rsi ); + sum = _mm256_add_ps( sum, v ); + _mm256_store_ps( t, v ); + } + float* const tEndAligned = t; + if( 0 != rem ) + { + __m256 v = loadPartial( rsi, rem ); + sum = _mm256_add_ps( sum, v ); + _mm256_store_ps( t, v ); + t += 8; + } + + const float lengthFloat = (float)(int)length; + const float meanScalar = horizontalSum( sum ) / lengthFloat; + const __m256 mean = _mm256_set1_ps( meanScalar ); + + // Second pass, offsetAndComputeSumSquares() in HLSL + sum = _mm256_setzero_ps(); + for( t = temp; t < tEndAligned; t += 8 ) + { + __m256 v = _mm256_load_ps( t ); + v = _mm256_sub_ps( v, mean ); + _mm256_store_ps( t, v ); + sum = _mm256_fmadd_ps( v, v, sum ); + } + if( 0 != rem ) + { + __m256 v = _mm256_load_ps( t ); + v = _mm256_sub_ps( v, mean ); + v = _mm256_and_ps( v, loadTailMaskFloats( rem ) ); + _mm256_store_ps( t, v ); + sum = _mm256_fmadd_ps( v, v, sum ); + } + + // Final pass: scale, and copy from temporary buffer into the destination row + + constexpr float eps = 1e-5f; // TODO: make this a parameter + const float scaleScalar = 1.0f / std::sqrtf( horizontalSum( sum ) / lengthFloat + eps ); + const __m256 scale = _mm256_set1_ps( scaleScalar ); + + for( t = temp; t < tEndAligned; t += 8, rdi += 8 ) + { + __m256 v = _mm256_load_ps( t ); + v = _mm256_mul_ps( v, scale ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + __m256 v = _mm256_load_ps( t ); + v = _mm256_mul_ps( v, scale ); + storePartial( rdi, v, rem ); + } +} + +void fmaRepeatRow( float* rdi, size_t len, const float* w, const float* b, size_t lenPattern ) +{ + float* rdiEndAligned = rdi + ( len & maskAlign8 ); + const size_t rem = len % 8; + + if( 1 == lenPattern ) + { + const __m256 v1 = _mm256_broadcast_ss( w ); + const __m256 v2 = _mm256_broadcast_ss( b ); + for( ; rdi < rdiEndAligned; rdi += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + v = _mm256_fmadd_ps( v, v1, v2 ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + v = _mm256_fmadd_ps( v, v1, v2 ); + _mm256_maskstore_ps( rdi, mask, v ); + } + } + else if( len == lenPattern ) + { + for( ; rdi < rdiEndAligned; rdi += 8, w += 8, b += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + __m256 v1 = _mm256_loadu_ps( w ); + __m256 v2 = _mm256_loadu_ps( b ); + v = _mm256_fmadd_ps( v, v1, v2 ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + __m256 v1 = _mm256_maskload_ps( w, mask ); + __m256 v2 = _mm256_maskload_ps( b, mask ); + v = _mm256_fmadd_ps( v, v1, v2 ); + _mm256_maskstore_ps( rdi, mask, v ); + } + } + else + { + // TODO: implement if this actually happens + throw E_NOTIMPL; + } +} + +void __vectorcall addRepeatScaleRow( float* rdi, size_t len, const float* b, size_t lenPattern, const __m256 scale ) +{ + float* rdiEndAligned = rdi + ( len & maskAlign8 ); + const size_t rem = len % 8; + + if( 1 == lenPattern ) + { + const __m256 v2 = _mm256_broadcast_ss( b ); + for( ; rdi < rdiEndAligned; rdi += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + v = _mm256_add_ps( v, v2 ); + v = _mm256_mul_ps( v, scale ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + v = _mm256_add_ps( v, v2 ); + v = _mm256_mul_ps( v, scale ); + _mm256_maskstore_ps( rdi, mask, v ); + } + return; + } + else if( len == lenPattern ) + { + for( ; rdi < rdiEndAligned; rdi += 8, b += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + __m256 v2 = _mm256_loadu_ps( b ); + v = _mm256_add_ps( v, v2 ); + v = _mm256_mul_ps( v, scale ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + __m256 v2 = _mm256_maskload_ps( b, mask ); + v = _mm256_add_ps( v, v2 ); + v = _mm256_mul_ps( v, scale ); + _mm256_maskstore_ps( rdi, mask, v ); + } + return; + } + else + { + // TODO: implement if this actually happens + throw E_NOTIMPL; + } +} + +void addRepeatRow( float* rdi, size_t len, const float* b, size_t lenPattern ) +{ + float* rdiEndAligned = rdi + ( len & maskAlign8 ); + const size_t rem = len % 8; + + if( 1 == lenPattern ) + { + const __m256 v2 = _mm256_broadcast_ss( b ); + for( ; rdi < rdiEndAligned; rdi += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + v = _mm256_add_ps( v, v2 ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + v = _mm256_add_ps( v, v2 ); + _mm256_maskstore_ps( rdi, mask, v ); + } + return; + } + else if( len == lenPattern ) + { + for( ; rdi < rdiEndAligned; rdi += 8, b += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + __m256 v2 = _mm256_loadu_ps( b ); + v = _mm256_add_ps( v, v2 ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + __m256 v2 = _mm256_maskload_ps( b, mask ); + v = _mm256_add_ps( v, v2 ); + _mm256_maskstore_ps( rdi, mask, v ); + } + return; + } + else + { + // TODO: implement if this actually happens + throw E_NOTIMPL; + } +} + +namespace +{ + __forceinline __m256 gelu( __m256 x, const DirectCompute::LookupTablesData& lookup ) + { + __m128i iv = _mm256_cvtps_ph( x, 0 ); + alignas( 16 ) std::array<uint16_t, 8> arr; + _mm_store_si128( ( __m128i* )arr.data(), iv ); + for( uint16_t& a : arr ) + a = lookup.gelu[ a ]; + iv = _mm_load_si128( ( __m128i* )arr.data() ); + return _mm256_cvtph_ps( iv ); + } +} + +void addRepeatGeluRow( float* rdi, size_t len, const float* b, size_t lenPattern, const DirectCompute::LookupTablesData& lookup ) +{ + float* rdiEndAligned = rdi + ( len & maskAlign8 ); + const size_t rem = len % 8; + + if( 1 == lenPattern ) + { + const __m256 v2 = _mm256_broadcast_ss( b ); + for( ; rdi < rdiEndAligned; rdi += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + v = _mm256_add_ps( v, v2 ); + v = gelu( v, lookup ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + v = _mm256_add_ps( v, v2 ); + v = gelu( v, lookup ); + _mm256_maskstore_ps( rdi, mask, v ); + } + return; + } + else if( len == lenPattern ) + { + for( ; rdi < rdiEndAligned; rdi += 8, b += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + __m256 v2 = _mm256_loadu_ps( b ); + v = _mm256_add_ps( v, v2 ); + v = gelu( v, lookup ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + __m256 v2 = _mm256_maskload_ps( b, mask ); + v = _mm256_add_ps( v, v2 ); + v = gelu( v, lookup ); + _mm256_maskstore_ps( rdi, mask, v ); + } + return; + } + else + { + // TODO: implement if this actually happens + throw E_NOTIMPL; + } +} + +void __vectorcall scaleRow( float* rdi, size_t len, const __m256 scale ) +{ + float* rdiEndAligned = rdi + ( len & maskAlign8 ); + const size_t rem = len % 8; + for( ; rdi < rdiEndAligned; rdi += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + v = _mm256_mul_ps( v, scale ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 v = _mm256_maskload_ps( rdi, mask ); + v = _mm256_mul_ps( v, scale ); + _mm256_maskstore_ps( rdi, mask, v ); + } +} + +namespace +{ + using DirectCompute::LookupTablesData; + + __forceinline float horizontalMax( __m256 vec ) + { + __m128 v = _mm256_extractf128_ps( vec, 1 ); + v = _mm_max_ps( v, _mm256_castps256_ps128( vec ) ); + v = _mm_max_ps( v, _mm_movehl_ps( v, v ) ); + v = _mm_max_ss( v, _mm_movehdup_ps( v ) ); + return _mm_cvtss_f32( v ); + } + + __forceinline float _cvtsh_ss( uint16_t f16 ) + { + __m128i i = _mm_cvtsi32_si128( f16 ); + __m128 f = _mm_cvtph_ps( i ); + return _mm_cvtss_f32( f ); + } + + __forceinline uint16_t _cvtss_sh( float f, int rounding ) + { + assert( 0 == rounding ); + __m128 v = _mm_set_ss( f ); + __m128i i = _mm_cvtps_ph( v, 0 ); + return (uint16_t)(uint32_t)_mm_cvtsi128_si32( i ); + } +} + +const LookupTablesData& getLookupTables() +{ + static const std::unique_ptr<LookupTablesData> res = std::make_unique<LookupTablesData>(); + return *res; +} + +void softMax( float* rdi, size_t length, const float inputScale ) +{ + float* const rdiBegin = rdi; + float* const rdiEndAligned = rdi + ( length & maskAlign8 ); + const size_t remainder = length % 8; + // First pass, compute maximum + __m256 max = _mm256_set1_ps( -INFINITY ); + for( rdi = rdiBegin; rdi < rdiEndAligned; rdi += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + max = _mm256_max_ps( max, v ); + } + __m256i tailMask; + if( 0 != remainder ) + { + tailMask = loadTailMaskInt( remainder ); + __m256 v = _mm256_maskload_ps( rdi, tailMask ); + v = _mm256_max_ps( max, v ); + max = _mm256_blendv_ps( max, v, _mm256_castsi256_ps( tailMask ) ); + } + + // Second pass: apply initial scale, compute the exponent, and compute total sum over the row + const LookupTablesData& lookup = getLookupTables(); + const float maxScalar = horizontalMax( max ); + + float* const rdiEnd = rdiBegin + length; + double sum = 0; + for( rdi = rdiBegin; rdi < rdiEnd; rdi++ ) + { + // Possible to vectorize, but relatively hard + // An easy way is upcast the complete lookup table to FP32 and then use two _mm256_i32gather_ps instructions per iteration + // However, that instruction is from AVX2 set. Let's hope this loop won't be a bottleneck. + float f = *rdi; + if( f != -INFINITY ) + { + f = ( f - maxScalar ) * inputScale; + uint16_t f16 = _cvtss_sh( f, 0 ); + f16 = lookup.exponent[ f16 ]; + f = _cvtsh_ss( f16 ); + sum += f; + } + else + f = 0; + + *rdi = f; + } + + // Final pass: apply the final scale + const __m256 finalScale = _mm256_set1_ps( (float)( 1.0 / sum ) ); + for( rdi = rdiBegin; rdi < rdiEndAligned; rdi += 8 ) + { + __m256 v = _mm256_loadu_ps( rdi ); + v = _mm256_mul_ps( v, finalScale ); + _mm256_storeu_ps( rdi, v ); + } + if( 0 != remainder ) + { + __m256 v = _mm256_maskload_ps( rdi, tailMask ); + v = _mm256_mul_ps( v, finalScale ); + _mm256_maskstore_ps( rdi, tailMask, v ); + } +} + +void floatsUpcast( float* rdi, const uint16_t* rsi, size_t length ) +{ + const uint16_t* rsiEndAligned = rsi + ( length & maskAlign8 ); + const size_t rem = length % 8; + + for( ; rsi < rsiEndAligned; rsi += 8, rdi += 8 ) + _mm256_storeu_ps( rdi, load8( rsi ) ); + + if( 0 != rem ) + { + __m256 v = loadPartial( rsi, rem ); + _mm256_maskstore_ps( rdi, loadTailMaskInt( rem ), v ); + } +} + +void floatsDowncast( uint16_t* rdi, const float* rsi, size_t length ) +{ + const float* rsiEndAligned = rsi + ( length & maskAlign8 ); + size_t rem = length % 8; + + for( ; rsi < rsiEndAligned; rsi += 8, rdi += 8 ) + { + __m256 vf = _mm256_loadu_ps( rsi ); + __m128i vi = _mm256_cvtps_ph( vf, 0 ); + store16( rdi, vi ); + } + + if( 0 != rem ) + { + __m256 vf = _mm256_maskload_ps( rsi, loadTailMaskInt( rem ) ); + __m128i vi = _mm256_cvtps_ph( vf, 0 ); + for( size_t i = 0; i < rem; i++, rdi++ ) + { + *rdi = (uint16_t)(uint32_t)_mm_cvtsi128_si32( vi ); + vi = _mm_srli_si128( vi, 2 ); + } + } +} + +void addRowInPlace( float* rdi, const float* rsi, size_t length ) +{ + const float* rdiEndAligned = rdi + ( length & maskAlign8 ); + size_t rem = length % 8; + + for( ; rdi < rdiEndAligned; rdi += 8, rsi += 8 ) + { + __m256 a = _mm256_loadu_ps( rdi ); + __m256 b = _mm256_loadu_ps( rsi ); + a = _mm256_add_ps( a, b ); + _mm256_storeu_ps( rdi, a ); + } + + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 a = _mm256_maskload_ps( rdi, mask ); + __m256 b = _mm256_maskload_ps( rsi, mask ); + a = _mm256_add_ps( a, b ); + _mm256_maskstore_ps( rdi, mask, a ); + } +} + +void addRow( float* rdi, const float* a, const float* b, size_t length ) +{ + const float* aEndAligned = a + ( length & maskAlign8 ); + size_t rem = length % 8; + + for( ; a < aEndAligned; a += 8, b += 8, rdi += 8 ) + { + __m256 x = _mm256_loadu_ps( a ); + __m256 y = _mm256_loadu_ps( b ); + x = _mm256_add_ps( x, y ); + _mm256_storeu_ps( rdi, x ); + } + + if( 0 != rem ) + { + const __m256i mask = loadTailMaskInt( rem ); + __m256 x = _mm256_maskload_ps( a, mask ); + __m256 y = _mm256_maskload_ps( b, mask ); + x = _mm256_add_ps( x, y ); + _mm256_maskstore_ps( rdi, mask, x ); + } +}
\ No newline at end of file diff --git a/Whisper/CPU/simdUtils.h b/Whisper/CPU/simdUtils.h new file mode 100644 index 0000000..a7a4bac --- /dev/null +++ b/Whisper/CPU/simdUtils.h @@ -0,0 +1,82 @@ +#pragma once +#include <immintrin.h> + +void addF16to32( float* rdi, const uint16_t* a, const uint16_t* b, size_t length ); +void addF16to32( float* rdi, const uint16_t* a, const float* b, size_t length ); + +class AlignedSpan +{ + float* pointer; + +public: + AlignedSpan( void* data ) + { + size_t i = (size_t)data; + constexpr size_t mask32 = ~(size_t)31; + i = ( i + 31 ) & mask32; + pointer = (float*)i; + } + + operator float* ( ) { return pointer; } +}; + +inline size_t tempBufferForFloats( size_t count ) +{ + // Round up by 8 to be able to use full-vector loads and stores + constexpr size_t mask8 = ~(size_t)7; + count = ( count + 7 ) & mask8; + + // Add 32 more bytes to align the temporary buffer + return ( count * 4 ) + 32; +} + +#define ALIGNED_SPAN( name, countFloats ) AlignedSpan name{ _alloca( tempBufferForFloats( countFloats ) ) } + +void norm( float* rdi, float* temp, const float* rsi, size_t length ); + +void fmaRepeatRow( float* rdi, size_t len, const float* w, const float* b, size_t lenPattern ); +void __vectorcall addRepeatScaleRow( float* rdi, size_t len, const float* b, size_t lenPattern, const __m256 scale ); +void addRepeatRow( float* rdi, size_t len, const float* b, size_t lenPattern ); +void __vectorcall scaleRow( float* rdi, size_t len, const __m256 scale ); + +namespace DirectCompute +{ + struct LookupTablesData; +} +const DirectCompute::LookupTablesData& getLookupTables(); +void addRepeatGeluRow( float* rdi, size_t len, const float* b, size_t lenPattern, const DirectCompute::LookupTablesData& lookup ); + +void softMax( float* rdi, size_t length, const float inputScale ); + +// A cache line-aligned array where first 8 elements have all bits set, last 8 elements are zeros +extern const std::array<int, 16> s_zeroTailMask; + +// Load a tail mask as FP32 vector, for use with _mm256_and_ps or _mm256_blendv_ps instructions +__forceinline __m256 loadTailMaskFloats( size_t remainder ) +{ + assert( remainder > 0 && remainder < 8 ); + const float* rsi = (const float*)&s_zeroTailMask; + rsi += 8; + return _mm256_loadu_ps( rsi - remainder ); +} + +// Load a tail mask as int32 vector, for use with _mm256_maskstore_ps instruction +template<bool assertIncomplete = true> +__forceinline __m256i loadTailMaskInt( size_t remainder ) +{ + if constexpr( assertIncomplete ) + assert( remainder > 0 && remainder < 8 ); + else + assert( remainder >= 0 && remainder <= 8 ); + + const int* rsi = (const int*)&s_zeroTailMask; + rsi += 8; + return _mm256_loadu_si256( ( const __m256i* )( rsi - remainder ) ); +} + +void floatsUpcast( float* rdi, const uint16_t* rsi, size_t length ); + +void floatsDowncast( uint16_t* rdi, const float* rsi, size_t length ); + +void addRowInPlace( float* rdi, const float* rsi, size_t length ); +void addRow( float* rdi, const float* a, const float* b, size_t length );
\ No newline at end of file diff --git a/Whisper/D3D/Binder.cpp b/Whisper/D3D/Binder.cpp new file mode 100644 index 0000000..9caf7aa --- /dev/null +++ b/Whisper/D3D/Binder.cpp @@ -0,0 +1,63 @@ +#include "stdafx.h" +#include "Binder.h" +#include <algorithm> +using namespace DirectCompute; + +void Binder::bind( ID3D11ShaderResourceView* srv0, ID3D11UnorderedAccessView* uav0 ) +{ + ID3D11DeviceContext* const ctx = context(); + ctx->CSSetUnorderedAccessViews( 0, 1, &uav0, nullptr ); + + ctx->CSSetShaderResources( 0, 1, &srv0 ); + + maxSrv = std::max( maxSrv, (uint8_t)1 ); + maxUav = std::max( maxUav, (uint8_t)1 ); +} + +void Binder::bind( ID3D11UnorderedAccessView* uav0 ) +{ + context()->CSSetUnorderedAccessViews( 0, 1, &uav0, nullptr ); + maxUav = std::max( maxUav, (uint8_t)1 ); +} + +void Binder::bind( ID3D11ShaderResourceView* srv0, ID3D11ShaderResourceView* srv1, ID3D11UnorderedAccessView* uav0 ) +{ + ID3D11DeviceContext* const ctx = context(); + ctx->CSSetUnorderedAccessViews( 0, 1, &uav0, nullptr ); + + std::array< ID3D11ShaderResourceView*, 2> arr = { srv0, srv1 }; + ctx->CSSetShaderResources( 0, 2, arr.data() ); + + maxSrv = std::max( maxSrv, (uint8_t)2 ); + maxUav = std::max( maxUav, (uint8_t)1 ); +} + +void Binder::bind( std::initializer_list<ID3D11ShaderResourceView*> srvs, std::initializer_list<ID3D11UnorderedAccessView*> uavs ) +{ + ID3D11DeviceContext* const ctx = context(); + + const size_t lengthResources = srvs.size(); + const size_t lengthUnordered = uavs.size(); + assert( lengthResources > 0 && lengthResources < D3D11_COMMONSHADER_INPUT_RESOURCE_REGISTER_COUNT ); + assert( lengthUnordered > 0 && lengthUnordered < D3D11_PS_CS_UAV_REGISTER_COUNT ); + + ctx->CSSetUnorderedAccessViews( 0, (UINT)lengthUnordered, uavs.begin(), nullptr ); + ctx->CSSetShaderResources( 0, (UINT)lengthResources, srvs.begin() ); + + maxSrv = std::max( maxSrv, (uint8_t)lengthResources ); + maxUav = std::max( maxUav, (uint8_t)lengthUnordered ); +} + +Binder::~Binder() +{ + uint8_t count = std::max( maxSrv, maxUav ); + if( 0 == count ) + return; +#pragma warning (disable: 6255) // Compiler doesn't know we have very few of these things + size_t* arr = (size_t*)_alloca( count * sizeof( size_t ) ); + memset( arr, 0, count * sizeof( size_t ) ); + + ID3D11DeviceContext* const ctx = context(); + ctx->CSSetShaderResources( 0, maxSrv, (ID3D11ShaderResourceView**)arr ); + ctx->CSSetUnorderedAccessViews( 0, maxUav, (ID3D11UnorderedAccessView**)arr, nullptr ); +}
\ No newline at end of file diff --git a/Whisper/D3D/Binder.h b/Whisper/D3D/Binder.h new file mode 100644 index 0000000..bf7ffb2 --- /dev/null +++ b/Whisper/D3D/Binder.h @@ -0,0 +1,21 @@ +#pragma once +#include "device.h" + +namespace DirectCompute +{ + class Binder + { + uint8_t maxSrv = 0; + uint8_t maxUav = 0; + + public: + Binder() = default; + Binder( const Binder& ) = delete; + + void bind( ID3D11ShaderResourceView* srv0, ID3D11UnorderedAccessView* uav0 ); + void bind( ID3D11ShaderResourceView* srv0, ID3D11ShaderResourceView* srv1, ID3D11UnorderedAccessView* uav0 ); + void bind( std::initializer_list<ID3D11ShaderResourceView*> srvs, std::initializer_list<ID3D11UnorderedAccessView*> uavs ); + void bind( ID3D11UnorderedAccessView* uav0 ); + ~Binder(); + }; +}
\ No newline at end of file diff --git a/Whisper/D3D/MappedResource.cpp b/Whisper/D3D/MappedResource.cpp new file mode 100644 index 0000000..d6e8119 --- /dev/null +++ b/Whisper/D3D/MappedResource.cpp @@ -0,0 +1,33 @@ +#include "stdafx.h" +#include "MappedResource.h" +using namespace DirectCompute; +#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } + +MappedResource::MappedResource() +{ + mapped.pData = nullptr; + mapped.RowPitch = mapped.DepthPitch = 0; + resource = nullptr; +} + +HRESULT MappedResource::map( ID3D11Resource* res, bool reading ) +{ + if( nullptr == resource ) + { + D3D11_MAP mt = reading ? D3D11_MAP_READ : D3D11_MAP_WRITE_DISCARD; + CHECK( context()->Map( res, 0, mt, 0, &mapped ) ); + resource = res; + return S_OK; + } + return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); +} + +MappedResource::~MappedResource() +{ + if( nullptr != resource ) + { + context()->Unmap( resource, 0 ); + resource = nullptr; + mapped.pData = nullptr; + } +}
\ No newline at end of file diff --git a/Whisper/D3D/MappedResource.h b/Whisper/D3D/MappedResource.h new file mode 100644 index 0000000..a6b046b --- /dev/null +++ b/Whisper/D3D/MappedResource.h @@ -0,0 +1,22 @@ +#pragma once +#include "device.h" +#include <assert.h> + +namespace DirectCompute +{ + class MappedResource + { + D3D11_MAPPED_SUBRESOURCE mapped; + ID3D11Resource* resource; + public: + MappedResource(); + HRESULT map( ID3D11Resource* res, bool reading ); + ~MappedResource(); + + void* data() const + { + assert( nullptr != mapped.pData ); + return mapped.pData; + } + }; +}
\ No newline at end of file diff --git a/Whisper/D3D/RenderDoc/renderDoc.cpp b/Whisper/D3D/RenderDoc/renderDoc.cpp new file mode 100644 index 0000000..e811c1e --- /dev/null +++ b/Whisper/D3D/RenderDoc/renderDoc.cpp @@ -0,0 +1,72 @@ +#include "stdafx.h" +#include "renderDoc.h" +#include "renderdoc_app.h" +#include "../device.h" + +#define ENABLE_RENDERDOC_DEBUGGER 1 + +#if ENABLE_RENDERDOC_DEBUGGER +namespace +{ + static HMODULE hmRenderDoc = nullptr; + static RENDERDOC_API_1_6_0* api = nullptr; +} + +bool DirectCompute::initializeRenderDoc() +{ + hmRenderDoc = GetModuleHandleW( L"renderdoc.dll" ); + if( nullptr == hmRenderDoc ) + return false; + + pRENDERDOC_GetAPI getApi = (pRENDERDOC_GetAPI)GetProcAddress( hmRenderDoc, "RENDERDOC_GetAPI" ); + if( nullptr == getApi ) + return false; + if( 1 != getApi( eRENDERDOC_API_Version_1_6_0, (void**)&api ) ) + return false; + if( nullptr == api ) + return false; + + return true; +} + +namespace +{ + using namespace DirectCompute; + inline bool isKeyPressed( int vKey ) + { + return 0 != ( GetAsyncKeyState( vKey ) & 0x8000 ); + } +} + +CaptureRaii::CaptureRaii() : capturing( false ) +{ + if( nullptr == api ) + return; + if( !isKeyPressed( VK_F12 ) ) + return; + ID3D11Device* const dev = device(); + if( nullptr == dev ) + return; + + api->StartFrameCapture( dev, nullptr ); + capturing = true; +} + +CaptureRaii::~CaptureRaii() +{ + if( !capturing ) + return; + api->EndFrameCapture( device(), nullptr ); +} +#else // !ENABLE_RENDERDOC_DEBUGGER +bool DirectCompute::initializeRenderDoc() +{ + return false; +} +DirectCompute::CaptureRaii::CaptureRaii() : capturing( false ) +{ +} +DirectCompute::CaptureRaii::~CaptureRaii() +{ +} +#endif
\ No newline at end of file diff --git a/Whisper/D3D/RenderDoc/renderDoc.h b/Whisper/D3D/RenderDoc/renderDoc.h new file mode 100644 index 0000000..791052d --- /dev/null +++ b/Whisper/D3D/RenderDoc/renderDoc.h @@ -0,0 +1,15 @@ +#pragma once + +namespace DirectCompute +{ + bool initializeRenderDoc(); + + class CaptureRaii + { + bool capturing; + public: + CaptureRaii(); + CaptureRaii( const CaptureRaii& ) = delete; + ~CaptureRaii(); + }; +}
\ No newline at end of file diff --git a/Whisper/D3D/RenderDoc/renderdoc_app.h b/Whisper/D3D/RenderDoc/renderdoc_app.h new file mode 100644 index 0000000..402dd3d --- /dev/null +++ b/Whisper/D3D/RenderDoc/renderdoc_app.h @@ -0,0 +1,724 @@ +/****************************************************************************** + * The MIT License (MIT) + * + * Copyright (c) 2019-2022 Baldur Karlsson + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + ******************************************************************************/ + +#pragma once + +////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Documentation for the API is available at https://renderdoc.org/docs/in_application_api.html +// + +#if !defined(RENDERDOC_NO_STDINT) +#include <stdint.h> +#endif + +#if defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) +#define RENDERDOC_CC __cdecl +#elif defined(__linux__) +#define RENDERDOC_CC +#elif defined(__APPLE__) +#define RENDERDOC_CC +#else +#error "Unknown platform" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Constants not used directly in below API + +// This is a GUID/magic value used for when applications pass a path where shader debug +// information can be found to match up with a stripped shader. +// the define can be used like so: const GUID RENDERDOC_ShaderDebugMagicValue = +// RENDERDOC_ShaderDebugMagicValue_value +#define RENDERDOC_ShaderDebugMagicValue_struct \ + { \ + 0xeab25520, 0x6670, 0x4865, 0x84, 0x29, 0x6c, 0x8, 0x51, 0x54, 0x00, 0xff \ + } + +// as an alternative when you want a byte array (assuming x86 endianness): +#define RENDERDOC_ShaderDebugMagicValue_bytearray \ + { \ + 0x20, 0x55, 0xb2, 0xea, 0x70, 0x66, 0x65, 0x48, 0x84, 0x29, 0x6c, 0x8, 0x51, 0x54, 0x00, 0xff \ + } + +// truncated version when only a uint64_t is available (e.g. Vulkan tags): +#define RENDERDOC_ShaderDebugMagicValue_truncated 0x48656670eab25520ULL + +////////////////////////////////////////////////////////////////////////////////////////////////// +// RenderDoc capture options +// + +typedef enum RENDERDOC_CaptureOption { + // Allow the application to enable vsync + // + // Default - enabled + // + // 1 - The application can enable or disable vsync at will + // 0 - vsync is force disabled + eRENDERDOC_Option_AllowVSync = 0, + + // Allow the application to enable fullscreen + // + // Default - enabled + // + // 1 - The application can enable or disable fullscreen at will + // 0 - fullscreen is force disabled + eRENDERDOC_Option_AllowFullscreen = 1, + + // Record API debugging events and messages + // + // Default - disabled + // + // 1 - Enable built-in API debugging features and records the results into + // the capture, which is matched up with events on replay + // 0 - no API debugging is forcibly enabled + eRENDERDOC_Option_APIValidation = 2, + eRENDERDOC_Option_DebugDeviceMode = 2, // deprecated name of this enum + + // Capture CPU callstacks for API events + // + // Default - disabled + // + // 1 - Enables capturing of callstacks + // 0 - no callstacks are captured + eRENDERDOC_Option_CaptureCallstacks = 3, + + // When capturing CPU callstacks, only capture them from actions. + // This option does nothing without the above option being enabled + // + // Default - disabled + // + // 1 - Only captures callstacks for actions. + // Ignored if CaptureCallstacks is disabled + // 0 - Callstacks, if enabled, are captured for every event. + eRENDERDOC_Option_CaptureCallstacksOnlyDraws = 4, + eRENDERDOC_Option_CaptureCallstacksOnlyActions = 4, + + // Specify a delay in seconds to wait for a debugger to attach, after + // creating or injecting into a process, before continuing to allow it to run. + // + // 0 indicates no delay, and the process will run immediately after injection + // + // Default - 0 seconds + // + eRENDERDOC_Option_DelayForDebugger = 5, + + // Verify buffer access. This includes checking the memory returned by a Map() call to + // detect any out-of-bounds modification, as well as initialising buffers with undefined contents + // to a marker value to catch use of uninitialised memory. + // + // NOTE: This option is only valid for OpenGL and D3D11. Explicit APIs such as D3D12 and Vulkan do + // not do the same kind of interception & checking and undefined contents are really undefined. + // + // Default - disabled + // + // 1 - Verify buffer access + // 0 - No verification is performed, and overwriting bounds may cause crashes or corruption in + // RenderDoc. + eRENDERDOC_Option_VerifyBufferAccess = 6, + + // The old name for eRENDERDOC_Option_VerifyBufferAccess was eRENDERDOC_Option_VerifyMapWrites. + // This option now controls the filling of uninitialised buffers with 0xdddddddd which was + // previously always enabled + eRENDERDOC_Option_VerifyMapWrites = eRENDERDOC_Option_VerifyBufferAccess, + + // Hooks any system API calls that create child processes, and injects + // RenderDoc into them recursively with the same options. + // + // Default - disabled + // + // 1 - Hooks into spawned child processes + // 0 - Child processes are not hooked by RenderDoc + eRENDERDOC_Option_HookIntoChildren = 7, + + // By default RenderDoc only includes resources in the final capture necessary + // for that frame, this allows you to override that behaviour. + // + // Default - disabled + // + // 1 - all live resources at the time of capture are included in the capture + // and available for inspection + // 0 - only the resources referenced by the captured frame are included + eRENDERDOC_Option_RefAllResources = 8, + + // **NOTE**: As of RenderDoc v1.1 this option has been deprecated. Setting or + // getting it will be ignored, to allow compatibility with older versions. + // In v1.1 the option acts as if it's always enabled. + // + // By default RenderDoc skips saving initial states for resources where the + // previous contents don't appear to be used, assuming that writes before + // reads indicate previous contents aren't used. + // + // Default - disabled + // + // 1 - initial contents at the start of each captured frame are saved, even if + // they are later overwritten or cleared before being used. + // 0 - unless a read is detected, initial contents will not be saved and will + // appear as black or empty data. + eRENDERDOC_Option_SaveAllInitials = 9, + + // In APIs that allow for the recording of command lists to be replayed later, + // RenderDoc may choose to not capture command lists before a frame capture is + // triggered, to reduce overheads. This means any command lists recorded once + // and replayed many times will not be available and may cause a failure to + // capture. + // + // NOTE: This is only true for APIs where multithreading is difficult or + // discouraged. Newer APIs like Vulkan and D3D12 will ignore this option + // and always capture all command lists since the API is heavily oriented + // around it and the overheads have been reduced by API design. + // + // 1 - All command lists are captured from the start of the application + // 0 - Command lists are only captured if their recording begins during + // the period when a frame capture is in progress. + eRENDERDOC_Option_CaptureAllCmdLists = 10, + + // Mute API debugging output when the API validation mode option is enabled + // + // Default - enabled + // + // 1 - Mute any API debug messages from being displayed or passed through + // 0 - API debugging is displayed as normal + eRENDERDOC_Option_DebugOutputMute = 11, + + // Option to allow vendor extensions to be used even when they may be + // incompatible with RenderDoc and cause corrupted replays or crashes. + // + // Default - inactive + // + // No values are documented, this option should only be used when absolutely + // necessary as directed by a RenderDoc developer. + eRENDERDOC_Option_AllowUnsupportedVendorExtensions = 12, + +} RENDERDOC_CaptureOption; + +// Sets an option that controls how RenderDoc behaves on capture. +// +// Returns 1 if the option and value are valid +// Returns 0 if either is invalid and the option is unchanged +typedef int(RENDERDOC_CC *pRENDERDOC_SetCaptureOptionU32)(RENDERDOC_CaptureOption opt, uint32_t val); +typedef int(RENDERDOC_CC *pRENDERDOC_SetCaptureOptionF32)(RENDERDOC_CaptureOption opt, float val); + +// Gets the current value of an option as a uint32_t +// +// If the option is invalid, 0xffffffff is returned +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_GetCaptureOptionU32)(RENDERDOC_CaptureOption opt); + +// Gets the current value of an option as a float +// +// If the option is invalid, -FLT_MAX is returned +typedef float(RENDERDOC_CC *pRENDERDOC_GetCaptureOptionF32)(RENDERDOC_CaptureOption opt); + +typedef enum RENDERDOC_InputButton { + // '0' - '9' matches ASCII values + eRENDERDOC_Key_0 = 0x30, + eRENDERDOC_Key_1 = 0x31, + eRENDERDOC_Key_2 = 0x32, + eRENDERDOC_Key_3 = 0x33, + eRENDERDOC_Key_4 = 0x34, + eRENDERDOC_Key_5 = 0x35, + eRENDERDOC_Key_6 = 0x36, + eRENDERDOC_Key_7 = 0x37, + eRENDERDOC_Key_8 = 0x38, + eRENDERDOC_Key_9 = 0x39, + + // 'A' - 'Z' matches ASCII values + eRENDERDOC_Key_A = 0x41, + eRENDERDOC_Key_B = 0x42, + eRENDERDOC_Key_C = 0x43, + eRENDERDOC_Key_D = 0x44, + eRENDERDOC_Key_E = 0x45, + eRENDERDOC_Key_F = 0x46, + eRENDERDOC_Key_G = 0x47, + eRENDERDOC_Key_H = 0x48, + eRENDERDOC_Key_I = 0x49, + eRENDERDOC_Key_J = 0x4A, + eRENDERDOC_Key_K = 0x4B, + eRENDERDOC_Key_L = 0x4C, + eRENDERDOC_Key_M = 0x4D, + eRENDERDOC_Key_N = 0x4E, + eRENDERDOC_Key_O = 0x4F, + eRENDERDOC_Key_P = 0x50, + eRENDERDOC_Key_Q = 0x51, + eRENDERDOC_Key_R = 0x52, + eRENDERDOC_Key_S = 0x53, + eRENDERDOC_Key_T = 0x54, + eRENDERDOC_Key_U = 0x55, + eRENDERDOC_Key_V = 0x56, + eRENDERDOC_Key_W = 0x57, + eRENDERDOC_Key_X = 0x58, + eRENDERDOC_Key_Y = 0x59, + eRENDERDOC_Key_Z = 0x5A, + + // leave the rest of the ASCII range free + // in case we want to use it later + eRENDERDOC_Key_NonPrintable = 0x100, + + eRENDERDOC_Key_Divide, + eRENDERDOC_Key_Multiply, + eRENDERDOC_Key_Subtract, + eRENDERDOC_Key_Plus, + + eRENDERDOC_Key_F1, + eRENDERDOC_Key_F2, + eRENDERDOC_Key_F3, + eRENDERDOC_Key_F4, + eRENDERDOC_Key_F5, + eRENDERDOC_Key_F6, + eRENDERDOC_Key_F7, + eRENDERDOC_Key_F8, + eRENDERDOC_Key_F9, + eRENDERDOC_Key_F10, + eRENDERDOC_Key_F11, + eRENDERDOC_Key_F12, + + eRENDERDOC_Key_Home, + eRENDERDOC_Key_End, + eRENDERDOC_Key_Insert, + eRENDERDOC_Key_Delete, + eRENDERDOC_Key_PageUp, + eRENDERDOC_Key_PageDn, + + eRENDERDOC_Key_Backspace, + eRENDERDOC_Key_Tab, + eRENDERDOC_Key_PrtScrn, + eRENDERDOC_Key_Pause, + + eRENDERDOC_Key_Max, +} RENDERDOC_InputButton; + +// Sets which key or keys can be used to toggle focus between multiple windows +// +// If keys is NULL or num is 0, toggle keys will be disabled +typedef void(RENDERDOC_CC *pRENDERDOC_SetFocusToggleKeys)(RENDERDOC_InputButton *keys, int num); + +// Sets which key or keys can be used to capture the next frame +// +// If keys is NULL or num is 0, captures keys will be disabled +typedef void(RENDERDOC_CC *pRENDERDOC_SetCaptureKeys)(RENDERDOC_InputButton *keys, int num); + +typedef enum RENDERDOC_OverlayBits { + // This single bit controls whether the overlay is enabled or disabled globally + eRENDERDOC_Overlay_Enabled = 0x1, + + // Show the average framerate over several seconds as well as min/max + eRENDERDOC_Overlay_FrameRate = 0x2, + + // Show the current frame number + eRENDERDOC_Overlay_FrameNumber = 0x4, + + // Show a list of recent captures, and how many captures have been made + eRENDERDOC_Overlay_CaptureList = 0x8, + + // Default values for the overlay mask + eRENDERDOC_Overlay_Default = (eRENDERDOC_Overlay_Enabled | eRENDERDOC_Overlay_FrameRate | + eRENDERDOC_Overlay_FrameNumber | eRENDERDOC_Overlay_CaptureList), + + // Enable all bits + eRENDERDOC_Overlay_All = ~0U, + + // Disable all bits + eRENDERDOC_Overlay_None = 0, +} RENDERDOC_OverlayBits; + +// returns the overlay bits that have been set +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_GetOverlayBits)(); +// sets the overlay bits with an and & or mask +typedef void(RENDERDOC_CC *pRENDERDOC_MaskOverlayBits)(uint32_t And, uint32_t Or); + +// this function will attempt to remove RenderDoc's hooks in the application. +// +// Note: that this can only work correctly if done immediately after +// the module is loaded, before any API work happens. RenderDoc will remove its +// injected hooks and shut down. Behaviour is undefined if this is called +// after any API functions have been called, and there is still no guarantee of +// success. +typedef void(RENDERDOC_CC *pRENDERDOC_RemoveHooks)(); + +// DEPRECATED: compatibility for code compiled against pre-1.4.1 headers. +typedef pRENDERDOC_RemoveHooks pRENDERDOC_Shutdown; + +// This function will unload RenderDoc's crash handler. +// +// If you use your own crash handler and don't want RenderDoc's handler to +// intercede, you can call this function to unload it and any unhandled +// exceptions will pass to the next handler. +typedef void(RENDERDOC_CC *pRENDERDOC_UnloadCrashHandler)(); + +// Sets the capture file path template +// +// pathtemplate is a UTF-8 string that gives a template for how captures will be named +// and where they will be saved. +// +// Any extension is stripped off the path, and captures are saved in the directory +// specified, and named with the filename and the frame number appended. If the +// directory does not exist it will be created, including any parent directories. +// +// If pathtemplate is NULL, the template will remain unchanged +// +// Example: +// +// SetCaptureFilePathTemplate("my_captures/example"); +// +// Capture #1 -> my_captures/example_frame123.rdc +// Capture #2 -> my_captures/example_frame456.rdc +typedef void(RENDERDOC_CC *pRENDERDOC_SetCaptureFilePathTemplate)(const char *pathtemplate); + +// returns the current capture path template, see SetCaptureFileTemplate above, as a UTF-8 string +typedef const char *(RENDERDOC_CC *pRENDERDOC_GetCaptureFilePathTemplate)(); + +// DEPRECATED: compatibility for code compiled against pre-1.1.2 headers. +typedef pRENDERDOC_SetCaptureFilePathTemplate pRENDERDOC_SetLogFilePathTemplate; +typedef pRENDERDOC_GetCaptureFilePathTemplate pRENDERDOC_GetLogFilePathTemplate; + +// returns the number of captures that have been made +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_GetNumCaptures)(); + +// This function returns the details of a capture, by index. New captures are added +// to the end of the list. +// +// filename will be filled with the absolute path to the capture file, as a UTF-8 string +// pathlength will be written with the length in bytes of the filename string +// timestamp will be written with the time of the capture, in seconds since the Unix epoch +// +// Any of the parameters can be NULL and they'll be skipped. +// +// The function will return 1 if the capture index is valid, or 0 if the index is invalid +// If the index is invalid, the values will be unchanged +// +// Note: when captures are deleted in the UI they will remain in this list, so the +// capture path may not exist anymore. +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_GetCapture)(uint32_t idx, char *filename, + uint32_t *pathlength, uint64_t *timestamp); + +// Sets the comments associated with a capture file. These comments are displayed in the +// UI program when opening. +// +// filePath should be a path to the capture file to add comments to. If set to NULL or "" +// the most recent capture file created made will be used instead. +// comments should be a NULL-terminated UTF-8 string to add as comments. +// +// Any existing comments will be overwritten. +typedef void(RENDERDOC_CC *pRENDERDOC_SetCaptureFileComments)(const char *filePath, + const char *comments); + +// returns 1 if the RenderDoc UI is connected to this application, 0 otherwise +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_IsTargetControlConnected)(); + +// DEPRECATED: compatibility for code compiled against pre-1.1.1 headers. +// This was renamed to IsTargetControlConnected in API 1.1.1, the old typedef is kept here for +// backwards compatibility with old code, it is castable either way since it's ABI compatible +// as the same function pointer type. +typedef pRENDERDOC_IsTargetControlConnected pRENDERDOC_IsRemoteAccessConnected; + +// This function will launch the Replay UI associated with the RenderDoc library injected +// into the running application. +// +// if connectTargetControl is 1, the Replay UI will be launched with a command line parameter +// to connect to this application +// cmdline is the rest of the command line, as a UTF-8 string. E.g. a captures to open +// if cmdline is NULL, the command line will be empty. +// +// returns the PID of the replay UI if successful, 0 if not successful. +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_LaunchReplayUI)(uint32_t connectTargetControl, + const char *cmdline); + +// RenderDoc can return a higher version than requested if it's backwards compatible, +// this function returns the actual version returned. If a parameter is NULL, it will be +// ignored and the others will be filled out. +typedef void(RENDERDOC_CC *pRENDERDOC_GetAPIVersion)(int *major, int *minor, int *patch); + +// Requests that the replay UI show itself (if hidden or not the current top window). This can be +// used in conjunction with IsTargetControlConnected and LaunchReplayUI to intelligently handle +// showing the UI after making a capture. +// +// This will return 1 if the request was successfully passed on, though it's not guaranteed that +// the UI will be on top in all cases depending on OS rules. It will return 0 if there is no current +// target control connection to make such a request, or if there was another error +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_ShowReplayUI)(); + +////////////////////////////////////////////////////////////////////////// +// Capturing functions +// + +// A device pointer is a pointer to the API's root handle. +// +// This would be an ID3D11Device, HGLRC/GLXContext, ID3D12Device, etc +typedef void *RENDERDOC_DevicePointer; + +// A window handle is the OS's native window handle +// +// This would be an HWND, GLXDrawable, etc +typedef void *RENDERDOC_WindowHandle; + +// A helper macro for Vulkan, where the device handle cannot be used directly. +// +// Passing the VkInstance to this macro will return the RENDERDOC_DevicePointer to use. +// +// Specifically, the value needed is the dispatch table pointer, which sits as the first +// pointer-sized object in the memory pointed to by the VkInstance. Thus we cast to a void** and +// indirect once. +#define RENDERDOC_DEVICEPOINTER_FROM_VKINSTANCE(inst) (*((void **)(inst))) + +// This sets the RenderDoc in-app overlay in the API/window pair as 'active' and it will +// respond to keypresses. Neither parameter can be NULL +typedef void(RENDERDOC_CC *pRENDERDOC_SetActiveWindow)(RENDERDOC_DevicePointer device, + RENDERDOC_WindowHandle wndHandle); + +// capture the next frame on whichever window and API is currently considered active +typedef void(RENDERDOC_CC *pRENDERDOC_TriggerCapture)(); + +// capture the next N frames on whichever window and API is currently considered active +typedef void(RENDERDOC_CC *pRENDERDOC_TriggerMultiFrameCapture)(uint32_t numFrames); + +// When choosing either a device pointer or a window handle to capture, you can pass NULL. +// Passing NULL specifies a 'wildcard' match against anything. This allows you to specify +// any API rendering to a specific window, or a specific API instance rendering to any window, +// or in the simplest case of one window and one API, you can just pass NULL for both. +// +// In either case, if there are two or more possible matching (device,window) pairs it +// is undefined which one will be captured. +// +// Note: for headless rendering you can pass NULL for the window handle and either specify +// a device pointer or leave it NULL as above. + +// Immediately starts capturing API calls on the specified device pointer and window handle. +// +// If there is no matching thing to capture (e.g. no supported API has been initialised), +// this will do nothing. +// +// The results are undefined (including crashes) if two captures are started overlapping, +// even on separate devices and/oror windows. +typedef void(RENDERDOC_CC *pRENDERDOC_StartFrameCapture)(RENDERDOC_DevicePointer device, + RENDERDOC_WindowHandle wndHandle); + +// Returns whether or not a frame capture is currently ongoing anywhere. +// +// This will return 1 if a capture is ongoing, and 0 if there is no capture running +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_IsFrameCapturing)(); + +// Ends capturing immediately. +// +// This will return 1 if the capture succeeded, and 0 if there was an error capturing. +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_EndFrameCapture)(RENDERDOC_DevicePointer device, + RENDERDOC_WindowHandle wndHandle); + +// Ends capturing immediately and discard any data stored without saving to disk. +// +// This will return 1 if the capture was discarded, and 0 if there was an error or no capture +// was in progress +typedef uint32_t(RENDERDOC_CC *pRENDERDOC_DiscardFrameCapture)(RENDERDOC_DevicePointer device, + RENDERDOC_WindowHandle wndHandle); + +// Only valid to be called between a call to StartFrameCapture and EndFrameCapture. Gives a custom +// title to the capture produced which will be displayed in the UI. +// +// If multiple captures are ongoing, this title will be applied to the first capture to end after +// this call. The second capture to end will have no title, unless this function is called again. +// +// Calling this function has no effect if no capture is currently running, and if it is called +// multiple times only the last title will be used. +typedef void(RENDERDOC_CC *pRENDERDOC_SetCaptureTitle)(const char *title); + +////////////////////////////////////////////////////////////////////////////////////////////////// +// RenderDoc API versions +// + +// RenderDoc uses semantic versioning (http://semver.org/). +// +// MAJOR version is incremented when incompatible API changes happen. +// MINOR version is incremented when functionality is added in a backwards-compatible manner. +// PATCH version is incremented when backwards-compatible bug fixes happen. +// +// Note that this means the API returned can be higher than the one you might have requested. +// e.g. if you are running against a newer RenderDoc that supports 1.0.1, it will be returned +// instead of 1.0.0. You can check this with the GetAPIVersion entry point +typedef enum RENDERDOC_Version { + eRENDERDOC_API_Version_1_0_0 = 10000, // RENDERDOC_API_1_0_0 = 1 00 00 + eRENDERDOC_API_Version_1_0_1 = 10001, // RENDERDOC_API_1_0_1 = 1 00 01 + eRENDERDOC_API_Version_1_0_2 = 10002, // RENDERDOC_API_1_0_2 = 1 00 02 + eRENDERDOC_API_Version_1_1_0 = 10100, // RENDERDOC_API_1_1_0 = 1 01 00 + eRENDERDOC_API_Version_1_1_1 = 10101, // RENDERDOC_API_1_1_1 = 1 01 01 + eRENDERDOC_API_Version_1_1_2 = 10102, // RENDERDOC_API_1_1_2 = 1 01 02 + eRENDERDOC_API_Version_1_2_0 = 10200, // RENDERDOC_API_1_2_0 = 1 02 00 + eRENDERDOC_API_Version_1_3_0 = 10300, // RENDERDOC_API_1_3_0 = 1 03 00 + eRENDERDOC_API_Version_1_4_0 = 10400, // RENDERDOC_API_1_4_0 = 1 04 00 + eRENDERDOC_API_Version_1_4_1 = 10401, // RENDERDOC_API_1_4_1 = 1 04 01 + eRENDERDOC_API_Version_1_4_2 = 10402, // RENDERDOC_API_1_4_2 = 1 04 02 + eRENDERDOC_API_Version_1_5_0 = 10500, // RENDERDOC_API_1_5_0 = 1 05 00 + eRENDERDOC_API_Version_1_6_0 = 10600, // RENDERDOC_API_1_6_0 = 1 06 00 +} RENDERDOC_Version; + +// API version changelog: +// +// 1.0.0 - initial release +// 1.0.1 - Bugfix: IsFrameCapturing() was returning false for captures that were triggered +// by keypress or TriggerCapture, instead of Start/EndFrameCapture. +// 1.0.2 - Refactor: Renamed eRENDERDOC_Option_DebugDeviceMode to eRENDERDOC_Option_APIValidation +// 1.1.0 - Add feature: TriggerMultiFrameCapture(). Backwards compatible with 1.0.x since the new +// function pointer is added to the end of the struct, the original layout is identical +// 1.1.1 - Refactor: Renamed remote access to target control (to better disambiguate from remote +// replay/remote server concept in replay UI) +// 1.1.2 - Refactor: Renamed "log file" in function names to just capture, to clarify that these +// are captures and not debug logging files. This is the first API version in the v1.0 +// branch. +// 1.2.0 - Added feature: SetCaptureFileComments() to add comments to a capture file that will be +// displayed in the UI program on load. +// 1.3.0 - Added feature: New capture option eRENDERDOC_Option_AllowUnsupportedVendorExtensions +// which allows users to opt-in to allowing unsupported vendor extensions to function. +// Should be used at the user's own risk. +// Refactor: Renamed eRENDERDOC_Option_VerifyMapWrites to +// eRENDERDOC_Option_VerifyBufferAccess, which now also controls initialisation to +// 0xdddddddd of uninitialised buffer contents. +// 1.4.0 - Added feature: DiscardFrameCapture() to discard a frame capture in progress and stop +// capturing without saving anything to disk. +// 1.4.1 - Refactor: Renamed Shutdown to RemoveHooks to better clarify what is happening +// 1.4.2 - Refactor: Renamed 'draws' to 'actions' in callstack capture option. +// 1.5.0 - Added feature: ShowReplayUI() to request that the replay UI show itself if connected +// 1.6.0 - Added feature: SetCaptureTitle() which can be used to set a title for a +// capture made with StartFrameCapture() or EndFrameCapture() + +typedef struct RENDERDOC_API_1_6_0 +{ + pRENDERDOC_GetAPIVersion GetAPIVersion; + + pRENDERDOC_SetCaptureOptionU32 SetCaptureOptionU32; + pRENDERDOC_SetCaptureOptionF32 SetCaptureOptionF32; + + pRENDERDOC_GetCaptureOptionU32 GetCaptureOptionU32; + pRENDERDOC_GetCaptureOptionF32 GetCaptureOptionF32; + + pRENDERDOC_SetFocusToggleKeys SetFocusToggleKeys; + pRENDERDOC_SetCaptureKeys SetCaptureKeys; + + pRENDERDOC_GetOverlayBits GetOverlayBits; + pRENDERDOC_MaskOverlayBits MaskOverlayBits; + + // Shutdown was renamed to RemoveHooks in 1.4.1. + // These unions allow old code to continue compiling without changes + union + { + pRENDERDOC_Shutdown Shutdown; + pRENDERDOC_RemoveHooks RemoveHooks; + }; + pRENDERDOC_UnloadCrashHandler UnloadCrashHandler; + + // Get/SetLogFilePathTemplate was renamed to Get/SetCaptureFilePathTemplate in 1.1.2. + // These unions allow old code to continue compiling without changes + union + { + // deprecated name + pRENDERDOC_SetLogFilePathTemplate SetLogFilePathTemplate; + // current name + pRENDERDOC_SetCaptureFilePathTemplate SetCaptureFilePathTemplate; + }; + union + { + // deprecated name + pRENDERDOC_GetLogFilePathTemplate GetLogFilePathTemplate; + // current name + pRENDERDOC_GetCaptureFilePathTemplate GetCaptureFilePathTemplate; + }; + + pRENDERDOC_GetNumCaptures GetNumCaptures; + pRENDERDOC_GetCapture GetCapture; + + pRENDERDOC_TriggerCapture TriggerCapture; + + // IsRemoteAccessConnected was renamed to IsTargetControlConnected in 1.1.1. + // This union allows old code to continue compiling without changes + union + { + // deprecated name + pRENDERDOC_IsRemoteAccessConnected IsRemoteAccessConnected; + // current name + pRENDERDOC_IsTargetControlConnected IsTargetControlConnected; + }; + pRENDERDOC_LaunchReplayUI LaunchReplayUI; + + pRENDERDOC_SetActiveWindow SetActiveWindow; + + pRENDERDOC_StartFrameCapture StartFrameCapture; + pRENDERDOC_IsFrameCapturing IsFrameCapturing; + pRENDERDOC_EndFrameCapture EndFrameCapture; + + // new function in 1.1.0 + pRENDERDOC_TriggerMultiFrameCapture TriggerMultiFrameCapture; + + // new function in 1.2.0 + pRENDERDOC_SetCaptureFileComments SetCaptureFileComments; + + // new function in 1.4.0 + pRENDERDOC_DiscardFrameCapture DiscardFrameCapture; + + // new function in 1.5.0 + pRENDERDOC_ShowReplayUI ShowReplayUI; + + // new function in 1.6.0 + pRENDERDOC_SetCaptureTitle SetCaptureTitle; +} RENDERDOC_API_1_6_0; + +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_0_0; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_0_1; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_0_2; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_1_0; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_1_1; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_1_2; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_2_0; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_3_0; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_4_0; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_4_1; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_4_2; +typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_5_0; + +////////////////////////////////////////////////////////////////////////////////////////////////// +// RenderDoc API entry point +// +// This entry point can be obtained via GetProcAddress/dlsym if RenderDoc is available. +// +// The name is the same as the typedef - "RENDERDOC_GetAPI" +// +// This function is not thread safe, and should not be called on multiple threads at once. +// Ideally, call this once as early as possible in your application's startup, before doing +// any API work, since some configuration functionality etc has to be done also before +// initialising any APIs. +// +// Parameters: +// version is a single value from the RENDERDOC_Version above. +// +// outAPIPointers will be filled out with a pointer to the corresponding struct of function +// pointers. +// +// Returns: +// 1 - if the outAPIPointers has been filled with a pointer to the API struct requested +// 0 - if the requested version is not supported or the arguments are invalid. +// +typedef int(RENDERDOC_CC *pRENDERDOC_GetAPI)(RENDERDOC_Version version, void **outAPIPointers); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/Whisper/D3D/createBuffer.cpp b/Whisper/D3D/createBuffer.cpp new file mode 100644 index 0000000..3fbb13e --- /dev/null +++ b/Whisper/D3D/createBuffer.cpp @@ -0,0 +1,51 @@ +#include "stdafx.h" +#include "createBuffer.h" + +#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } + +HRESULT DirectCompute::createBuffer( eBufferUse use, size_t totalBytes, ID3D11Buffer** ppGpuBuffer, const void* rsi, ID3D11Buffer** ppStagingBuffer ) +{ + if( totalBytes > INT_MAX ) + return DISP_E_OVERFLOW; + if( nullptr == ppGpuBuffer ) + return E_POINTER; + + CD3D11_BUFFER_DESC bufferDesc{ (UINT)totalBytes, D3D11_BIND_SHADER_RESOURCE }; + switch( use ) + { + case eBufferUse::Immutable: + if( nullptr == rsi ) + return E_INVALIDARG; + bufferDesc.Usage = D3D11_USAGE_IMMUTABLE; + break; + case eBufferUse::ReadWrite: + case eBufferUse::ReadWriteDownload: + bufferDesc.BindFlags |= D3D11_BIND_UNORDERED_ACCESS; + break; + case eBufferUse::Dynamic: + bufferDesc.Usage = D3D11_USAGE_DYNAMIC; + bufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE; + break; + } + + D3D11_SUBRESOURCE_DATA srd; + D3D11_SUBRESOURCE_DATA* pSrd = nullptr; + if( nullptr != rsi ) + { + srd.pSysMem = rsi; + srd.SysMemPitch = srd.SysMemSlicePitch = 0; + pSrd = &srd; + } + + CHECK( device()->CreateBuffer( &bufferDesc, pSrd, ppGpuBuffer ) ); + + if( nullptr != ppStagingBuffer && use == eBufferUse::ReadWriteDownload ) + { + bufferDesc.Usage = D3D11_USAGE_STAGING; + bufferDesc.BindFlags = 0; + bufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_READ; + CHECK( device()->CreateBuffer( &bufferDesc, nullptr, ppStagingBuffer ) ); + } + + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/D3D/createBuffer.h b/Whisper/D3D/createBuffer.h new file mode 100644 index 0000000..df0c2ab --- /dev/null +++ b/Whisper/D3D/createBuffer.h @@ -0,0 +1,8 @@ +#pragma once +#include "enums.h" +#include "device.h" + +namespace DirectCompute +{ + HRESULT createBuffer( eBufferUse use, size_t totalBytes, ID3D11Buffer** ppGpuBuffer, const void* rsi, ID3D11Buffer** ppStagingBuffer ); +}
\ No newline at end of file diff --git a/Whisper/D3D/device.cpp b/Whisper/D3D/device.cpp new file mode 100644 index 0000000..4eb5a60 --- /dev/null +++ b/Whisper/D3D/device.cpp @@ -0,0 +1,120 @@ +#include "stdafx.h" +#include "device.h" +#include <immintrin.h> +#include <ammintrin.h> +#pragma comment(lib, "D3D11.lib") +#include "RenderDoc/renderDoc.h" + +namespace DirectCompute +{ + CComPtr<ID3D11Device> g_device; + CComPtr<ID3D11DeviceContext> g_context; + D3D_FEATURE_LEVEL g_featureLevel = (D3D_FEATURE_LEVEL)0; + + ID3D11Device* device() { return g_device; } + ID3D11DeviceContext* context() { return g_context; } + D3D_FEATURE_LEVEL featureLevel() { return g_featureLevel; } + + void terminate() + { + g_context = nullptr; + g_device = nullptr; + } + + static HRESULT createDevice() + { + if( g_device ) + return S_FALSE; + + const std::array<D3D_FEATURE_LEVEL, 4> levels = { D3D_FEATURE_LEVEL_12_1 , D3D_FEATURE_LEVEL_12_0 , D3D_FEATURE_LEVEL_11_1 , D3D_FEATURE_LEVEL_11_0 }; + UINT flags = D3D11_CREATE_DEVICE_DISABLE_GPU_TIMEOUT | D3D11_CREATE_DEVICE_SINGLETHREADED; + bool renderDoc = initializeRenderDoc(); +#ifdef _DEBUG + if( !renderDoc ) + { + // Last time I checked, RenderDoc crashed with debug version of D3D11 runtime + // Only setting this flag unless renderdoc.dll is loaded to the current process + flags |= D3D11_CREATE_DEVICE_DEBUG; + } +#endif + constexpr UINT levelsCount = (UINT)levels.size(); + HRESULT hr = D3D11CreateDevice( nullptr, D3D_DRIVER_TYPE_HARDWARE, nullptr, flags, levels.data(), levelsCount, D3D11_SDK_VERSION, &g_device, &g_featureLevel, &g_context ); + if( SUCCEEDED( hr ) ) + return S_OK; + // D3D11_CREATE_DEVICE_DISABLE_GPU_TIMEOUT: This value is not supported until Direct3D 11.1 + // https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ne-d3d11-d3d11_create_device_flag + flags = _andn_u32( D3D11_CREATE_DEVICE_DISABLE_GPU_TIMEOUT, flags ); + + hr = D3D11CreateDevice( nullptr, D3D_DRIVER_TYPE_HARDWARE, nullptr, flags, levels.data(), levelsCount, D3D11_SDK_VERSION, &g_device, &g_featureLevel, &g_context ); + if( SUCCEEDED( hr ) ) + return S_OK; + return hr; + } + + sGpuInfo s_gpuInfo = {}; + const sGpuInfo& gpuInfo = s_gpuInfo; + + static HRESULT queryDeviceInfo() + { + if( nullptr == g_device ) + return OLE_E_BLANK; + CComPtr<IDXGIDevice> dd; + CHECK( g_device.QueryInterface( &dd ) ); + + CComPtr<IDXGIAdapter> adapter; + CHECK( dd->GetAdapter( &adapter ) ); + + DXGI_ADAPTER_DESC desc; + adapter->GetDesc( &desc ); + + const size_t descLen = wcsnlen_s( desc.Description, 128 ); + const wchar_t* rsi = &desc.Description[ 0 ]; + s_gpuInfo.description.assign( rsi, rsi + descLen ); + s_gpuInfo.vendor = (eGpuVendor)desc.VendorId; + s_gpuInfo.device = (uint16_t)desc.DeviceId; + s_gpuInfo.revision = (uint16_t)desc.Revision; + s_gpuInfo.subsystem = desc.SubSysId; + s_gpuInfo.vramDedicated = desc.DedicatedVideoMemory; + s_gpuInfo.ramDedicated = desc.DedicatedSystemMemory; + s_gpuInfo.ramShared = desc.SharedSystemMemory; + return S_OK; + } + + HRESULT initialize() + { + HRESULT hr = createDevice(); + if( hr != S_OK ) + return hr; + queryDeviceInfo(); + return S_OK; + } + + __m128i __declspec( noinline ) bufferMemoryUsage( ID3D11Buffer* buffer ) + { + if( nullptr != buffer ) + { + D3D11_BUFFER_DESC desc; + buffer->GetDesc( &desc ); + + if( desc.Usage != D3D11_USAGE_STAGING ) + return setHigh_size( desc.ByteWidth ); + else + return setLow_size( desc.ByteWidth ); + } + return _mm_setzero_si128(); + } + + __m128i __declspec( noinline ) resourceMemoryUsage( ID3D11ShaderResourceView* srv ) + { + if( nullptr != srv ) + { + CComPtr<ID3D11Resource> res; + srv->GetResource( &res ); + CComPtr<ID3D11Buffer> buff; + if( SUCCEEDED( res.QueryInterface( &buff ) ) ) + return bufferMemoryUsage( buff ); + assert( false ); // We don't use textures in this project + } + return _mm_setzero_si128(); + } +}
\ No newline at end of file diff --git a/Whisper/D3D/device.h b/Whisper/D3D/device.h new file mode 100644 index 0000000..dfcb766 --- /dev/null +++ b/Whisper/D3D/device.h @@ -0,0 +1,66 @@ +#pragma once +#include <atlcomcli.h> +#include <string> + +namespace DirectCompute +{ + ID3D11Device* device(); + ID3D11DeviceContext* context(); + D3D_FEATURE_LEVEL featureLevel(); + + HRESULT initialize(); + void terminate(); + + // DXGI_ADAPTER_DESC.VendorId magic numbers; they come from that database: https://pcisig.com/membership/member-companies + enum struct eGpuVendor : uint16_t + { + AMD = 0x1002, + NVidia = 0x10de, + Intel = 0x8086, + VMWare = 0x15ad, + }; + + struct sGpuInfo + { + std::wstring description; + eGpuVendor vendor; + uint16_t device, revision; + uint32_t subsystem; + size_t vramDedicated, ramDedicated, ramShared; + + inline bool wave64() const + { + return vendor == eGpuVendor::AMD; + } + + // On nVidia 1080Ti that approach is much slower, by a factor of 2.4 + // On AMD Cezanne that approach is faster by a factor of 0.69, i.e. 30% faster. + // Dunno why that is, maybe 'coz on that AMD complete panels fit in L3 cache. + // Anyway, we do want extra 30% perf on AMD Cezanne, so only using that code on AMD GPUs. + // Dunno how it gonna behave on other GPUs, need to test. +#if RESHAPED_MATRIX_MULTIPLY + inline bool useReshapedMatMul() const + { + // return true; + return vendor == eGpuVendor::AMD; + } +#else + constexpr bool useReshapedMatMul() const { return false; } +#endif + }; + extern const sGpuInfo& gpuInfo; + + inline bool available() + { + return nullptr != device(); + } + + inline void csSetCB( ID3D11Buffer* cb ) + { + context()->CSSetConstantBuffers( 0, 1, &cb ); + } + + __m128i bufferMemoryUsage( ID3D11Buffer* buffer ); + + __m128i resourceMemoryUsage( ID3D11ShaderResourceView* srv ); +}
\ No newline at end of file diff --git a/Whisper/D3D/downloadBuffer.cpp b/Whisper/D3D/downloadBuffer.cpp new file mode 100644 index 0000000..9790ab7 --- /dev/null +++ b/Whisper/D3D/downloadBuffer.cpp @@ -0,0 +1,72 @@ +#include "stdafx.h" +#include "downloadBuffer.h" +#include "device.h" +#include "MappedResource.h" + +namespace +{ + struct BufferInfo + { + D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc; + D3D11_BUFFER_DESC bufferDesc; + CComPtr<ID3D11Buffer> source; + + HRESULT create( ID3D11ShaderResourceView* srv ) + { + srv->GetDesc( &viewDesc ); + if( viewDesc.ViewDimension != D3D_SRV_DIMENSION_BUFFER ) + return E_INVALIDARG; + + CComPtr<ID3D11Resource> res; + srv->GetResource( &res ); + CHECK( res.QueryInterface( &source ) ); + + source->GetDesc( &bufferDesc ); + return S_OK; + } + + HRESULT download( void* rdi ) + { + bufferDesc.BindFlags = 0; + bufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_READ; + bufferDesc.Usage = D3D11_USAGE_STAGING; + CComPtr<ID3D11Buffer> staging; + using namespace DirectCompute; + CHECK( device()->CreateBuffer( &bufferDesc, nullptr, &staging ) ); + + context()->CopyResource( staging, source ); + + MappedResource mapped; + mapped.map( staging, true ); + memcpy( rdi, mapped.data(), bufferDesc.ByteWidth ); + return S_OK; + } + }; + + size_t dxgiSizeof( DXGI_FORMAT fmt ) + { + switch( fmt ) + { + case DXGI_FORMAT_R16_FLOAT: return 2; + case DXGI_FORMAT_R32_FLOAT: return 4; + } + return 0; + } +} + +template<class E> +HRESULT DirectCompute::downloadBuffer( ID3D11ShaderResourceView* srv, std::vector<E>& vec ) +{ + BufferInfo bi; + CHECK( bi.create( srv ) ); + + const size_t cb = dxgiSizeof( bi.viewDesc.Format ); + if( cb != sizeof( E ) ) + return E_INVALIDARG; + + vec.resize( bi.bufferDesc.ByteWidth / cb ); + return bi.download( vec.data() ); +} + +template HRESULT DirectCompute::downloadBuffer( ID3D11ShaderResourceView* srv, std::vector<uint16_t>& vec ); +template HRESULT DirectCompute::downloadBuffer( ID3D11ShaderResourceView* srv, std::vector<float>& vec );
\ No newline at end of file diff --git a/Whisper/D3D/downloadBuffer.h b/Whisper/D3D/downloadBuffer.h new file mode 100644 index 0000000..0885007 --- /dev/null +++ b/Whisper/D3D/downloadBuffer.h @@ -0,0 +1,9 @@ +#pragma once + +namespace DirectCompute +{ + // Download a buffer from VRAM into std::vector + // The function is relatively expensive, creates a temporary staging buffer on each call, and only used to test things. + template<class E> + HRESULT downloadBuffer( ID3D11ShaderResourceView* srv, std::vector<E>& vec ); +}
\ No newline at end of file diff --git a/Whisper/D3D/enums.cpp b/Whisper/D3D/enums.cpp new file mode 100644 index 0000000..7c31648 --- /dev/null +++ b/Whisper/D3D/enums.cpp @@ -0,0 +1,9 @@ +#include "stdafx.h" +#include "enums.h" + +static const alignas( 16 ) std::array<DXGI_FORMAT, 3> s_tensorViewFormats = { DXGI_FORMAT_R16_FLOAT, DXGI_FORMAT_R32_FLOAT, DXGI_FORMAT_R32_UINT }; + +DXGI_FORMAT DirectCompute::viewFormat( eDataType dt ) +{ + return s_tensorViewFormats[ (uint8_t)dt ]; +}
\ No newline at end of file diff --git a/Whisper/D3D/enums.h b/Whisper/D3D/enums.h new file mode 100644 index 0000000..c5d5350 --- /dev/null +++ b/Whisper/D3D/enums.h @@ -0,0 +1,34 @@ +#pragma once +#include <stdint.h> +#include <assert.h> + +namespace DirectCompute +{ + enum struct eDataType : uint8_t + { + FP16, + FP32, + U32, + }; + + inline size_t elementSize( eDataType dt ) + { + assert( dt == eDataType::FP16 || dt == eDataType::FP32 || dt == eDataType::U32 ); + + return ( dt == eDataType::FP16 ) ? 2 : 4; + } + + DXGI_FORMAT viewFormat( eDataType dt ); + + enum struct eBufferUse : uint8_t + { + // Immutable tensor, readable from GPU + Immutable, + // Read+write tensor, readable and writable on GPU + ReadWrite, + // Read+write tensor, readable and writable on GPU, which supports downloads from GPU + ReadWriteDownload, + // The tensor is accessible by both GPU (read only) and CPU (write only). Optimized for resources frequently updated from CPU. + Dynamic, + }; +}
\ No newline at end of file diff --git a/Whisper/D3D/shaderNames.cpp b/Whisper/D3D/shaderNames.cpp new file mode 100644 index 0000000..b52f5db --- /dev/null +++ b/Whisper/D3D/shaderNames.cpp @@ -0,0 +1,53 @@ +// This source file is generated by a tool +#include "stdafx.h" +#include "shaderNames.h" + +static const std::array<const char*, 38> s_shaderNames = +{ + "add", + "addInPlace", + "addRepeat", + "addRepeatGelu", + "addRepeatScale", + "addRows", + "convolutionMain", + "convolutionMain2", + "convolutionMain2Fixed", + "convolutionPrep1", + "convolutionPrep2", + "copyConvert", + "copyTranspose", + "diagMaskInf", + "flashAttention", + "flashAttentionCompat1", + "flashAttentionCompat2", + "flashAttentionCompat3", + "fmaRepeat1", + "fmaRepeat2", + "matReshapePanels", + "mulMatByRow", + "mulMatByRowTiled", + "mulMatByRowTiledEx", + "mulMatByScalar", + "mulMatDotMain", + "mulMatDotReshape", + "mulMatMadMain", + "mulMatTiled", + "mulMatTiledEx", + "norm", + "normCompat", + "normFixed", + "scaleInPlace", + "softMax", + "softMaxCompat", + "softMaxFixed", + "zeroMemory", +}; + +const char* DirectCompute::computeShaderName( eComputeShader cs ) +{ + const uint16_t i = (uint16_t)cs; + if( i < s_shaderNames.size() ) + return s_shaderNames[ i ]; + return nullptr; +}
\ No newline at end of file diff --git a/Whisper/D3D/shaderNames.h b/Whisper/D3D/shaderNames.h new file mode 100644 index 0000000..ccfab86 --- /dev/null +++ b/Whisper/D3D/shaderNames.h @@ -0,0 +1,50 @@ +// This header is generated by a tool +#pragma once +#include <stdint.h> + +namespace DirectCompute +{ + enum struct eComputeShader: uint16_t + { + add = 0, + addInPlace = 1, + addRepeat = 2, + addRepeatGelu = 3, + addRepeatScale = 4, + addRows = 5, + convolutionMain = 6, + convolutionMain2 = 7, + convolutionMain2Fixed = 8, + convolutionPrep1 = 9, + convolutionPrep2 = 10, + copyConvert = 11, + copyTranspose = 12, + diagMaskInf = 13, + flashAttention = 14, + flashAttentionCompat1 = 15, + flashAttentionCompat2 = 16, + flashAttentionCompat3 = 17, + fmaRepeat1 = 18, + fmaRepeat2 = 19, + matReshapePanels = 20, + mulMatByRow = 21, + mulMatByRowTiled = 22, + mulMatByRowTiledEx = 23, + mulMatByScalar = 24, + mulMatDotMain = 25, + mulMatDotReshape = 26, + mulMatMadMain = 27, + mulMatTiled = 28, + mulMatTiledEx = 29, + norm = 30, + normCompat = 31, + normFixed = 32, + scaleInPlace = 33, + softMax = 34, + softMaxCompat = 35, + softMaxFixed = 36, + zeroMemory = 37, + }; + + const char* computeShaderName( eComputeShader cs ); +}
\ No newline at end of file diff --git a/Whisper/D3D/shaders.cpp b/Whisper/D3D/shaders.cpp new file mode 100644 index 0000000..f7d9ce4 --- /dev/null +++ b/Whisper/D3D/shaders.cpp @@ -0,0 +1,104 @@ +#include "stdafx.h" +#include "shaders.h" +#include "startup.h" +#include "device.h" +#include <compressapi.h> +#pragma comment( lib, "Cabinet.lib" ) + +namespace +{ +#ifdef _DEBUG +#include "shaderData-Debug.inl" +#else +#include "shaderData-Release.inl" +#endif + + constexpr DWORD compressionAlgorithm = COMPRESS_ALGORITHM_MSZIP; + + class Decompressor + { + DECOMPRESSOR_HANDLE handle = nullptr; + + public: + + HRESULT create() + { + if( CreateDecompressor( compressionAlgorithm, nullptr, &handle ) ) + return S_OK; + return HRESULT_FROM_WIN32( GetLastError() ); + } + + HRESULT decompress( const uint8_t* src, size_t compressedLength, void* dest, size_t origLength ) const + { + if( Decompress( handle, src, compressedLength, dest, origLength, nullptr ) ) + return S_OK; + return HRESULT_FROM_WIN32( GetLastError() ); + } + + ~Decompressor() + { + if( nullptr != handle ) + { + CloseDecompressor( handle ); + handle = nullptr; + } + } + }; + + static std::vector<CComPtr<ID3D11ComputeShader>> s_shaders; +} + +HRESULT DirectCompute::createComputeShaders() +{ + constexpr size_t countBinaries = s_shaderOffsets.size() - 1; + const size_t cbDecompressedLength = s_shaderOffsets[ countBinaries ]; + constexpr size_t countShaders = s_shaderBlobs32.size(); + + std::vector<uint8_t> dxbc; + try + { + s_shaders.resize( countShaders ); + dxbc.resize( cbDecompressedLength ); + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + + Decompressor decomp; + CHECK( decomp.create() ); + + decomp.decompress( s_compressedShaders.data(), s_compressedShaders.size(), dxbc.data(), cbDecompressedLength ); + ID3D11Device* const dev = device(); + + const auto& blobs = gpuInfo.wave64() ? s_shaderBlobs64 : s_shaderBlobs32; + + for( size_t i = 0; i < countShaders; i++ ) + { + const size_t idxBinary = blobs[ i ]; + const uint32_t offThis = s_shaderOffsets[ idxBinary ]; + const uint8_t* rsi = &dxbc[ offThis ]; + const size_t len = s_shaderOffsets[ idxBinary + 1 ] - offThis; + const HRESULT hr = dev->CreateComputeShader( rsi, len, nullptr, &s_shaders[ i ] ); + if( SUCCEEDED( hr ) ) + continue; + + const uint64_t binaryBit = ( 1ull << idxBinary ); + if( 0 != ( binaryBit & fp64ShadersBitmap ) ) + continue; // This shader uses FP64 math, the support for that is optional. When not supported, CreateComputeShader method is expected to fail. + // TODO [low]: ideally, query for the support when creating the device, and don't even try creating these compute shaders + return hr; + } + + return S_OK; +} + +void DirectCompute::destroyComputeShaders() +{ + s_shaders.clear(); +} + +void DirectCompute::bindShader( eComputeShader shader ) +{ + context()->CSSetShader( s_shaders[ (uint16_t)shader ], nullptr, 0 ); +}
\ No newline at end of file diff --git a/Whisper/D3D/shaders.h b/Whisper/D3D/shaders.h new file mode 100644 index 0000000..1188988 --- /dev/null +++ b/Whisper/D3D/shaders.h @@ -0,0 +1,7 @@ +#pragma once +#include "shaderNames.h" + +namespace DirectCompute +{ + void bindShader( eComputeShader shader ); +}
\ No newline at end of file diff --git a/Whisper/D3D/startup.cpp b/Whisper/D3D/startup.cpp new file mode 100644 index 0000000..2ff0b0c --- /dev/null +++ b/Whisper/D3D/startup.cpp @@ -0,0 +1,17 @@ +#include "stdafx.h" +#include "startup.h" +#include "device.h" + +HRESULT DirectCompute::d3dStartup() +{ + HRESULT hr = DirectCompute::initialize(); + if( SUCCEEDED( hr ) ) + hr = createComputeShaders(); + return hr; +} + +void DirectCompute::d3dShutdown() +{ + destroyComputeShaders(); + terminate(); +}
\ No newline at end of file diff --git a/Whisper/D3D/startup.h b/Whisper/D3D/startup.h new file mode 100644 index 0000000..de42fae --- /dev/null +++ b/Whisper/D3D/startup.h @@ -0,0 +1,11 @@ +#pragma once +using HRESULT = long; + +namespace DirectCompute +{ + HRESULT d3dStartup(); + void d3dShutdown(); + + HRESULT createComputeShaders(); + void destroyComputeShaders(); +}
\ No newline at end of file diff --git a/Whisper/DllMain.cpp b/Whisper/DllMain.cpp new file mode 100644 index 0000000..4746d98 --- /dev/null +++ b/Whisper/DllMain.cpp @@ -0,0 +1,27 @@ +#include "stdafx.h" + +BOOL __stdcall DllMain( HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved ) +{ + // Perform actions based on the reason for calling. + switch( fdwReason ) + { + case DLL_PROCESS_ATTACH: + // Initialize once for each new process. Return FALSE to fail DLL load. + DisableThreadLibraryCalls( (HMODULE)hinstDLL ); + break; + case DLL_THREAD_ATTACH: + // Do thread-specific initialization. + break; + case DLL_THREAD_DETACH: + // Do thread-specific cleanup. + break; + case DLL_PROCESS_DETACH: + if( lpvReserved != nullptr ) + { + break; // do not do cleanup if process termination scenario + } + // Perform any necessary cleanup + break; + } + return TRUE; // Successful DLL_PROCESS_ATTACH. +}
\ No newline at end of file diff --git a/Whisper/Hybrid/HybridContext.cpp b/Whisper/Hybrid/HybridContext.cpp new file mode 100644 index 0000000..64ed2e6 --- /dev/null +++ b/Whisper/Hybrid/HybridContext.cpp @@ -0,0 +1,349 @@ +#include "stdafx.h" +#include <immintrin.h> +#include <optional> +#include "HybridContext.h" +#include "../Utils/Trace/tracing.h" + +#if BUILD_HYBRID_VERSION +namespace +{ + int threadsCount( int t ) + { +#ifdef NDEBUG + if( t == 0 ) + { + SYSTEM_INFO si; + GetSystemInfo( &si ); + return (int)si.dwNumberOfProcessors; + } + if( t <= 1 ) + return 1; + return t; +#else + return 1; +#endif + } + + constexpr size_t MB = 1u << 20; +} + +HybridContext::HybridContext( const Whisper::WhisperModel& wm ) : + ml( threadsCount( 0 ) ), + model( wm.hybridTensors ), + whisperModel( wm ) +{ } + +namespace +{ + enum struct eModelType : uint8_t + { + Tiny = 0, + Base = 1, + Small = 2, + Medium = 3, + Large = 4, + }; + + static HRESULT detectModelType( const Whisper::sModelParams& modelParams, eModelType& mt ) + { + switch( modelParams.n_audio_layer ) + { + case 4: + mt = eModelType::Tiny; + return S_OK; + case 6: + mt = eModelType::Base; + return S_OK; + case 12: + mt = eModelType::Small; + return S_OK; + case 24: + mt = eModelType::Medium; + return S_OK; + case 32: + mt = eModelType::Large; + return S_OK; + } + logError( u8"Unrecognized model" ); + return E_INVALIDARG; + } + + struct alignas( 2 ) RamMB + { + uint8_t dec, decLayer; + constexpr RamMB( uint8_t d, uint8_t dl ) : dec( d ), decLayer( dl ) { } + + __m128i loadBytes() const + { + __m128i v = _mm_loadu_si16( this ); + // Upcast bytes to int64_t. That instruction can load directly from memory, too bad VC++ optimized doesn't care + v = _mm_cvtepu8_epi64( v ); + // Scale from megabytes into bytes, the multiplier is obviously 2^20 + v = _mm_slli_epi64( v, 20 ); + return v; + } + }; + + // The magic numbers are from MEM_REQ_DECODE and MEM_REQ_DECODE_LAYER red/black maps in the reference version, + // near the top of whisper.cpp source file + static const std::array<RamMB, 5> s_memRequirements = + { + RamMB{ 200, 32 }, // Tiny + RamMB{ 202, 44 }, // Base + RamMB{ 204, 64 }, // Small + RamMB{ 206, 84 }, // Medium + RamMB{ 208, 110 }, // Large + }; +} + +HRESULT HybridContext::create() +{ + // Allocate buffers for compute + // We know they're large, so bypassing the heap + eModelType modelType; + CHECK( detectModelType( whisperModel.parameters, modelType ) ); + + const __m128i bytes = s_memRequirements.at( (uint8_t)modelType ).loadBytes(); + CHECK( allocCompute.create( _mm_cvtsi128_si64( bytes ) ) ); + CHECK( allocComputeLayer.create( _mm_extract_epi64( bytes, 1 ) ) ); + + // Create staging buffers to download output from encoder stage, + // in the reference version they're named memory_cross_k / memory_cross_v + CHECK( kvCross.create( whisperModel.parameters ) ); + + // Create RAM buffers for memory_k / memory_v + CHECK( kv.create( whisperModel.parameters ) ); + + return S_OK; +} + +class HybridContext::SetAllocatorRaii +{ + HybridContext& context; + CpuCompute::iMemoryAllocator* prevAlloc; + CpuCompute::iArenaAllocator* newAlloc; +public: + + SetAllocatorRaii( HybridContext* owner, CpuCompute::iArenaAllocator& a ) : + context( *owner ) + { + prevAlloc = context.ml.setAllocator( &a ); + newAlloc = &a; + } + ~SetAllocatorRaii() + { + context.ml.setAllocator( prevAlloc ); + newAlloc->resetArena(); + } +}; + +HRESULT HybridContext::decode( const int* tokens, const int n_tokens, const int n_past, const sDecParams& dp, std::vector<float>& probs ) +{ + CHECK( ml.setThreadsCount( dp.n_threads ) ); + + // whisper_decode + const auto& hparams = whisperModel.parameters; + const uint32_t n_vocab = hparams.n_vocab; + + const uint32_t n_ctx = hparams.n_text_ctx; + const uint32_t n_state = hparams.n_text_state; + const uint32_t n_head = hparams.n_text_head; + const uint32_t n_layer = hparams.n_text_layer; + + const uint32_t N = n_tokens; + const uint32_t M = dp.M; + + SetAllocatorRaii ac{ this, allocCompute }; + using namespace CpuCompute; + Tensor cur = ml.addRows( model.tokenEmbedding, model.positionalEmbedding, tokens, n_tokens, n_past ); + Tracing::tensor( "dec-rows", cur ); + + Tensor inpL = cur; + auto kvCross = this->kvCross.map(); + + for( uint32_t il = 0; il < n_layer; il++ ) + { + if( 0 == il ) Tracing::tensor( "dec-inpL", inpL ); + const auto& layer = model.layers[ il ]; + SetAllocatorRaii acLayer{ this, allocComputeLayer }; + + // norm + Tensor cur = ml.norm( inpL ); + ml.fmaRepeat( cur, layer.attnLn0 ); + if( 0 == il ) Tracing::tensor( "dec-norm", cur ); + + // self-attention + { + Tensor Qcur = ml.mulMat( layer.attnQuery.w, cur ); + if( 0 == il ) Tracing::tensor( "dec-Qcur-0", Qcur ); + const float scaling = (float)pow( float( (int)n_state ) / (int)n_head, -0.25 ); + ml.addRepeatScale( Qcur, layer.attnQuery.b, scaling ); + if( 0 == il ) Tracing::tensor( "dec-Qcur-1", Qcur ); + + // note: no bias for Key + Tensor Kcur = ml.mulMat( layer.attnKey, cur ); + ml.scale( Kcur, scaling ); + if( 0 == il ) Tracing::tensor( "dec-Kcur", Kcur ); + + Tensor Vcur = ml.mulMat( layer.attnValue.w, cur ); + ml.addRepeat( Vcur, layer.attnValue.b ); + if( 0 == il ) Tracing::tensor( "dec-Vcur", Vcur ); + + // store key and value to memory + { + const uint32_t len = N * n_state; + const uint32_t off = n_state * ( (uint32_t)il * n_ctx + n_past ); + Tensor k = kv.keysView( len, off ); + Tensor v = kv.valuesView( len, off ); + + CHECK( ml.copyImpl( k, Kcur ) ); + CHECK( ml.copyImpl( v, Vcur ) ); + } + + // ------ + Tensor Q = ml.permute( ml.copy( Qcur, eDataType::FP32, { n_state / n_head, n_head, N } ), 0, 2, 1, 3 ); + Tensor K = ml.permute( kv.keysView( ( n_past + N ) * n_state, (uint32_t)il * n_ctx * n_state ) + .reshape3d( n_state / n_head, n_head, n_past + N ), + 0, 2, 1, 3 ); + Tensor KQ = ml.mulMat( K, Q ); + if( 0 == il ) Tracing::tensor( "dec-KQ-0", KQ ); + ml.diagMaskInf( KQ, n_past ); + if( 0 == il ) Tracing::tensor( "dec-KQ-1", KQ ); + ml.softMax( KQ ); + if( 0 == il ) Tracing::tensor( "dec-KQ-2", KQ ); + + Tensor V_trans = ml.permute( + kv.valuesView( ( n_past + N ) * n_state, (uint32_t)il * n_ctx * n_state ) + .reshape3d( n_state / n_head, n_head, n_past + N ), + 1, 2, 0, 3 ); + + Tensor KQV = ml.mulMat( V_trans, KQ ); + if( 0 == il ) Tracing::tensor( "dec-KQV", KQV ); + + Tensor KQV_merged = ml.permute( KQV, 0, 2, 1, 3 ); + ml.copyInPlace( cur, KQV_merged, eDataType::FP32, { n_state, N } ); + } + + { + cur = ml.mulMat( layer.attnLn1.w, cur ); + ml.addRepeat( cur, layer.attnLn1.b ); + } + + // add the input + Tensor inpCA = ml.add( cur, inpL ); + + // norm + { + cur = ml.norm( inpCA ); + ml.fmaRepeat( cur, layer.crossAttnLn0 ); + } + + // cross-attention + { + Tensor Qcur = ml.mulMat( layer.crossAttnQuery.w, cur ); + ml.addRepeatScale( Qcur, layer.crossAttnQuery.b, (float)pow( float( (int)n_state ) / (int)n_head, -0.25 ) ); + + // Kcross is already scaled + const uint32_t len = M * n_state; + const uint32_t off = (uint32_t)il * len; + Tensor Kcross = kvCross.keysView( len, off ).reshape3d( n_state / n_head, n_head, M ); + Tensor Vcross = kvCross.valuesView( len, off ).reshape3d( n_state / n_head, n_head, M ); + + // ------ + Tensor Q = ml.permute( ml.copy( Qcur, eDataType::FP32, { n_state / n_head, n_head, N } ), 0, 2, 1, 3 ); + Tensor K = ml.permute( Kcross, 0, 2, 1, 3 ); + Tensor KQ = ml.mulMat( K, Q ); + ml.softMax( KQ ); + Tensor V_trans = ml.permute( Vcross, 1, 2, 0, 3 ); + Tensor KQV = ml.mulMat( V_trans, KQ ); + if( 0 == il ) Tracing::tensor( "dec-KQV", KQV ); + Tensor KQV_merged = ml.permute( KQV, 0, 2, 1, 3 ); + + ml.copyInPlace( cur, KQV_merged, eDataType::FP32, { n_state, N } ); + } + + // projection + { + cur = ml.mulMat( layer.crossAttnLn1.w, cur ); + ml.addRepeat( cur, layer.crossAttnLn1.b ); + } + // add the input + ml.addInPlace( cur, inpCA ); + Tensor inpFF = cur; + + // feed-forward network + { + // norm + cur = ml.norm( inpFF ); + ml.fmaRepeat( cur, layer.mlpLn ); + + cur = ml.mulMat( layer.mlp0.w, cur ); + ml.addRepeatGelu( cur, layer.mlp0.b ); + + // The mulMat() below creates a tensor for the output of this layer. + // We have a special memory storage for these tensors, that's how they survive resets of per-layer arenas + allocLayerOutput.resetArena(); + ml.setAllocator( &allocLayerOutput ); + + // projection + cur = ml.mulMat( layer.mlp1.w, cur ); + ml.addRepeat( cur, layer.mlp1.b ); + } + + // output from this layer + ml.addInPlace( cur, inpFF ); + inpL = cur; + } + + // norm + cur = ml.norm( inpL ); + ml.fmaRepeat( cur, model.ln ); + + cur = ml.mulMat( model.tokenEmbedding, cur ); + + // logits -> probs + ml.softMax( cur ); + + const float* rsi = cur.fp32(); + probs.assign( rsi, rsi + cur.countElements() ); + Tracing::vector( "probs", probs ); + return S_OK; +} + +void* HybridContext::AllocSingle::allocate( size_t cb, size_t align ) +{ + if( !allocated ) + { + allocated = true; + if( cb <= capacity ) + { + CpuCompute::dbgMarkUninitializedMemory( buffer.pointer(), capacity ); + return buffer.pointer(); + } + else + { + HRESULT hr = buffer.allocate( cb ); + if( SUCCEEDED( hr ) ) + { + capacity = cb; + CpuCompute::dbgMarkUninitializedMemory( buffer.pointer(), capacity ); + return buffer.pointer(); + } + logErrorHr( hr, u8"HybridContext.AllocSingle.allocate" ); + throw hr; + } + } + else + { + logError( u8"HybridContext.AllocSingle only supports 1 tensor" ); + throw E_UNEXPECTED; + } +} + +void HybridContext::AllocSingle::resetArena() +{ + allocated = false; + if( capacity > 0 ) + CpuCompute::dbgMarkFreedMemory( buffer.pointer(), capacity ); +} +#endif
\ No newline at end of file diff --git a/Whisper/Hybrid/HybridContext.h b/Whisper/Hybrid/HybridContext.h new file mode 100644 index 0000000..52039eb --- /dev/null +++ b/Whisper/Hybrid/HybridContext.h @@ -0,0 +1,52 @@ +#pragma once +#include "../Whisper/WhisperModel.h" +#include "../CPU/MlContext.h" +#include "../CPU/BufferAllocator.h" +#include "KeyValueDownloader.h" +#include "../CPU/KvTensors.h" + +// This version of the hybrid context uses the new, custom-built kernels +class HybridContext +{ + CpuCompute::MlContext ml; + CpuCompute::VirtualAllocator allocCompute, allocComputeLayer; + + class AllocSingle : public CpuCompute::iArenaAllocator + { + CpuCompute::LargeBuffer buffer; + size_t capacity = 0; + bool allocated = false; + // Inherited via iArenaAllocator + virtual void* allocate( size_t cb, size_t align ) override final; + + public: + virtual void resetArena() override final; + }; + AllocSingle allocLayerOutput; + + const CpuCompute::DecoderTensors& model; + const Whisper::WhisperModel& whisperModel; + KeyValueDownloader kvCross; + CpuCompute::KvTensors kv; + + class SetAllocatorRaii; + +public: + + HybridContext( const Whisper::WhisperModel& wm ); + + HRESULT create(); + + HRESULT downloadKeyValues( const DirectCompute::KeyValueBuffers& source ) + { + return kvCross.download( source ); + } + + struct sDecParams + { + int n_threads; + int M; + }; + + HRESULT decode( const int* tokens, const int n_tokens, const int n_past, const sDecParams& dp, std::vector<float>& probs_out ); +};
\ No newline at end of file diff --git a/Whisper/Hybrid/KeyValueDownloader.cpp b/Whisper/Hybrid/KeyValueDownloader.cpp new file mode 100644 index 0000000..ad50136 --- /dev/null +++ b/Whisper/Hybrid/KeyValueDownloader.cpp @@ -0,0 +1,32 @@ +#include "stdafx.h" +#include "KeyValueDownloader.h" + +HRESULT KeyValueDownloader::create( const Whisper::sModelParams& mp ) +{ + const uint32_t n_audio_ctx = mp.n_audio_ctx; + const uint32_t n_mem = mp.n_text_layer * mp.n_audio_ctx; + const uint32_t n_elements = mp.n_text_state * n_mem; + + CD3D11_BUFFER_DESC desc{ n_elements * 2, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ }; + ID3D11Device* dev = DirectCompute::device(); + CHECK( dev->CreateBuffer( &desc, nullptr, &keys ) ); + CHECK( dev->CreateBuffer( &desc, nullptr, &values ) ); + + length = n_elements; + return S_OK; +} + +HRESULT KeyValueDownloader::download( const DirectCompute::KeyValueBuffers& source ) +{ + ID3D11DeviceContext* ctx = DirectCompute::context(); + ctx->CopyResource( keys, source.keys.getBuffer() ); + ctx->CopyResource( values, source.values.getBuffer() ); + return S_OK; +} + +KeyValueDownloader::ReadMap::ReadMap( KeyValueDownloader& owner ) : + length( owner.length ) +{ + check( mappedKeys.map( owner.keys, true ) ); + check( mappedValues.map( owner.values, true ) ); +}
\ No newline at end of file diff --git a/Whisper/Hybrid/KeyValueDownloader.h b/Whisper/Hybrid/KeyValueDownloader.h new file mode 100644 index 0000000..e0e9644 --- /dev/null +++ b/Whisper/Hybrid/KeyValueDownloader.h @@ -0,0 +1,63 @@ +#pragma once +#include "../Whisper/sModelParams.h" +#include "../Whisper/KeyValueBuffers.h" +#include "../D3D/MappedResource.h" +#include "../CPU/Tensor.h" + +class KeyValueDownloader +{ + CComPtr<ID3D11Buffer> keys, values; + uint32_t length = 0; + + using E = uint16_t; + static constexpr DirectCompute::eDataType dataType = DirectCompute::eDataType::FP16; + +public: + // Create the staging resources to download kvCross tensors produced by the GPGPU encoder + HRESULT create( const Whisper::sModelParams& mp ); + + // Download these two tensors from VRAM to the staging buffers in system RAM + HRESULT download( const DirectCompute::KeyValueBuffers& source ); + + class ReadMap + { + const uint32_t length; + DirectCompute::MappedResource mappedKeys, mappedValues; + + public: + ReadMap( KeyValueDownloader& owner ); + ~ReadMap() = default; + ReadMap( const ReadMap& ) = delete; + + // A slice of model.memory_k tensor + CpuCompute::Tensor keysView( uint32_t len, uint32_t off ) const + { + if( len + off <= length ) + { + E* rsi = (E*)mappedKeys.data(); + rsi += off; + return CpuCompute::Tensor::fromData( rsi, dataType, len ); + } + throw E_BOUNDS; + } + + // A slice of model.memory_v tensor + CpuCompute::Tensor valuesView( uint32_t len, uint32_t off ) const + { + if( len + off <= length ) + { + E* rsi = (E*)mappedValues.data(); + rsi += off; + return CpuCompute::Tensor::fromData( rsi, dataType, len ); + } + throw E_BOUNDS; + } + }; + + // Map both staging buffers, return RAII object which unmaps when destroyed, + // which can supply the data in the shape of CpuCompute::Tensor vector + decltype( auto ) map() + { + return ReadMap( *this ); + } +};
\ No newline at end of file diff --git a/Whisper/Hybrid/Readme.txt b/Whisper/Hybrid/Readme.txt new file mode 100644 index 0000000..702813d --- /dev/null +++ b/Whisper/Hybrid/Readme.txt @@ -0,0 +1 @@ +The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_HYBRID_VERSION macro in stdafx.h
\ No newline at end of file diff --git a/Whisper/MF/AudioBuffer.cpp b/Whisper/MF/AudioBuffer.cpp new file mode 100644 index 0000000..ba4752d --- /dev/null +++ b/Whisper/MF/AudioBuffer.cpp @@ -0,0 +1,93 @@ +#include "stdafx.h" +#include "AudioBuffer.h" +using namespace Whisper; + +void AudioBuffer::appendMono( const float* rsi, size_t countFloats ) +{ + mono.insert( mono.end(), rsi, rsi + countFloats ); +} + +void AudioBuffer::appendStereo( const float* rsi, size_t countFloats ) +{ + assert( 0 == ( countFloats % 2 ) ); + const size_t countSamples = countFloats / 2; + + const size_t oldLength = mono.size(); + assert( oldLength * 2 == stereo.size() ); + mono.resize( oldLength + countSamples ); + stereo.resize( ( oldLength + countSamples ) * 2 ); + + const float* const rsiEnd = rsi + countSamples * 2; + const float* const rsiEndAligned = rsiEnd - ( countSamples * 2 ) % 8; + + float* rdiStereo = &stereo[ oldLength * 2 ]; + float* rdiMono = &mono[ oldLength ]; + + const __m128 half = _mm_set1_ps( 0.5f ); + for( ; rsi < rsiEndAligned; rsi += 8, rdiStereo += 8, rdiMono += 4 ) + { + // Load 4 samples = 8 floats + __m128 v0 = _mm_loadu_ps( rsi ); // L0, R0, L1, R1 + __m128 v1 = _mm_loadu_ps( rsi + 4 );// L2, R2, L3, R3 + + // Store into the stereo PCM vector + _mm_storeu_ps( rdiStereo, v0 ); + _mm_storeu_ps( rdiStereo + 4, v1 ); + + // Compute and store the average of these channels + __m128 left = _mm_shuffle_ps( v0, v1, _MM_SHUFFLE( 2, 0, 2, 0 ) ); + __m128 right = _mm_shuffle_ps( v0, v1, _MM_SHUFFLE( 3, 1, 3, 1 ) ); + __m128 sum = _mm_add_ps( left, right ); + sum = _mm_mul_ps( sum, half ); + _mm_storeu_ps( rdiMono, sum ); + } + +#pragma loop (no_vector) + for( ; rsi < rsiEnd; rsi += 2, rdiStereo += 2, rdiMono++ ) + { + __m128 vec = _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) ); + _mm_store_sd( (double*)rdiStereo, _mm_castps_pd( vec ) ); + + vec = _mm_add_ss( vec, _mm_movehdup_ps( vec ) ); + vec = _mm_mul_ss( vec, half ); + _mm_store_ss( rdiMono, vec ); + } +} + +void AudioBuffer::appendDownmixedStereo( const float* rsi, size_t countFloats ) +{ + assert( 0 == ( countFloats % 2 ) ); + const size_t countSamples = countFloats / 2; + + const size_t oldLength = mono.size(); + mono.resize( oldLength + countSamples ); + + const float* const rsiEnd = rsi + countSamples * 2; + const float* const rsiEndAligned = rsiEnd - ( countSamples * 2 ) % 8; + + float* rdiMono = &mono[ oldLength ]; + + const __m128 half = _mm_set1_ps( 0.5f ); + for( ; rsi < rsiEndAligned; rsi += 8, rdiMono += 4 ) + { + // Load 4 samples = 8 floats + __m128 v0 = _mm_loadu_ps( rsi ); // L0, R0, L1, R1 + __m128 v1 = _mm_loadu_ps( rsi + 4 );// L2, R2, L3, R3 + + // Compute and store the average of these channels + __m128 left = _mm_shuffle_ps( v0, v1, _MM_SHUFFLE( 2, 0, 2, 0 ) ); + __m128 right = _mm_shuffle_ps( v0, v1, _MM_SHUFFLE( 3, 1, 3, 1 ) ); + __m128 sum = _mm_add_ps( left, right ); + sum = _mm_mul_ps( sum, half ); + _mm_storeu_ps( rdiMono, sum ); + } + +#pragma loop (no_vector) + for( ; rsi < rsiEnd; rsi += 2, rdiMono++ ) + { + __m128 vec = _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) ); + vec = _mm_add_ss( vec, _mm_movehdup_ps( vec ) ); + vec = _mm_mul_ss( vec, half ); + _mm_store_ss( rdiMono, vec ); + } +}
\ No newline at end of file diff --git a/Whisper/MF/AudioBuffer.h b/Whisper/MF/AudioBuffer.h new file mode 100644 index 0000000..87319dd --- /dev/null +++ b/Whisper/MF/AudioBuffer.h @@ -0,0 +1,41 @@ +#pragma once +#include <vector> + +namespace Whisper +{ + struct AudioBuffer + { + std::vector<float> mono; + std::vector<float> stereo; + + void appendMono( const float* rsi, size_t countFloats ); + void appendDownmixedStereo( const float* rsi, size_t countFloats ); + void appendStereo( const float* rsi, size_t countFloats ); + + using pfnAppendSamples = void( AudioBuffer::* )( const float* rsi, size_t countFloats ); + + inline static pfnAppendSamples appendSamplesFunc( bool sourceMono, bool wantStereo ) + { + if( sourceMono ) + return &AudioBuffer::appendMono; + else if( !wantStereo ) + return &AudioBuffer::appendDownmixedStereo; + else + return &AudioBuffer::appendStereo; + } + + void clear() + { + mono.clear(); + stereo.clear(); + } + + void resize( size_t len ) + { + assert( len <= mono.size() ); + mono.resize( len ); + if( !stereo.empty() ) + stereo.resize( len * 2 ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/MF/AudioCapture.cpp b/Whisper/MF/AudioCapture.cpp new file mode 100644 index 0000000..17f34dc --- /dev/null +++ b/Whisper/MF/AudioCapture.cpp @@ -0,0 +1,167 @@ +#include "stdafx.h" +#include <atlstr.h> +#include <mfapi.h> +#include <mfidl.h> +#include <mfreadwrite.h> +#include "AudioCapture.h" +#include "../API/iMediaFoundation.cl.h" +#include "../ComLightLib/comLightServer.h" +#pragma comment(lib, "Mf.lib") + +namespace +{ + struct Strings + { + CString displayName, endpoint; + }; + + HRESULT getAllocString( IMFActivate* activate, const GUID& id, CString& rdi ) + { + wchar_t* pointer = nullptr; + UINT32 cchName; + HRESULT hr = activate->GetAllocatedString( id, &pointer, &cchName ); + if( SUCCEEDED( hr ) ) + rdi.SetString( pointer, cchName ); + CoTaskMemFree( pointer ); + return hr; + } + + HRESULT getInfo( IMFActivate* activate, Strings& rdi ) + { + CHECK( getAllocString( activate, MF_DEVSOURCE_ATTRIBUTE_FRIENDLY_NAME, rdi.displayName ) ); + CHECK( getAllocString( activate, MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE_AUDCAP_ENDPOINT_ID, rdi.endpoint ) ); + return S_OK; + } + + HRESULT __stdcall supplyDevices( Whisper::pfnFoundCaptureDevices pfn, void* pv, IMFActivate** ppDevices, UINT32 count ) + { + if( ppDevices == nullptr || count == 0 ) + return pfn( 0, nullptr, pv ); + + std::vector<Strings> strings; + strings.reserve( count ); + + for( UINT i = 0; i < count; i++ ) + { + IMFActivate* const activate = ppDevices[ i ]; + if( nullptr == activate ) + continue; + Strings info; + HRESULT hr = getInfo( activate, info ); + if( FAILED( hr ) ) + continue; + + strings.emplace_back( std::move( info ) ); + } + + const size_t len = strings.size(); + if( 0 == len ) + return pfn( 0, nullptr, pv ); + + std::vector<Whisper::sCaptureDevice> pointers; + pointers.resize( len ); + for( size_t i = 0; i < len; i++ ) + { + const auto& src = strings[ i ]; + auto& dest = pointers[ i ]; + dest.displayName = src.displayName; + dest.endpoint = src.endpoint; + } + return pfn( (int)len, pointers.data(), pv ); + } +} + +HRESULT __stdcall Whisper::captureDeviceList( pfnFoundCaptureDevices pfn, void* pv ) +{ + // Create an attribute store to hold the search criteria. + CComPtr<IMFAttributes> attrs; + CHECK( MFCreateAttributes( &attrs, 1 ) ); + // Request audio capture devices + CHECK( attrs->SetGUID( MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE, MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE_AUDCAP_GUID ) ); + + // Enumerate the devices + IMFActivate** ppDevices = nullptr; + UINT32 count = 0; + CHECK( MFEnumDeviceSources( attrs, &ppDevices, &count ) ); + + // Feed the data to the caller + HRESULT hr = supplyDevices( pfn, pv, ppDevices, count ); + + // Free the memory + for( DWORD i = 0; i < count; i++ ) + ppDevices[ i ]->Release(); + CoTaskMemFree( ppDevices ); + + return hr; +} + +namespace +{ + using namespace Whisper; + + class Capture : public ComLight::ObjectRoot<iAudioCapture> + { + CComPtr<IMFSourceReader> reader; + CComPtr<iMediaFoundation> mediaFoundation; + sCaptureParams captureParams; + + HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const noexcept override final + { + if( pp == nullptr ) + return E_POINTER; + CComPtr<IMFSourceReader> res = reader; + *pp = res.Detach();; + return S_OK; + } + const sCaptureParams& COMLIGHTCALL getParams() const noexcept override final + { + return captureParams; + } + public: + HRESULT open( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& cp ); + }; + + HRESULT Capture::open( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& cp ) + { + // Create an attribute store to hold the search criteria. + CComPtr<IMFAttributes> attrs; + CHECK( MFCreateAttributes( &attrs, 2 ) ); + // Request audio capture devices + CHECK( attrs->SetGUID( MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE, MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE_AUDCAP_GUID ) ); + CHECK( attrs->SetString( MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE_AUDCAP_ENDPOINT_ID, endpoint ) ); + + CComPtr<IMFMediaSource> source; + HRESULT hr = MFCreateDeviceSource( attrs, &source ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"MFCreateDeviceSource" ); + return hr; + } + + // TODO: implement IMFSourceReaderCallback, pass into MF_SOURCE_READER_ASYNC_CALLBACK attribute + // This is to support cancellation + hr = MFCreateSourceReaderFromMediaSource( source, nullptr, &reader ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"MFCreateSourceReaderFromMediaSource" ); + return hr; + } + + captureParams = cp; + mediaFoundation = owner; + return S_OK; + } +} + +HRESULT __stdcall Whisper::captureOpen( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) noexcept +{ + if( nullptr == endpoint || nullptr == pp ) + return E_POINTER; + + ComLight::CComPtr<ComLight::Object<Capture>> res; + CHECK( ComLight::Object<Capture>::create( res ) ); + CHECK( res->open( owner, endpoint, captureParams ) ); + + res.detach( pp ); + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/MF/AudioCapture.h b/Whisper/MF/AudioCapture.h new file mode 100644 index 0000000..276ee4b --- /dev/null +++ b/Whisper/MF/AudioCapture.h @@ -0,0 +1,12 @@ +#pragma once +#include "../API/MfStructs.h" + +namespace Whisper +{ + struct iAudioCapture; + struct iMediaFoundation; + + HRESULT __stdcall captureDeviceList( pfnFoundCaptureDevices pfn, void* pv ); + + HRESULT __stdcall captureOpen( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) noexcept; +}
\ No newline at end of file diff --git a/Whisper/MF/MediaFoundation.cpp b/Whisper/MF/MediaFoundation.cpp new file mode 100644 index 0000000..4a4f6a2 --- /dev/null +++ b/Whisper/MF/MediaFoundation.cpp @@ -0,0 +1,109 @@ +#include "stdafx.h" +#include "../API/iMediaFoundation.cl.h" +#include "mfStartup.h" +#include "../ComLightLib/comLightServer.h" +#include "loadAudioFile.h" +#include <mfidl.h> +#include <mfreadwrite.h> +#include "mfUtils.h" +#include "AudioCapture.h" + +namespace Whisper +{ + class AudioReader : public ComLight::ObjectRoot<iAudioReader> + { + CComPtr<IMFSourceReader> reader; + bool wantStereo; + CComPtr<iMediaFoundation> mediaFoundation; + + HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const noexcept override final + { + if( pp == nullptr ) + return E_POINTER; + CComPtr<IMFSourceReader> res = reader; + *pp = res.Detach();; + return S_OK; + } + HRESULT COMLIGHTCALL requestedStereo() const noexcept override final + { + return wantStereo ? S_OK : S_FALSE; + } + HRESULT COMLIGHTCALL getDuration( int64_t& rdi ) const noexcept override final + { + if( reader ) + return getStreamDuration( reader, rdi ); + return OLE_E_BLANK; + } + public: + HRESULT open( iMediaFoundation* owner, LPCTSTR path, bool stereo ) + { + HRESULT hr = MFCreateSourceReaderFromURL( path, nullptr, &reader ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"MFCreateSourceReaderFromURL failed" ); + return hr; + } + wantStereo = stereo; + mediaFoundation = owner; + logDebug16( L"Created source reader from the file \"%s\"", path ); + return S_OK; + } + }; + + class MediaFoundation : public ComLight::ObjectRoot<iMediaFoundation> + { + MfStartupRaii raii; + DWORD tid = ~(DWORD)0; + + virtual HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const noexcept override final + { + return Whisper::loadAudioFile( path, stereo, pp ); + } + virtual HRESULT COMLIGHTCALL openAudioFile( LPCTSTR path, bool stereo, iAudioReader** pp ) noexcept override final + { + if( nullptr == path || nullptr == pp ) + return E_POINTER; + + ComLight::CComPtr<ComLight::Object<AudioReader>> res; + CHECK( ComLight::Object<AudioReader>::create( res ) ); + CHECK( res->open( this, path, stereo ) ); + + res.detach( pp ); + return S_OK; + } + HRESULT COMLIGHTCALL listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ) noexcept override final + { + return captureDeviceList( pfn, pv ); + } + HRESULT COMLIGHTCALL openCaptureDevice( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) noexcept override final + { + return captureOpen( this, endpoint, captureParams, pp ); + } + protected: + + HRESULT FinalConstruct() + { + CHECK( raii.startup() ); + tid = GetCurrentThreadId(); + return S_OK; + } + + public: + + ~MediaFoundation() override + { + assert( tid == GetCurrentThreadId() ); + } + }; +} + +HRESULT COMLIGHTCALL Whisper::initMediaFoundation( iMediaFoundation** pp ) +{ + if( nullptr == pp ) + return E_POINTER; + + ComLight::CComPtr<ComLight::Object<MediaFoundation>> obj; + CHECK( ComLight::Object<MediaFoundation>::create( obj ) ); + obj.detach( pp ); + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/MF/PcmReader.cpp b/Whisper/MF/PcmReader.cpp new file mode 100644 index 0000000..ab92fc3 --- /dev/null +++ b/Whisper/MF/PcmReader.cpp @@ -0,0 +1,274 @@ +#include "stdafx.h" +#include "PcmReader.h" +#include <mfapi.h> +#include "mfUtils.h" + +namespace Whisper +{ + __interface iSampleHandler + { + void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, PcmStereoChunk* pStereo ) const; + void moveBufferData( AudioBuffer& rdi, size_t amount ) const; + void appendPcm( AudioBuffer& rdi, const float* rsi, size_t countFloats ) const; + void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, size_t samples, PcmStereoChunk* pStereo ) const; + uint32_t readerChannelsCount() const; + }; +} + +namespace +{ + using namespace Whisper; + + __forceinline void copyMono( PcmMonoChunk* rdi, const AudioBuffer& rsi, size_t sourceOffset, size_t samples ) + { + assert( sourceOffset + samples <= rsi.mono.size() ); + memcpy( rdi->mono.data(), &rsi.mono[ sourceOffset ], samples * 4 ); + if( samples < FFT_STEP ) + memset( rdi->mono.data() + samples, 0, ( FFT_STEP - samples ) * 4 ); + } + + __forceinline void copyStereo( PcmStereoChunk* rdi, const AudioBuffer& rsi, size_t sourceOffset, size_t samples ) + { + memcpy( rdi->stereo.data(), &rsi.stereo[ sourceOffset * 2 ], samples * 8 ); + if( samples < FFT_STEP ) + memset( rdi->stereo.data() + samples * 2, 0, ( FFT_STEP - samples ) * 8 ); + } + + struct HandlerMono : iSampleHandler + { + void appendPcm( AudioBuffer& rdi, const float* rsi, size_t countFloats ) const override + { + rdi.appendMono( rsi, countFloats ); + } + void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, PcmStereoChunk* pStereo ) const override final + { + copyMono( pMono, rsi, sourceOffset, FFT_STEP ); + } + void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, size_t samples, PcmStereoChunk* pStereo ) const override final + { + copyMono( pMono, rsi, sourceOffset, samples ); + } + void moveBufferData( AudioBuffer& rdi, size_t amount ) const override final + { + const size_t len = rdi.mono.size(); + assert( amount <= len ); + if( amount < len ) + { + const size_t block = len - amount; + memmove( rdi.mono.data(), rdi.mono.data() + amount, block * 4 ); + rdi.mono.resize( block ); + } + else + rdi.mono.clear(); + } + uint32_t readerChannelsCount() const override { return 1; } + }; + struct HandlerDownmixedStereo : HandlerMono + { + void appendPcm( AudioBuffer& rdi, const float* rsi, size_t countFloats ) const override final + { + rdi.appendDownmixedStereo( rsi, countFloats ); + } + uint32_t readerChannelsCount() const override final { return 2; } + }; + struct HandlerStereo : iSampleHandler + { + void appendPcm( AudioBuffer& rdi, const float* rsi, size_t countFloats ) const override final + { + rdi.appendStereo( rsi, countFloats ); + } + void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, PcmStereoChunk* pStereo ) const override final + { + copyMono( pMono, rsi, sourceOffset, FFT_STEP ); + copyStereo( pStereo, rsi, sourceOffset, FFT_STEP ); + } + void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, size_t samples, PcmStereoChunk* pStereo ) const override final + { + copyMono( pMono, rsi, sourceOffset, samples ); + copyStereo( pStereo, rsi, sourceOffset, samples ); + } + void moveBufferData( AudioBuffer& rdi, size_t amount ) const override final + { + const size_t len = rdi.mono.size(); + assert( amount <= len ); + if( amount < len ) + { + const size_t block = len - amount; + memmove( rdi.mono.data(), rdi.mono.data() + amount, block * 4 ); + rdi.mono.resize( block ); + memmove( rdi.stereo.data(), rdi.stereo.data() + amount * 2, block * 8 ); + rdi.mono.resize( block * 2 ); + } + else + { + rdi.mono.clear(); + rdi.stereo.clear(); + } + } + uint32_t readerChannelsCount() const override final { return 2; } + }; + static const HandlerMono s_mono; + static const HandlerDownmixedStereo s_downmix; + static const HandlerStereo s_stereo; +} + +PcmReader::PcmReader( IMFSourceReader* reader, bool stereo ) +{ + if( nullptr == reader ) + throw E_POINTER; + this->reader = reader; + + // Set up media type, and figure out sample handler + check( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) ); + check( reader->SetStreamSelection( MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE ) ); + + CComPtr<IMFMediaType> mtNative; + check( reader->GetNativeMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, MF_SOURCE_READER_CURRENT_TYPE_INDEX, &mtNative ) ); + UINT32 numChannels; + check( mtNative->GetUINT32( MF_MT_AUDIO_NUM_CHANNELS, &numChannels ) ); + + const bool sourceMono = numChannels < 2; + if( sourceMono ) + sampleHandler = &s_mono; + else if( !stereo ) + sampleHandler = &s_downmix; + else + { + sampleHandler = &s_stereo; + m_stereoOutput = true; + } + + CComPtr<IMFMediaType> mt; + check( createMediaType( !sourceMono, &mt ) ); + check( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) ); + + // Find out the length + int64_t durationTicks; + check( getStreamDuration( reader, durationTicks ) ); + + // Convert length to chunks + // Seconds = Ticks / 10^7 + // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7 + // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down + constexpr __int64 mul = SAMPLE_RATE; + constexpr __int64 div = (__int64)FFT_STEP * 10'000'000; + m_length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 ); +} + +HRESULT PcmReader::readNextSample() +{ + const size_t off = bufferReadOffset; + const size_t availableSamples = pcm.mono.size() - off; + + // If needed, move the remaining PCM data to the start of these vectors + if( availableSamples > 0 ) + { + if( 0 != off ) + sampleHandler->moveBufferData( pcm, off ); + } + else + pcm.clear(); + bufferReadOffset = 0; + + while( true ) + { + DWORD dwFlags = 0; + CComPtr<IMFSample> sample; + + // Read the next sample + HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"IMFSourceReader.ReadSample" ); + return hr; + } + + if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED ) + { + // logError( u8"Media type changes ain’t supported by the library." ); + // return E_UNEXPECTED; + + // This happens for some video files at the very start of the reading, with Dolby AC3 audio track. + // Instead of failing the transcribe process, verify the important attributes (FP32 samples, sample rate, count of channels) haven’t changed. + CHECK( validateCurrentMediaType( reader, sampleHandler->readerChannelsCount() ) ); + } + + if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM ) + return E_EOF; + + if( !sample ) + { + // printf( "No sample\n" ); + continue; + } + + // Get a pointer to the audio data in the sample. + CComPtr<IMFMediaBuffer> buffer; + hr = sample->ConvertToContiguousBuffer( &buffer ); + if( FAILED( hr ) ) + return hr; + + const float* pAudioData = nullptr; + DWORD cbBuffer; + hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer ); + if( FAILED( hr ) ) + return hr; + + try + { + assert( 0 == ( cbBuffer % sizeof( float ) ) ); + const size_t countFloats = cbBuffer / sizeof( float ); + sampleHandler->appendPcm( pcm, pAudioData, countFloats ); + } + catch( const std::bad_alloc& ) + { + buffer->Unlock(); + return E_OUTOFMEMORY; + } + + // Unlock the buffer + hr = buffer->Unlock(); + if( FAILED( hr ) ) + return hr; + + return S_OK; + } +} + +HRESULT PcmReader::readChunk( PcmMonoChunk& mono, PcmStereoChunk* stereo ) +{ + while( true ) + { + const size_t off = bufferReadOffset; + const size_t availableSamples = pcm.mono.size() - off; + if( availableSamples >= FFT_STEP ) + { + // We have enough data in the buffer + sampleHandler->copyChunk( &mono, pcm, off, stereo ); + bufferReadOffset = off + FFT_STEP; + return S_OK; + } + + if( !m_readerEndOfFile ) + { + // We don't have enough data, but the stream has not ended yet, can load moar samples from the reader + HRESULT hr = readNextSample(); + if( SUCCEEDED( hr ) ) + continue; + if( hr != E_EOF ) + return hr; + m_readerEndOfFile = true; + } + + if( availableSamples > 0 ) + { + // We have reached the end of stream of the reader, but the buffer still has a few samples. + // Return the final incomplete chunk padded with zeros + sampleHandler->copyChunk( &mono, pcm, off, availableSamples, stereo ); + bufferReadOffset = off + availableSamples; + return S_OK; + } + + return E_EOF; + } +}
\ No newline at end of file diff --git a/Whisper/MF/PcmReader.h b/Whisper/MF/PcmReader.h new file mode 100644 index 0000000..9e3757e --- /dev/null +++ b/Whisper/MF/PcmReader.h @@ -0,0 +1,63 @@ +#pragma once +#include "../Whisper/audioConstants.h" +#include <mfidl.h> +#include <mfreadwrite.h> +#include "AudioBuffer.h" + +namespace Whisper +{ + // PCM buffer with 10 milliseconds of single-channel audio + struct PcmMonoChunk + { + std::array<float, FFT_STEP> mono; + }; + // PCM buffer with 10 milliseconds of interleaved stereo + struct PcmStereoChunk + { + std::array<float, FFT_STEP * 2> stereo; + }; + + __interface iSampleHandler; + + constexpr HRESULT E_EOF = HRESULT_FROM_WIN32( ERROR_HANDLE_EOF ); + + // Utility class which reads chunks of FFT_STEP FP32 PCM samples from the MF source reader + // The class always delivers mono chunks, and can optionally deliver stereo in a separate buffer. + class PcmReader + { + // A small intermediate buffer with PCM data for complete media foundation samples + AudioBuffer pcm; + // Index of the first unconsumed sample in the pcm buffer + size_t bufferReadOffset = 0; + // Utility object to abstract away mono versus stereo shenanigans + const iSampleHandler* sampleHandler; + // The underlying MF source reader which delivers audio data + CComPtr<IMFSourceReader> reader; + // True after we consumed all available media samples from the reader + bool m_readerEndOfFile = false; + // True if this object delivers stereo samples + bool m_stereoOutput = false; + // The count of chunks we expect to get from the reader + size_t m_length = 0; + // Read next sample from the reader, store in the PCM buffer in this class + HRESULT readNextSample(); + + public: + + PcmReader( IMFSourceReader* source, bool stereo ); + + // Count of chunks in the MEL spectrogram. + // The PCM audio is generally slightly longer than that, due to the incomplete last chunk. + size_t getLength() const noexcept + { + return m_length; + } + + // True when the stereo flag passed to constructor, and the audio stream actually has 2 or more audio channels + bool outputsStereo() const { return m_stereoOutput; } + + // Load another 10ms chunk from the stream + // For the last chunk in the stream, the output buffers are padded with zeros + HRESULT readChunk( PcmMonoChunk& mono, PcmStereoChunk* stereo ); + }; +}
\ No newline at end of file diff --git a/Whisper/MF/loadAudioFile.cpp b/Whisper/MF/loadAudioFile.cpp new file mode 100644 index 0000000..d1a439a --- /dev/null +++ b/Whisper/MF/loadAudioFile.cpp @@ -0,0 +1,151 @@ +#include "stdafx.h" +#include "../ComLightLib/comLightServer.h" +#include "loadAudioFile.h" +#include "mfUtils.h" +#include "AudioBuffer.h" +#include <mfidl.h> +#include <mfreadwrite.h> +#include <mfapi.h> +#pragma comment(lib, "Mfreadwrite.lib") +#pragma comment(lib, "mfuuid.lib") + +namespace Whisper +{ + class MediaFileBuffer : public ComLight::ObjectRoot<iAudioBuffer> + { + AudioBuffer pcm; + uint32_t channels = 0; + + uint32_t COMLIGHTCALL countSamples() const noexcept override final + { + return (uint32_t)( pcm.mono.size() ); + } + const float* COMLIGHTCALL getPcmMono() const noexcept override final + { + if( !pcm.mono.empty() ) + return pcm.mono.data(); + return nullptr; + } + const float* COMLIGHTCALL getPcmStereo() const noexcept override final + { + if( !pcm.stereo.empty() ) + return pcm.stereo.data(); + return nullptr; + } + HRESULT COMLIGHTCALL getTime( int64_t& rdi ) const noexcept override final + { + rdi = 0; + return S_OK; + } + public: + HRESULT load( LPCTSTR path, bool stereo ); + }; + + HRESULT MediaFileBuffer::load( LPCTSTR path, bool stereo ) + { + CComPtr<IMFSourceReader> reader; + HRESULT hr = MFCreateSourceReaderFromURL( path, nullptr, &reader ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"MFCreateSourceReaderFromURL failed" ); + return hr; + } + + CHECK( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) ); + CHECK( reader->SetStreamSelection( MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE ) ); + + CComPtr<IMFMediaType> mtNative; + CHECK( reader->GetNativeMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, MF_SOURCE_READER_CURRENT_TYPE_INDEX, &mtNative ) ); + UINT32 numChannels; + CHECK( mtNative->GetUINT32( MF_MT_AUDIO_NUM_CHANNELS, &numChannels ) ); + const bool sourceMono = numChannels == 1; + const AudioBuffer::pfnAppendSamples pfn = AudioBuffer::appendSamplesFunc( sourceMono, stereo ); + channels = ( stereo && !sourceMono ) ? 2 : 1; + + CComPtr<IMFMediaType> mt; + CHECK( createMediaType( !sourceMono, &mt ) ); + + CHECK( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) ); + + while( true ) + { + DWORD dwFlags = 0; + CComPtr<IMFSample> sample; + + // Read the next sample. + hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"IMFSourceReader.ReadSample" ); + return hr; + } + + if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED ) + { + logError( u8"Media type changes ain’t supported by the library." ); + return E_UNEXPECTED; + } + + if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM ) + break; + + if( !sample ) + { + // printf( "No sample\n" ); + continue; + } + + // Get a pointer to the audio data in the sample. + CComPtr<IMFMediaBuffer> buffer; + hr = sample->ConvertToContiguousBuffer( &buffer ); + if( FAILED( hr ) ) + return hr; + + const float* pAudioData = nullptr; + DWORD cbBuffer; + hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer ); + if( FAILED( hr ) ) + return hr; + + try + { + const size_t countFloats = cbBuffer / sizeof( float ); + ( pcm.*pfn )( pAudioData, countFloats ); + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + + // Unlock the buffer + hr = buffer->Unlock(); + if( FAILED( hr ) ) + return hr; + } + + const size_t len = pcm.mono.size(); + if( len == 0 ) + { + logError16( L"The audio file \"%s\" has no samples", path ); + return E_INVALIDARG; + } + if( len < SAMPLE_RATE / 2 ) + logError16( L"The file \"%s\" only has %zu samples, less than 0.5 seconds of audio", path, len ); + else + logDebug16( L"Loaded audio file from \"%s\": %zu samples, %g seconds", path, len, (int)len * ( 1.0 / SAMPLE_RATE ) ); + return S_OK; + + } +} + +HRESULT COMLIGHTCALL Whisper::loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) +{ + if( nullptr == path || nullptr == pp ) + return E_POINTER; + + ComLight::CComPtr<ComLight::Object<MediaFileBuffer>> obj; + CHECK( ComLight::Object<MediaFileBuffer>::create( obj ) ); + CHECK( obj->load( path, stereo ) ); + obj.detach( pp ); + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/MF/loadAudioFile.h b/Whisper/MF/loadAudioFile.h new file mode 100644 index 0000000..9736ccd --- /dev/null +++ b/Whisper/MF/loadAudioFile.h @@ -0,0 +1,7 @@ +#pragma once +#include "../API/iMediaFoundation.cl.h" + +namespace Whisper +{ + HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ); +}
\ No newline at end of file diff --git a/Whisper/MF/mfStartup.cpp b/Whisper/MF/mfStartup.cpp new file mode 100644 index 0000000..b7ab829 --- /dev/null +++ b/Whisper/MF/mfStartup.cpp @@ -0,0 +1,128 @@ +#include "stdafx.h" +#include "mfStartup.h" +#include <atlbase.h> +#include <mfapi.h> +#pragma comment(lib, "Mfplat.lib") + +namespace +{ + struct sCoInitStatus + { + // Possible state: + // -1 is the initial state, coInitialize never called + // S_OK - CoInitializeEx succeeded, in this state the counter tracks the count of coInitialize() for the current thread + // S_FALSE - CoInitializeEx failed with RPC_E_CHANGED_MODE, or did nothing because already initialized for the current thread + // Error status - CoInitializeEx failed for some other reason + HRESULT code = -1; + uint32_t counter = 0; + }; + thread_local sCoInitStatus coInitStatus; + + static HRESULT coInitialize() + { + sCoInitStatus& cis = coInitStatus; + HRESULT hr = cis.code; + if( SUCCEEDED( hr ) ) + { + if( S_OK == hr ) + cis.counter++; + return S_FALSE; + } + + if( hr == HRESULT( -1 ) ) + { + hr = CoInitializeEx( nullptr, COINIT_MULTITHREADED ); + if( S_OK == hr ) + { + cis.counter = 1; + return cis.code = S_OK; + } + if( S_FALSE == hr || RPC_E_CHANGED_MODE == hr ) + { + return cis.code = S_FALSE; + } + cis.code = hr; + return hr; + } + + return hr; + } + + static void coUninitialize() + { + sCoInitStatus& cis = coInitStatus; + if( cis.code == S_OK ) + { + assert( cis.counter > 0 ); + cis.counter--; + if( 0 == cis.counter ) + CoUninitialize(); + } + } + + static CComAutoCriticalSection s_lock; +#define LOCK() CComCritSecLock<CComAutoCriticalSection> lock{ s_lock } + static uint32_t mfStartupCounter = 0; + + constexpr uint8_t FlagCOM = 1; + constexpr uint8_t FlagMF = 0x10; +} + +using namespace Whisper; + +MfStartupRaii::~MfStartupRaii() +{ + if( 0 != ( successFlags & FlagMF ) ) + { + LOCK(); + assert( mfStartupCounter > 0 ); + mfStartupCounter--; + if( mfStartupCounter > 0 ) + return; + MFShutdown(); + successFlags &= ~FlagMF; + } + + if( 0 != ( successFlags & FlagCOM ) ) + { + coUninitialize(); + successFlags &= ~FlagCOM; + } +} + +HRESULT MfStartupRaii::startup() +{ + if( 0 != ( successFlags & FlagMF ) ) + return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); + + HRESULT hr = coInitialize(); + CHECK( hr ); + if( hr == S_OK ) + successFlags |= FlagCOM; + + LOCK(); + + if( 0 == mfStartupCounter ) + { + HRESULT hr = MFStartup( MF_VERSION, MFSTARTUP_LITE ); + if( SUCCEEDED( hr ) ) + { + mfStartupCounter = 1; + successFlags |= FlagMF; + return S_OK; + } + + if( 0 != ( successFlags & FlagCOM ) ) + { + coUninitialize(); + successFlags &= ~FlagCOM; + } + return hr; + } + else + { + mfStartupCounter++; + successFlags |= FlagMF; + return S_FALSE; + } +}
\ No newline at end of file diff --git a/Whisper/MF/mfStartup.h b/Whisper/MF/mfStartup.h new file mode 100644 index 0000000..1434ffc --- /dev/null +++ b/Whisper/MF/mfStartup.h @@ -0,0 +1,15 @@ +#pragma once + +namespace Whisper +{ + class MfStartupRaii + { + uint8_t successFlags = 0; + public: + MfStartupRaii() = default; + ~MfStartupRaii(); + MfStartupRaii( const MfStartupRaii& ) = delete; + + HRESULT startup(); + }; +}
\ No newline at end of file diff --git a/Whisper/MF/mfUtils.cpp b/Whisper/MF/mfUtils.cpp new file mode 100644 index 0000000..e739079 --- /dev/null +++ b/Whisper/MF/mfUtils.cpp @@ -0,0 +1,69 @@ +#include "stdafx.h" +#include "mfUtils.h" +#include <mfapi.h> + +HRESULT Whisper::createMediaType( bool stereo, IMFMediaType** pp ) +{ + if( nullptr == pp ) + return E_POINTER; + + CComPtr<IMFMediaType> mt; + CHECK( MFCreateMediaType( &mt ) ); + CHECK( mt->SetGUID( MF_MT_MAJOR_TYPE, MFMediaType_Audio ) ); + CHECK( mt->SetGUID( MF_MT_SUBTYPE, MFAudioFormat_Float ) ); + CHECK( mt->SetUINT32( MF_MT_AUDIO_SAMPLES_PER_SECOND, SAMPLE_RATE ) ); + + const uint32_t channels = stereo ? 2 : 1; + CHECK( mt->SetUINT32( MF_MT_AUDIO_NUM_CHANNELS, channels ) ); + CHECK( mt->SetUINT32( MF_MT_AUDIO_BLOCK_ALIGNMENT, channels * 4 ) ); + CHECK( mt->SetUINT32( MF_MT_AUDIO_AVG_BYTES_PER_SECOND, channels * 4 * SAMPLE_RATE ) ); + CHECK( mt->SetUINT32( MF_MT_AUDIO_BITS_PER_SAMPLE, 32 ) ); + CHECK( mt->SetUINT32( MF_MT_ALL_SAMPLES_INDEPENDENT, TRUE ) ); + + *pp = mt.Detach(); + + return S_OK; +} + +HRESULT Whisper::getStreamDuration( IMFSourceReader* reader, int64_t& duration ) +{ + PROPVARIANT var; + PropVariantInit( &var ); + CHECK( reader->GetPresentationAttribute( MF_SOURCE_READER_MEDIASOURCE, MF_PD_DURATION, &var ) ); + + if( var.vt == VT_UI8 ) + { + // The documentation says the type of that attribute is UINT64 + // https://learn.microsoft.com/en-us/windows/win32/medfound/mf-pd-duration-attribute + duration = var.uhVal.QuadPart; + return S_OK; + } + logError( u8"Unexpected type of MF_PD_DURATION attribute" ); + return E_INVALIDARG; +} + +HRESULT Whisper::validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels ) +{ + CComPtr<IMFMediaType> mt; + CHECK( reader->GetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, &mt ) ); + + GUID guid; + CHECK( mt->GetGUID( MF_MT_MAJOR_TYPE, &guid ) ); + if( guid != MFMediaType_Audio ) + return E_FAIL; + + CHECK( mt->GetGUID( MF_MT_SUBTYPE, &guid ) ); + if( guid != MFAudioFormat_Float ) + return E_FAIL; + + UINT32 u32; + CHECK( mt->GetUINT32( MF_MT_AUDIO_SAMPLES_PER_SECOND, &u32 ) ); + if( u32 != SAMPLE_RATE ) + return E_FAIL; + + CHECK( mt->GetUINT32( MF_MT_AUDIO_NUM_CHANNELS, &u32 ) ); + if( u32 != expectedChannels ) + return E_FAIL; + + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/MF/mfUtils.h b/Whisper/MF/mfUtils.h new file mode 100644 index 0000000..c889a92 --- /dev/null +++ b/Whisper/MF/mfUtils.h @@ -0,0 +1,15 @@ +#pragma once +#include <stdint.h> +#include <mfidl.h> +#include <mfobjects.h> +#include <mfreadwrite.h> +#include "../Whisper/audioConstants.h" + +namespace Whisper +{ + HRESULT createMediaType( bool stereo, IMFMediaType** pp ); + + HRESULT getStreamDuration( IMFSourceReader* reader, int64_t& duration ); + + HRESULT validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels ); +}
\ No newline at end of file diff --git a/Whisper/ML/ConstantBuffer.cpp b/Whisper/ML/ConstantBuffer.cpp new file mode 100644 index 0000000..5f3bfbe --- /dev/null +++ b/Whisper/ML/ConstantBuffer.cpp @@ -0,0 +1,63 @@ +#include "stdafx.h" +#include "ConstantBuffer.h" +#include "../D3D/MappedResource.h" +using namespace DirectCompute; + +HRESULT ConstantBuffer::create() +{ + if( nullptr == buffer ) + { + CD3D11_BUFFER_DESC desc{ 16 * 3 * 2, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE }; + return device()->CreateBuffer( &desc, nullptr, &buffer ); + } + return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); +} + +namespace +{ + __forceinline void copy32( __m128i* rdi, const TensorShape& ts ) + { + _mm_storeu_si128( rdi, ts.sizeVec() ); + _mm_storeu_si128( rdi + 1, ts.stridesVec() ); + } +} + +HRESULT ConstantBuffer::update( const TensorShape& t0 ) +{ + MappedResource mapped; + CHECK( mapped.map( buffer, false ) ); + + __m128i* const rdi = ( __m128i* )mapped.data(); + copy32( rdi, t0 ); + return S_OK; +} + +HRESULT ConstantBuffer::update( const TensorShape& t0, const TensorShape& t1 ) +{ + MappedResource mapped; + CHECK( mapped.map( buffer, false ) ); + + __m128i* const rdi = ( __m128i* )mapped.data(); + copy32( rdi, t0 ); + copy32( rdi + 2, t1 ); + return S_OK; +} + +HRESULT ConstantBuffer::update( const TensorShape& t0, const TensorShape& t1, const TensorShape& t2 ) +{ + MappedResource mapped; + CHECK( mapped.map( buffer, false ) ); + + __m128i* const rdi = ( __m128i* )mapped.data(); + copy32( rdi, t0 ); + copy32( rdi + 2, t1 ); + copy32( rdi + 4, t2 ); + return S_OK; +} + +void ConstantBuffer::bind() const +{ + ID3D11Buffer* p = buffer; + assert( nullptr != p ); + context()->CSSetConstantBuffers( 0, 1, &p ); +}
\ No newline at end of file diff --git a/Whisper/ML/ConstantBuffer.h b/Whisper/ML/ConstantBuffer.h new file mode 100644 index 0000000..3e6664d --- /dev/null +++ b/Whisper/ML/ConstantBuffer.h @@ -0,0 +1,25 @@ +#pragma once +#include "../D3D/device.h" +#include "TensorShape.h" + +namespace DirectCompute +{ + // 96 bytes dynamic constant buffers, with dimensions and VRAM layout of 2-3 tensors + class ConstantBuffer + { + CComPtr<ID3D11Buffer> buffer; + + public: + HRESULT create(); + HRESULT update( const TensorShape& t0 ); + HRESULT update( const TensorShape& t0, const TensorShape& t1 ); + HRESULT update( const TensorShape& t0, const TensorShape& t1, const TensorShape& t2 ); + + void bind() const; + + __m128i getMemoryUse() const + { + return bufferMemoryUsage( buffer ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/ML/Context.ops.cpp b/Whisper/ML/Context.ops.cpp new file mode 100644 index 0000000..7dfca9f --- /dev/null +++ b/Whisper/ML/Context.ops.cpp @@ -0,0 +1,280 @@ +#include "stdafx.h" +#include "MlContext.h" +#include "testUtils.h" +using namespace DirectCompute; + +Tensor MlContext::createTensor( eDataType type, const std::array<uint32_t, 4>& ne ) +{ + Tensor res; + check( res.create( type, ne ) ); + return res; +} + +Tensor MlContext::createTensor( eDataType type, std::initializer_list<uint32_t> ne ) +{ + size_t nDims = ne.size(); + if( 0 == nDims || nDims > 4 ) + throw E_INVALIDARG; + std::array<uint32_t, 4> arr; + for( size_t i = 0; i < nDims; i++ ) + arr[ i ] = ne.begin()[ i ]; + for( size_t i = nDims; i < 4; i++ ) + arr[ i ] = 1; + return createTensor( type, arr ); +} + +Tensor MlContext::conv_1d_1s( const Tensor& a, const Tensor& b ) +{ + assert( b.isMatrix() ); + assert( a.ne[ 1 ] == b.ne[ 1 ] ); + assert( a.ne[ 3 ] == 1 ); + + Tensor res = createTensor( eDataType::FP32, { b.ne[ 0 ], a.ne[ 2 ] } ); + + convolution( a, b, res ); + return res; +} + +Tensor MlContext::conv_1d_2s( const Tensor& a, const Tensor& b ) +{ + assert( b.isMatrix() ); + assert( a.ne[ 1 ] == b.ne[ 1 ] ); + assert( a.ne[ 3 ] == 1 ); + + Tensor res = createTensor( eDataType::FP32, { b.ne[ 0 ] / 2, a.ne[ 2 ] } ); +#if 0 + static PrintUniqueTensorSizes printSize( "conv_1d_2s" ); + printSize.print( a, b ); +#endif + convolution2( a, b, res ); + return res; +} + +namespace +{ + inline bool canRepeat( const TensorShape& t0, const TensorShape& t1 ) + { + return ( t1.ne[ 0 ] % t0.ne[ 0 ] == 0 ) && + ( t1.ne[ 1 ] % t0.ne[ 1 ] == 0 ) && + ( t1.ne[ 2 ] % t0.ne[ 2 ] == 0 ) && + ( t1.ne[ 3 ] % t0.ne[ 3 ] == 0 ); + } +} + +Tensor MlContext::cwiseBinary( const Tensor& a, const Tensor& b, eComputeShader cs ) +{ + assert( isSameShape( a, b ) ); + Tensor res = createTensor( a.getType(), a.ne ); + cwiseBinary( a, b, res, cs ); + return res; +} + +Tensor __declspec( noinline ) MlContext::view2d( const Tensor& a, uint32_t ne0, uint32_t ne1, uint32_t nb1, uint32_t offset ) +{ + if( 0 != offset ) + throw E_NOTIMPL; + + Tensor res = a; + res.ne = { ne0, ne1, 1, 1 }; + + res.nb[ 1 ] = nb1; + res.nb[ 2 ] = res.nb[ 3 ] = nb1 * ne1; + return res; +} + +Tensor MlContext::transpose( const Tensor& a ) +{ + Tensor result = a; + std::swap( result.ne[ 0 ], result.ne[ 1 ] ); + std::swap( result.nb[ 0 ], result.nb[ 1 ] ); + return result; +} + +Tensor MlContext::norm( const Tensor& a ) +{ + Tensor res = createTensor( a.getType(), a.ne ); + norm( a, res ); + return res; +} + +Tensor MlContext::mulMat( const Tensor& a, const Tensor& b ) +{ + if( !canMulMat( a, b ) ) + throw E_INVALIDARG; + Tensor res = createTensor( eDataType::FP32, { a.ne[ 1 ], b.ne[ 1 ], a.ne[ 2 ], b.ne[ 3 ] } ); + if constexpr( enableInexactOptimizations ) + mulMatTiled( a, b, res ); + else + mulMat( a, b, res ); +#if 0 + Tensor testTiled; + check( testTiled.create( eDataType::FP32, res.ne ) ); + mulMatTiled( a, b, testTiled ); + + std::vector<float> current, tiled; + res.download( current ); + testTiled.download( tiled ); + sTensorDiff diff = computeDiff( current.data(), tiled.data(), current.size() ); + diff.print( "mulMatTiled" ); +#endif + return res; +} + +Tensor MlContext::mulMatEx( const Tensor& a, const Tensor& b, const char* tagName ) +{ + if( !canMulMat( a, b ) ) + throw E_INVALIDARG; + if( 0 != a.nb[ 0 ] ) + throw E_INVALIDARG; // The first argument is expected to be pre-transposed + + const uint16_t tag = profiler.setNextTag( tagName ); + + if( b.ne[ 1 ] != 1 ) + { + if( b.nb[ 0 ] != 0 ) + { + Tensor rhs = reshapePanels( b ); + profiler.setNextTag( tag ); + return mulMatTiledEx( a, rhs ); + } + else + { + // Second argument already reshaped into these panels + return mulMatTiledEx( a, b ); + } + } + else + { + if( 0 != b.nb[ 0 ] ) + return mulMatByRowTiledEx( a, b ); + + // That shader requires classic VRAM layout of the second argument, gonna fail with pre-transposed one + throw E_INVALIDARG; + } +} + +Tensor MlContext::permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 ) +{ + assert( axis0 < 4 ); + assert( axis1 < 4 ); + assert( axis2 < 4 ); + assert( axis3 < 4 ); + + assert( axis0 != axis1 ); + assert( axis0 != axis2 ); + assert( axis0 != axis3 ); + assert( axis1 != axis2 ); + assert( axis1 != axis3 ); + assert( axis2 != axis3 ); + + Tensor res = a; + res.ne[ axis0 ] = a.ne[ 0 ]; + res.ne[ axis1 ] = a.ne[ 1 ]; + res.ne[ axis2 ] = a.ne[ 2 ]; + res.ne[ axis3 ] = a.ne[ 3 ]; + + res.nb[ axis0 ] = a.nb[ 0 ]; + res.nb[ axis1 ] = a.nb[ 1 ]; + res.nb[ axis2 ] = a.nb[ 2 ]; + res.nb[ axis3 ] = a.nb[ 3 ]; + return res; +} + +Tensor MlContext::flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, bool masked ) +{ + if( !canMulMat( k, q ) ) + throw E_INVALIDARG; + + if constexpr( enableInexactOptimizations ) + { + if( !masked ) + { + profiler.setNextTag( "flashAttn.1" ); + Tensor tmp = mulMat( k, q ); + + const float tempScale = (float)( 1.0 / sqrt( (double)(int)q.ne[ 0 ] ) ); + softMax( tmp, tempScale ); + + profiler.setNextTag( "flashAttn.2" ); + return mulMat( v, tmp ); + } + } + + Tensor res = createTensor( eDataType::FP32, q.ne ); + flashAttention( q, k, v, res, masked ); + +#if 0 + Tensor tmpMat = mulMat( k, q ); + float scale = (float)( 1.0 / sqrt( (double)(int)q.ne[ 0 ] ) ); + softMax( tmpMat, scale ); + Tensor testRes = mulMat( v, tmpMat ); + computeDiff( res, testRes ).print( "flashAttention mulmat" ); +#endif + + return res; +} + +Tensor MlContext::copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size ) +{ + const size_t dims = size.size(); + if( 0 == dims || dims > 4 ) + throw E_BOUNDS; + + size_t nRequested = 1; + for( size_t i = 0; i < dims; i++ ) + { + uint32_t n = size.begin()[ i ]; + nRequested *= n; + } + if( nRequested != a.countElements() ) + throw E_INVALIDARG; + + const eDataType st = a.getType(); + Tensor res; + if( a.isContinuous() && st == type ) + { + // Same type, and it's dense - no need to call any compute shaders, equal to reshape + res = a; + for( size_t i = 0; i < dims; i++ ) + res.ne[ i ] = size.begin()[ i ];; + for( size_t i = dims; i < 4; i++ ) + res.ne[ i ] = 1; + res.setDenseStrides(); + } + else + { + // Either converting non-continuous to continuous, or converting types + res = createTensor( type, size ); + copyImpl( a, res, st == eDataType::FP32 && type == eDataType::FP16 ); + } + return res; +} + +void MlContext::copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size ) +{ + assert( type == dest.getType() ); + + const size_t dims = size.size(); + if( 0 == dims || dims > 4 ) + throw E_BOUNDS; + + size_t nRequested = 1; + for( size_t i = 0; i < dims; i++ ) + { + uint32_t n = size.begin()[ i ]; + nRequested *= n; + } + if( nRequested != a.countElements() || nRequested != dest.countElements() ) + throw E_INVALIDARG; + + // Reshape the destination + for( size_t i = 0; i < dims; i++ ) + dest.ne[ i ] = size.begin()[ i ]; + for( size_t i = dims; i < 4; i++ ) + dest.ne[ i ] = 1; + dest.setDenseStrides(); + + // Call the shader + const eDataType st = a.getType(); + copyImpl( a, dest, st == eDataType::FP32 && type == eDataType::FP16 ); +}
\ No newline at end of file diff --git a/Whisper/ML/LookupTables.cpp b/Whisper/ML/LookupTables.cpp new file mode 100644 index 0000000..2fc1cc8 --- /dev/null +++ b/Whisper/ML/LookupTables.cpp @@ -0,0 +1,54 @@ +#include "stdafx.h" +#include "LookupTables.h" +#include "LookupTablesData.h" +#include <memory> +using namespace DirectCompute; + +namespace +{ + HRESULT uploadLookupTable( const std::array<uint16_t, 0x10000>& rsi, CComPtr<ID3D11ShaderResourceView>& rdi ) + { + rdi = nullptr; + CComPtr<ID3D11Buffer> buffer; + + CD3D11_BUFFER_DESC desc{ 0x10000 * 2, D3D11_BIND_SHADER_RESOURCE, D3D11_USAGE_IMMUTABLE }; + D3D11_SUBRESOURCE_DATA srd{ rsi.data(), 0, 0 }; + CHECK( device()->CreateBuffer( &desc, &srd, &buffer ) ); + + CD3D11_SHADER_RESOURCE_VIEW_DESC viewDesc{ D3D11_SRV_DIMENSION_BUFFER, DXGI_FORMAT_R16_UINT, 0, 0x10000 }; + CHECK( device()->CreateShaderResourceView( buffer, &viewDesc, &rdi ) ); + + return S_OK; + } +} + +HRESULT LookupTables::create() +{ + std::unique_ptr<LookupTablesData> data; + try + { + data = std::make_unique<LookupTablesData>(); + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + + CHECK( uploadLookupTable( data->gelu, m_gelu ) ); + CHECK( uploadLookupTable( data->exponent, m_exponent ) ); + + return S_OK; +} + +void LookupTables::clear() +{ + m_gelu = nullptr; + m_exponent = nullptr; +} + +__m128i LookupTables::getMemoryUsage() const +{ + __m128i v = resourceMemoryUsage( m_gelu ); + v = _mm_add_epi64( v, resourceMemoryUsage( m_exponent ) ); + return v; +}
\ No newline at end of file diff --git a/Whisper/ML/LookupTables.h b/Whisper/ML/LookupTables.h new file mode 100644 index 0000000..306b548 --- /dev/null +++ b/Whisper/ML/LookupTables.h @@ -0,0 +1,22 @@ +#pragma once +#include "../D3D/device.h" + +namespace DirectCompute +{ + class LookupTables + { + CComPtr<ID3D11ShaderResourceView> m_gelu, m_exponent; + + public: + + HRESULT create(); + void clear(); + ID3D11ShaderResourceView* gelu() const { return m_gelu; } + ID3D11ShaderResourceView* exponent() const { return m_exponent; } + + __m128i getMemoryUsage() const; + }; + + // Singleton instance, defined in mlStartup.cpp + extern const LookupTables& lookupTables; +}
\ No newline at end of file diff --git a/Whisper/ML/LookupTablesData.cpp b/Whisper/ML/LookupTablesData.cpp new file mode 100644 index 0000000..263bcf7 --- /dev/null +++ b/Whisper/ML/LookupTablesData.cpp @@ -0,0 +1,40 @@ +#include "stdafx.h" +#include "LookupTablesData.h" +#include <immintrin.h> +using namespace DirectCompute; + +namespace +{ + inline float fp32( uint16_t f16 ) + { + __m128i i = _mm_cvtsi32_si128( f16 ); + __m128 f = _mm_cvtph_ps( i ); + return _mm_cvtss_f32( f ); + } + + inline uint16_t fp16( float fp32 ) + { + __m128 f = _mm_set_ss( fp32 ); + __m128i i = _mm_cvtps_ph( f, 0 ); + uint32_t res = (uint32_t)_mm_cvtsi128_si32( i ); + return (uint16_t)res; + } + + constexpr double GELU_COEF_A = 0.044715; + constexpr double SQRT_2_OVER_PI = 0.79788456080286535587989211986876; + + inline float computeGelu( float x ) + { + return (float)( 0.5 * x * ( 1.0 + tanh( SQRT_2_OVER_PI * x * ( 1.0 + GELU_COEF_A * x * x ) ) ) ); + } +} + +LookupTablesData::LookupTablesData() +{ + for( int i = 0; i < 0x10000; i++ ) + { + const float f = fp32( i ); + gelu[ i ] = fp16( computeGelu( f ) ); + exponent[ i ] = fp16( (float)exp( f ) ); + } +}
\ No newline at end of file diff --git a/Whisper/ML/LookupTablesData.h b/Whisper/ML/LookupTablesData.h new file mode 100644 index 0000000..aa9f8ae --- /dev/null +++ b/Whisper/ML/LookupTablesData.h @@ -0,0 +1,14 @@ +#pragma once +#include <stdint.h> +#include <array> + +namespace DirectCompute +{ + struct LookupTablesData + { + std::array<uint16_t, 0x10000> gelu; + std::array<uint16_t, 0x10000> exponent; + + LookupTablesData(); + }; +}
\ No newline at end of file diff --git a/Whisper/ML/MlContext.cpp b/Whisper/ML/MlContext.cpp new file mode 100644 index 0000000..5a29b85 --- /dev/null +++ b/Whisper/ML/MlContext.cpp @@ -0,0 +1,744 @@ +#include "stdafx.h" +#include "MlContext.h" +#include "../D3D/shaderNames.h" +#include "LookupTables.h" +#include "../D3D/shaders.h" +#include "../D3D/Binder.h" +#include "../D3D/MappedResource.h" +#include "../D3D/downloadBuffer.h" +#include "testUtils.h" +#include "reshapedMultiply.h" +using namespace DirectCompute; + +// TODO: change this to a field, and set to false when the GPU doesn't support FP64 math +// Most notably, Intel has dropped the support recently: +// https://www.intel.com/content/www/us/en/developer/articles/guide/lp-api-developer-optimization-guide.html#inpage-nav-3-8-undefined +// "To improve power and performance", LOL +constexpr bool usePreciseComputeShaders = true; + +MlContext::MlContext( Whisper::ProfileCollection& profileColl ) : + profiler( profileColl ) +{ + check( cb.create() ); + check( profiler.create() ); +} + +void MlContext::bindShader( eComputeShader cs ) +{ + DirectCompute::bindShader( cs ); + profiler.computeShader( cs ); +} + +void MlContext::mulMatDot( const Tensor& src0, const Tensor& src1, Tensor& res ) +{ + const auto& size1 = src1.ne; + if( 1 != size1[ 3 ] ) + throw E_UNEXPECTED; + + const size_t tempLength = size1[ 0 ] * size1[ 1 ] * size1[ 2 ] * size1[ 3 ]; + const TensorGpuViews& tempBuffer = temp.fp16( tempLength ); + cb.bind(); + + bindShader( eComputeShader::mulMatDotReshape ); + cb.update( src1 ); + Binder bind; + bind.bind( src1, tempBuffer ); + context()->Dispatch( size1[ 1 ], size1[ 2 ], 1 ); + + bindShader( eComputeShader::mulMatDotMain ); + cb.update( src0, src1, res ); + bind.bind( src0, tempBuffer, res ); + + const auto& size0 = src0.ne; + // total rows in src0 + const uint32_t nr = size0[ 1 ] * size0[ 2 ] * size0[ 3 ]; + context()->Dispatch( size1[ 1 ], nr, 1 ); +} + +void MlContext::mulMatMad( const Tensor& a, const Tensor& b, Tensor& res ) +{ + // CaptureRaii renderDoc; + const uint32_t resultElts = res.countElements(); + constexpr uint32_t nth = 4; + + uint32_t fp16; + TensorGpuViews tempBuffer; + + const eDataType dataType = a.getType(); + if( dataType == eDataType::FP16 ) + { + fp16 = TRUE; + tempBuffer = temp.fp16( resultElts * nth ); + } + else if( dataType == eDataType::FP32 ) + { + fp16 = FALSE; + tempBuffer = temp.fp32( resultElts * nth ); + } + else + throw E_INVALIDARG; + + TensorShape resultShape = res; + resultShape.nb = { fp16, resultElts, 0, 0 }; + + cb.update( a, b, resultShape ); + bindShader( eComputeShader::mulMatMadMain ); + cb.bind(); + + Binder bind; + bind.bind( { a, b }, { res, tempBuffer } ); + context()->Dispatch( b.ne[ 1 ], b.ne[ 2 ], b.ne[ 3 ] ); +} + +void MlContext::mulMatTiled( const Tensor& a, const Tensor& b, Tensor& res ) +{ + cb.update( a, b, res ); + cb.bind(); + + Binder bind; + bind.bind( a, b, res ); + + if( b.ne[ 1 ] == 1 ) + { + if( b.ne[ 0 ] != 1 ) + { +#if 0 + static PrintUniqueTensorSizes printSize( "mulMatByRow" ); + printSize.print( a, b ); +#endif + // Tensor B is a single row, we have optimized compute shaders for that use case + // Even 2 of them, tiled and sequential. Select between these two shaders. + constexpr uint32_t minHeightToTile = 2; + if( a.ne[ 1 ] < minHeightToTile ) + { + bindShader( eComputeShader::mulMatByRow ); + context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] ); + } + else + { + bindShader( eComputeShader::mulMatByRowTiled ); + uint32_t groupsX; + if( gpuInfo.wave64() ) + { + constexpr uint32_t TILE_Y = 128; + groupsX = ( a.ne[ 1 ] + TILE_Y - 1 ) / TILE_Y; + } + else + { + constexpr uint32_t TILE_Y = 64; + groupsX = ( a.ne[ 1 ] + TILE_Y - 1 ) / TILE_Y; + } + context()->Dispatch( groupsX, a.ne[ 2 ], a.ne[ 3 ] ); + } + } + else + { + // Tensor B is a single element: we have an optimized shader for that as well + bindShader( eComputeShader::mulMatByScalar ); + context()->Dispatch( a.ne[ 2 ], a.ne[ 3 ], 1 ); + } + } + else + { + // According to visual studio debugger, when the second argument of this method is a 2D matrix, the first argument is 2D as well. + // Assuming both arguments are 2D matrices. + // For optimal VRAM bandwidth utilization, we compute such matrix products in square tiles, a tile is 32x32 elements. + // Dispatching one thread group for each tile of the output matrix. + bindShader( eComputeShader::mulMatTiled ); + + uint32_t x, y; + // These compute shaders correctly handle partial tiles on the right and bottom edges of the output matrix, that's why rounding up. + if( gpuInfo.wave64() ) + { + constexpr uint32_t TILE_SIZE = 64; + x = ( res.ne[ 0 ] + TILE_SIZE - 1 ) / TILE_SIZE; + y = ( res.ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE; + } + else + { + constexpr uint32_t TILE_SIZE = 32; + x = ( res.ne[ 0 ] + TILE_SIZE - 1 ) / TILE_SIZE; + y = ( res.ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE; + } + + const uint32_t z = res.ne[ 2 ] * res.ne[ 3 ]; + context()->Dispatch( x, y, z ); + } +} + +void MlContext::mulMat( const Tensor& src0, const Tensor& src1, Tensor& res ) +{ + const uint32_t nb00 = src0.nb[ 0 ]; + const uint32_t nb01 = src0.nb[ 1 ]; + if( nb01 >= nb00 ) + mulMatDot( src0, src1, res ); + else + mulMatMad( src0, src1, res ); +} + +namespace +{ + // Must match the HLSL in flashAttention.hlsl + struct sFlashAttentionConstants + { + TensorShape q, k, v, res; + BOOL masked; + float scale; + uint32_t tempBufferStride; + uint32_t zzPadding; + }; + + struct sFlashAttnDispatchInfo + { + uint32_t tempStride; + uint32_t groupsCount; + }; + + sFlashAttnDispatchInfo makeFlashAttentionConstants( CComPtr<ID3D11Buffer>& buffer, const Tensor& q, const Tensor& k, const Tensor& v, Tensor& res, bool masked ) + { + if( nullptr == buffer ) + { + CD3D11_BUFFER_DESC desc{ sizeof( sFlashAttentionConstants ), D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE }; + check( device()->CreateBuffer( &desc, nullptr, &buffer ) ); + } + + sFlashAttnDispatchInfo result; + + sFlashAttentionConstants cb; + cb.q = q; + cb.k = k; + cb.v = v; + cb.res = res; + cb.masked = masked ? TRUE : FALSE; + + const int neq0 = (int)cb.q.ne[ 0 ]; + const int D = neq0; + cb.scale = (float)( 1.0 / sqrt( (double)(int)D ) ); + + const uint32_t nek1 = cb.k.ne[ 1 ]; + constexpr uint32_t align = 32 / 4; + result.tempStride = ( ( nek1 + align - 1 ) / align ) * align; + cb.tempBufferStride = result.tempStride; + cb.zzPadding = 0; + result.groupsCount = cb.q.ne[ 1 ] * cb.q.ne[ 2 ] * cb.q.ne[ 3 ]; + + MappedResource mapped; + check( mapped.map( buffer, false ) ); + memcpy( mapped.data(), &cb, sizeof( cb ) ); + return result; + } +} + +void MlContext::flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, Tensor& res, bool masked ) +{ + sFlashAttnDispatchInfo di = makeFlashAttentionConstants( flashAttentionConstants, q, k, v, res, masked ); + + const uint32_t tempLength = di.tempStride * di.groupsCount; + const TensorGpuViews& tb = temp.fp32( tempLength ); + + csSetCB( flashAttentionConstants ); + ID3D11DeviceContext* const ctx = context(); + + Binder bind; + bind.bind( { q, k, v, lookupTables.exponent() }, { res, tb } ); + + if constexpr( usePreciseComputeShaders && !enableInexactOptimizations ) + { + bindShader( eComputeShader::flashAttentionCompat1 ); + ctx->Dispatch( di.groupsCount, 1, 1 ); + + bindShader( eComputeShader::flashAttentionCompat2 ); + ctx->Dispatch( ( di.groupsCount + 31 ) / 32, 1, 1 ); + + bindShader( eComputeShader::flashAttentionCompat3 ); + ctx->Dispatch( di.groupsCount, 1, 1 ); + } + else + { + // This version is not too bad, e.g. maxAbsDiff = 2.7895e-05, avgDiffSquared = 1.61783e-14 + // And probably much faster. + // But still, it does not deliver bitwise equality with the reference CPU version + bindShader( eComputeShader::flashAttention ); + ctx->Dispatch( di.groupsCount, 1, 1 ); + } +} + +namespace +{ + // Round up the number to be a multiple of 32 + inline uint32_t roundUp32( uint32_t x ) + { + return ( x + 31 ) & ( ~31u ); + } +} + +void MlContext::convolutionImpl( const Tensor& a, const Tensor& b, Tensor& res, bool is2 ) +{ + const uint32_t ne00 = a.ne[ 0 ]; + const uint32_t ne01 = a.ne[ 1 ]; + const uint32_t ne02 = a.ne[ 2 ]; + + const uint32_t ne10 = b.ne[ 0 ]; + const uint32_t ne11 = b.ne[ 1 ]; + + const uint32_t nb00 = a.nb[ 0 ]; + const uint32_t nb01 = a.nb[ 1 ]; + const uint32_t nb02 = a.nb[ 2 ]; + + const uint32_t nb10 = b.nb[ 0 ]; + const uint32_t nb11 = b.nb[ 1 ]; + + const uint32_t nb1 = res.nb[ 1 ]; + + const uint32_t ew0 = roundUp32( ne01 ); + + const uint32_t nk = ne00; + const uint32_t nh = nk / 2; + + const uint32_t lenTemp1 = ne02 * ew0 * ne00; + const uint32_t lenTemp2 = ( ne10 + ne00 ) * ew0; + + const TensorGpuViews& temp1 = temp.fp16( lenTemp1, true ); + const TensorGpuViews& temp2 = temp.fp16_2( lenTemp2, true ); + + cb.bind(); + + bindShader( eComputeShader::convolutionPrep1 ); + cb.update( a ); + Binder bind; + bind.bind( a, temp1 ); + context()->Dispatch( ne01, ne02, 1 ); + + bindShader( eComputeShader::convolutionPrep2 ); + cb.update( a, b ); + bind.bind( b, temp2 ); + context()->Dispatch( ne11, 1, 1 ); + + cb.update( a, b, res ); + bind.bind( temp1, temp2, res ); + if( is2 ) + { + if constexpr( enableInexactOptimizations ) + { + constexpr uint32_t KERNEL = 3; + constexpr uint32_t TILE_Y = 8; + if( a.ne[ 0 ] == KERNEL ) + { + const uint32_t x = ( ( ne10 / 2 ) + TILE_Y - 1 ) / TILE_Y; + bindShader( eComputeShader::convolutionMain2Fixed ); + context()->Dispatch( x, ne02, 1 ); + return; + } + } + bindShader( eComputeShader::convolutionMain2 ); + context()->Dispatch( ne10 / 2, ne02, 1 ); + } + else + { + bindShader( eComputeShader::convolutionMain ); + context()->Dispatch( ne10, ne02, 1 ); + } +#if 0 + std::vector<uint16_t> tmp; + downloadBuffer( temp1, tmp ); + dbgWriteBinaryFile( L"conv-gpu-arg1.bin", tmp.data(), lenTemp1 * 2 ); + downloadBuffer( temp2, tmp ); + dbgWriteBinaryFile( L"conv-gpu-arg2.bin", tmp.data(), lenTemp1 * 2 ); + res.download( tempVector ); + dbgWriteBinaryFile( L"conv-gpu-result.bin", tempVector.data(), tempVector.size() * 4 ); +#endif +} + +void MlContext::norm( const Tensor& a, Tensor& res ) +{ + const uint32_t ne01 = a.ne[ 1 ]; + const uint32_t ne02 = a.ne[ 2 ]; + const uint32_t ne03 = a.ne[ 3 ]; + + cb.bind(); + cb.update( a, res ); + Binder bind; + bind.bind( a, res ); + + if constexpr( usePreciseComputeShaders && !enableInexactOptimizations ) + { + bindShader( eComputeShader::normCompat ); + context()->Dispatch( ( ne01 + 31 ) / 32, ne02, ne03 ); + } + else + { + constexpr uint32_t FIXED_ROW_SIZE = 1024; + eComputeShader cs = ( a.ne[ 0 ] == FIXED_ROW_SIZE ) ? eComputeShader::normFixed : eComputeShader::norm; + bindShader( cs ); + context()->Dispatch( ne01, ne02, ne03 ); + } +} + +void MlContext::cwiseBinary( const Tensor& a, const Tensor& b, Tensor& res, eComputeShader cs ) +{ + assert( isSameShape( a, b ) ); + assert( isSameShape( a, res ) ); + + bindShader( cs ); + cb.bind(); + check( cb.update( a, b, res ) ); + Binder bind; + bind.bind( a, b, res ); + + uint32_t rows = a.countRows(); + context()->Dispatch( rows, 1, 1 ); +} + +Tensor MlContext::add( const Tensor& a, const Tensor& b ) +{ + return cwiseBinary( a, b, eComputeShader::add ); +} + +void MlContext::addInPlace( Tensor& a, const Tensor& b ) +{ + if( !isSameShape( a, b ) ) + throw E_INVALIDARG; + assert( a.getType() == eDataType::FP32 ); + + check( cb.update( a, b ) ); + bindShader( eComputeShader::addInPlace ); + cb.bind(); + + Binder bind; + bind.bind( b, a ); + context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] ); +} + +void MlContext::copyImpl( const Tensor& a, Tensor& res, bool downcastFp32 ) +{ + assert( res.isContinuous() ); + const eComputeShader cs = a.isContinuous() ? eComputeShader::copyConvert : eComputeShader::copyTranspose; + bindShader( cs ); + + cb.bind(); + // These two shaders don't need shape of the destination because dense, but they wants a boolean flag whether to implement rounding while downcasting + TensorShape dummyShape; + dummyShape.setZero(); + dummyShape.ne[ 0 ] = downcastFp32 ? TRUE : FALSE; + check( cb.update( a, dummyShape ) ); + + Binder bind; + bind.bind( a, res ); + context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] ); +} + +namespace +{ + uint32_t bitcast( float val ) + { + __m128 f = _mm_set_ss( val ); + __m128i i = _mm_castps_si128( f ); + return (uint32_t)_mm_cvtsi128_si32( i ); + } +} + +void MlContext::scale( Tensor& a, float mul ) +{ + if( !a.isContinuous() ) + throw E_INVALIDARG; + + bindShader( eComputeShader::scaleInPlace ); + cb.bind(); + TensorShape dummyShape; + dummyShape.setZero(); + dummyShape.ne[ 0 ] = bitcast( mul ); + check( cb.update( a, dummyShape ) ); + + Binder bind; + bind.bind( a ); + context()->Dispatch( a.countRows(), 1, 1 ); +} + +void MlContext::addRepeat( Tensor& a, const Tensor& b ) +{ + check( cb.update( a, b ) ); + bindShader( eComputeShader::addRepeat ); + cb.bind(); + + Binder bind; + bind.bind( b, a ); + context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] ); +} + +void MlContext::addRepeatScale( Tensor& a, const Tensor& b, float scale ) +{ +#if 0 + addRepeat( a, b ); + this->scale( a, scale ); + return; +#endif + + TensorShape dummyShape; + dummyShape.setZero(); + dummyShape.ne[ 0 ] = bitcast( scale ); + check( cb.update( a, b, dummyShape ) ); + bindShader( eComputeShader::addRepeatScale ); + cb.bind(); + + Binder bind; + bind.bind( b, a ); + context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] ); +} + +void MlContext::fmaRepeat( Tensor& a, const Tensor& mul, const Tensor& add ) +{ + eComputeShader cs; + if( isSameShapeAndLayout( mul, add ) ) + { + cs = eComputeShader::fmaRepeat1; + check( cb.update( a, mul ) ); + } + else + { + cs = eComputeShader::fmaRepeat2; + check( cb.update( a, mul, add ) ); + } + + bindShader( cs ); + cb.bind(); + Binder bind; + bind.bind( mul, add, a ); + context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] ); +} + +void MlContext::diagMaskInf( Tensor& a, uint32_t n_past ) +{ + if( !a.isContinuous() ) + throw E_INVALIDARG; + + bindShader( eComputeShader::diagMaskInf ); + TensorShape dummyShape; + dummyShape.setZero(); + dummyShape.ne[ 0 ] = n_past; + + cb.bind(); + check( cb.update( a, dummyShape ) ); + + Binder bind; + bind.bind( a ); + + const uint32_t n = a.countRows(); + const uint32_t nr = a.ne[ 1 ]; + const uint32_t nz = n / nr; + context()->Dispatch( nr, nz, 1 ); +} + +void MlContext::softMax( Tensor& a, float inputScale ) +{ + if( !a.isContinuous() ) + throw E_INVALIDARG; + + if constexpr( usePreciseComputeShaders && !enableInexactOptimizations ) + { + assert( inputScale == 1.0f ); + bindShader( eComputeShader::softMaxCompat ); + const uint32_t nr = a.countRows(); + TensorShape dummyShape; + dummyShape.setZero(); + dummyShape.ne[ 0 ] = nr; + + cb.bind(); + check( cb.update( a, dummyShape ) ); + + Binder bind; + bind.bind( lookupTables.exponent(), a ); + context()->Dispatch( ( nr + 31 ) / 32, 1, 1 ); + } + else + { +#if 0 + static PrintUniqueTensorSizes printSizes( "softMax" ); + printSizes.print( a ); +#endif + constexpr uint32_t FIXED_ROW_SIZE = 1500; + eComputeShader cs = ( a.ne[ 0 ] == FIXED_ROW_SIZE ) ? eComputeShader::softMaxFixed : eComputeShader::softMax; + bindShader( cs ); + const uint32_t nr = a.countRows(); + TensorShape dummyShape; + dummyShape.setZero(); + dummyShape.ne[ 0 ] = nr; + dummyShape.ne[ 1 ] = bitcast( inputScale ); + + cb.bind(); + check( cb.update( a, dummyShape ) ); + + Binder bind; + bind.bind( lookupTables.exponent(), a ); + context()->Dispatch( nr, 1, 1 ); + } +} + +void MlContext::addRepeatGelu( Tensor& a, const Tensor& b ) +{ + check( cb.update( a, b ) ); + bindShader( eComputeShader::addRepeatGelu ); + cb.bind(); + + Binder bind; + bind.bind( b, lookupTables.gelu(), a ); + context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] ); +} + +namespace +{ + inline bool canAddRows( const Tensor& tokenEmbedding, const Tensor& positionalEmbedding, const Tensor& embd, uint32_t pastTokensCount ) + { + if( tokenEmbedding.ne[ 0 ] != positionalEmbedding.ne[ 0 ] ) + return false; // Different row lengths + if( embd.ne[ 0 ] + pastTokensCount > positionalEmbedding.ne[ 1 ] ) + return false; // Too many rows requested, positionalEmbedding matrix doesn't have that many + return true; + } +} + +Tensor MlContext::addRows( const Tensor& tokenEmbedding, const Tensor& positionalEmbedding, const Tensor& embd, uint32_t pastTokensCount ) +{ + if( !canAddRows( tokenEmbedding, positionalEmbedding, embd, pastTokensCount ) ) + throw E_INVALIDARG; + + const uint32_t rowLength = tokenEmbedding.ne[ 0 ]; + const uint32_t rows = embd.ne[ 0 ]; + Tensor result = createTensor( eDataType::FP32, { rowLength, rows } ); + + TensorShape constants; + // rowLength + constants.ne[ 0 ] = rowLength; + // pastTokensCount + constants.ne[ 1 ] = pastTokensCount; + // outputRowStride + constants.ne[ 2 ] = result.nb[ 1 ]; + // embStrides + constants.nb[ 0 ] = tokenEmbedding.nb[ 0 ]; + constants.nb[ 1 ] = tokenEmbedding.nb[ 1 ]; + // posStrides + constants.nb[ 2 ] = positionalEmbedding.nb[ 0 ]; + constants.nb[ 3 ] = positionalEmbedding.nb[ 1 ]; + check( cb.update( constants ) ); + + bindShader( eComputeShader::addRows ); + cb.bind(); + Binder bind; + bind.bind( { tokenEmbedding, positionalEmbedding, embd }, { result } ); + context()->Dispatch( rows, 1, 1 ); + return result; +} + +Tensor MlContext::reshapePanels( const Tensor& a ) +{ + constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE; + + const eDataType dataType = a.getType(); + // Reshaping into column major horizontal panels, height = TILE_SIZE, width = width of the source matrix + + // Round height to multiple of tile size + std::array<uint32_t, 4> ne = a.ne; + // Dispatch a group of threads thread per panel + const uint32_t groupsX = ( ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE; + ne[ 1 ] = groupsX * TILE_SIZE;; + // Each panel has [ size.x, TILE_SIZE ] elements + const uint32_t panelSize = ne[ 0 ] * TILE_SIZE; + + Tensor result = createTensor( dataType, ne ); + + TensorShape constants; + constants.setZero(); + // uint panelSize : packoffset( c2.y ); + constants.ne[ 1 ] = panelSize; + // uint2 layerStrides: packoffset( c2.z ); + constants.ne[ 2 ] = result.nb[ 2 ]; + constants.ne[ 3 ] = result.nb[ 3 ]; + + check( cb.update( a, constants ) ); + bindShader( eComputeShader::matReshapePanels ); + cb.bind(); + + Binder bind; + bind.bind( a, result ); + context()->Dispatch( groupsX, a.ne[ 2 ], a.ne[ 3 ] ); + +#if 0 + if( dataType == eDataType::FP32 ) + { + std::vector<float> v1, v2; + a.download( v1 ); + result.download( v2 ); + __debugbreak(); + } + else if( dataType == eDataType::FP16 ) + { + std::vector<uint16_t> v1, v2; + a.download( v1 ); + result.download( v2 ); + __debugbreak(); + } +#endif + + // Set up size and stride expected by the mulMatTiledEx compute shader + result.ne = a.ne; + result.nb[ 0 ] = 0; + result.nb[ 1 ] = panelSize; + return result; +} + +Tensor MlContext::mulMatTiledEx( const Tensor& a, const Tensor& b ) +{ + constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE; + + if( !canMulMat( a, b ) ) + throw E_INVALIDARG; // Wrong size + if( 0 != ( a.nb[ 0 ] | b.nb[ 0 ] ) ) + throw E_INVALIDARG; // Both tensors are expected to be pre-transposed into these panels + + Tensor res = createTensor( eDataType::FP32, { a.ne[ 1 ], b.ne[ 1 ], a.ne[ 2 ], b.ne[ 3 ] } ); + + check( cb.update( a, b, res ) ); + bindShader( eComputeShader::mulMatTiledEx ); + cb.bind(); + + Binder bind; + bind.bind( a, b, res ); + + const uint32_t x = ( res.ne[ 0 ] + TILE_SIZE - 1 ) / TILE_SIZE; + const uint32_t y = ( res.ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE; + const uint32_t z = res.ne[ 2 ] * res.ne[ 3 ]; + context()->Dispatch( x, y, z ); + + return res; +} + +Tensor MlContext::mulMatByRowTiledEx( const Tensor& a, const Tensor& b ) +{ + constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE; + assert( canMulMat( a, b ) ); + assert( b.ne[ 1 ] == 1 ); + + Tensor res = createTensor( eDataType::FP32, { a.ne[ 1 ], 1, a.ne[ 2 ], b.ne[ 3 ] } ); + + check( cb.update( a, b, res ) ); + bindShader( eComputeShader::mulMatByRowTiledEx ); + cb.bind(); + + Binder bind; + bind.bind( a, b, res ); + + const uint32_t x = ( res.ne[ 0 ] + TILE_SIZE - 1 ) / TILE_SIZE; + const uint32_t y = res.ne[ 2 ]; + const uint32_t z = res.ne[ 3 ]; + context()->Dispatch( x, y, z ); + + return res; +} + +__m128i MlContext::getMemoryUse() const +{ + __m128i v = cb.getMemoryUse(); + v = _mm_add_epi64( v, temp.getMemoryUse() ); + v = _mm_add_epi64( v, bufferMemoryUsage( flashAttentionConstants ) ); + v = _mm_add_epi64( v, lookupTables.getMemoryUsage() ); + return v; +}
\ No newline at end of file diff --git a/Whisper/ML/MlContext.dbg.cpp b/Whisper/ML/MlContext.dbg.cpp new file mode 100644 index 0000000..df095aa --- /dev/null +++ b/Whisper/ML/MlContext.dbg.cpp @@ -0,0 +1,59 @@ +#include "stdafx.h" +#include "MlContext.h" +#include "../source/ggml.h" +#include "testUtils.h" +using namespace DirectCompute; + +#define E_TYPE HRESULT_FROM_WIN32( ERROR_DATATYPE_MISMATCH ) + +static void dbgPrintSizeDiff( const char* what, __m128i ref, __m128i gpu ) +{ + std::array<int, 8> a; + _mm_storeu_si128( ( __m128i* ) & a[ 0 ], ref ); + _mm_storeu_si128( ( __m128i* ) & a[ 4 ], gpu ); + printf( "%s; reference [ %i, %i, %i, %i ], GPGPU [ %i, %i, %i, %i ]\n", + what, + a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ], + a[ 4 ], a[ 5 ], a[ 6 ], a[ 7 ] ); +} + +void MlContext::dbgPrintDifference( const ggml_tensor* reference, const Tensor& gpu, const char* what, bool trapToDebugger ) +{ + sTensorDiff diff; + const __m128i gpuSize = gpu.sizeVec(); + const __m128i gpuStrides = gpu.stridesVec(); + __m128i expectedStrides; + if( reference->type == GGML_TYPE_F32 ) + { + if( gpu.getType() != eDataType::FP32 ) + throw E_TYPE; + expectedStrides = _mm_slli_epi32( gpuStrides, 2 ); + + std::vector<float> v; + gpu.download( v ); + diff = computeDiff( v.data(), (const float*)reference->data, v.size() ); + } + else if( reference->type == GGML_TYPE_F16 ) + { + if( gpu.getType() != eDataType::FP16 ) + throw E_TYPE; + expectedStrides = _mm_slli_epi32( gpuStrides, 1 ); + + std::vector<uint16_t> v; + gpu.download( v ); + diff = computeDiff( v.data(), (const uint16_t*)reference->data, v.size() ); + } + else + throw E_NOTIMPL; + + const __m128i ggmlSize = _mm_loadu_si128( ( const __m128i* ) & reference->ne[ 0 ] ); + const __m128i ggmlStrides = _mm_loadu_si128( ( const __m128i* ) & reference->nb[ 0 ] ); + if( !vectorEqual( gpuSize, ggmlSize ) ) + dbgPrintSizeDiff( "dbgPrintDifference - size is different", ggmlSize, gpuSize ); + // if( !vectorEqual( expectedStrides, ggmlStrides ) ) dbgPrintSizeDiff( "dbgPrintDifference - stride is different", ggmlStrides, expectedStrides ); + + diff.print( what ); + + if( trapToDebugger ) + __debugbreak(); +}
\ No newline at end of file diff --git a/Whisper/ML/MlContext.h b/Whisper/ML/MlContext.h new file mode 100644 index 0000000..d0c2d9e --- /dev/null +++ b/Whisper/ML/MlContext.h @@ -0,0 +1,111 @@ +#pragma once +#include <vector> +#include "TempBuffers.h" +#include "ConstantBuffer.h" +#include "Tensor.h" +#include "../Utils/GpuProfiler.h" +#include "../Utils/ProfileCollection.h" + +namespace DirectCompute +{ + enum struct eComputeShader : uint16_t; + + class MlContext + { + // When false, the implementation is 100% compatible with the CPU-running code written by Georgi Gerganov + // When true, the implementation is much faster, and doesn't require FP64 support in the compute shaders. + // FP64 is an optional feature, not all GPUs support that. + static constexpr bool enableInexactOptimizations = true; + + ConstantBuffer cb; + TempBuffers temp; + CComPtr<ID3D11Buffer> flashAttentionConstants; + + void convolutionImpl( const Tensor& a, const Tensor& b, Tensor& res, bool is2 ); + + void cwiseBinary( const Tensor& a, const Tensor& b, Tensor& res, eComputeShader cs ); + Tensor cwiseBinary( const Tensor& a, const Tensor& b, eComputeShader cs ); + + void mulMatDot( const Tensor& a, const Tensor& b, Tensor& res ); + void mulMatMad( const Tensor& a, const Tensor& b, Tensor& res ); + void mulMatTiled( const Tensor& a, const Tensor& b, Tensor& res ); + + void bindShader( eComputeShader cs ); + + protected: + void copyImpl( const Tensor& a, Tensor& res, bool downcastFp32 ); + + // Create a dense output tensor for the results of a computation + // Override this method to implement a pool of these tensors + virtual Tensor createTensor( eDataType type, const std::array<uint32_t, 4>& ne ); + + Tensor createTensor( eDataType type, std::initializer_list<uint32_t> ne ); + + GpuProfiler profiler; + + public: + MlContext( Whisper::ProfileCollection& profileColl ); + MlContext( const MlContext& ) = delete; + + // res = a * b + void mulMat( const Tensor& a, const Tensor& b, Tensor& res ); + + void flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, Tensor& res, bool masked ); + + inline void convolution( const Tensor& a, const Tensor& b, Tensor& res ) + { + convolutionImpl( a, b, res, false ); + } + void convolution2( const Tensor& a, const Tensor& b, Tensor& res ) + { + convolutionImpl( a, b, res, true ); + } + + void norm( const Tensor& a, Tensor& res ); + + Tensor conv_1d_1s( const Tensor& a, const Tensor& b ); + Tensor conv_1d_2s( const Tensor& a, const Tensor& b ); + + Tensor add( const Tensor& a, const Tensor& b ); + void addInPlace( Tensor& a, const Tensor& b ); + + Tensor view2d( const Tensor& a, uint32_t ne0, uint32_t ne1, uint32_t nb1, uint32_t offset ); + Tensor transpose( const Tensor& a ); + + Tensor norm( const Tensor& a ); + Tensor mulMat( const Tensor& a, const Tensor& b ); + Tensor mulMatEx( const Tensor& a, const Tensor& b, const char* tagName ); + Tensor permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 ); + Tensor flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, bool masked ); + + Tensor copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size ); + void copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size ); + + void dbgPrintDifference( const ggml_tensor* reference, const Tensor& gpu, const char * what, bool trapToDebugger = true ); + + void scale( Tensor& a, float mul ); + + void addRepeat( Tensor& a, const Tensor& b ); + void addRepeatScale( Tensor& a, const Tensor& b, float scale ); + void fmaRepeat( Tensor& a, const Tensor& mul, const Tensor& add ); + + // ggml_diag_mask_inf + void diagMaskInf( Tensor& a, uint32_t n_past ); + // ggml_soft_max + void softMax( Tensor& a, float inputScale = 1.0f ); + + void addRepeatGelu( Tensor& a, const Tensor& b ); + + // Extract rows from tokenEmbedding matrix, row indices are taken from the `embd` R32_UINT row vector + // Extract same count of rows from positionalEmbedding matrix, starting at the `pastTokensCount` row + // Return a new FP32 matrix with the sum of these rows + Tensor addRows( const Tensor& tokenEmbedding, const Tensor& positionalEmbedding, const Tensor& embd, uint32_t pastTokensCount ); + + Tensor reshapePanels( const Tensor& a ); + + Tensor mulMatTiledEx( const Tensor& a, const Tensor& b ); + Tensor mulMatByRowTiledEx( const Tensor& a, const Tensor& b ); + + __m128i getMemoryUse() const; + }; +}
\ No newline at end of file diff --git a/Whisper/ML/Reshaper.cpp b/Whisper/ML/Reshaper.cpp new file mode 100644 index 0000000..af66929 --- /dev/null +++ b/Whisper/ML/Reshaper.cpp @@ -0,0 +1,80 @@ +#include "stdafx.h" +#include "Reshaper.h" +#include "../D3D/MappedResource.h" +#include "../D3D/Binder.h" +#include "../D3D/shaders.h" +#include "reshapedMultiply.h" + +namespace +{ + using namespace DirectCompute; + struct Constants + { + // Size and strides of the source tensor + TensorShape arg0; + uint32_t zzPadding; + // Count of elements per panel + uint32_t panelSize; + // Layer strides of the output matrix + std::array<uint32_t, 2> layerStrides; + }; +} + +HRESULT DirectCompute::Reshaper::createConstants() +{ + constexpr uint32_t cb = sizeof( Constants ); + CD3D11_BUFFER_DESC desc{ cb, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE }; + return device()->CreateBuffer( &desc, nullptr, &constantBuffer ); +} + +HRESULT DirectCompute::Reshaper::makePanels( Tensor& tensor, eDataType dataType ) +{ + if( !constantBuffer ) + CHECK( createConstants() ); + + constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE; + + // Reshaping into column major horizontal panels, height = TILE_SIZE, width = width of the source matrix + + std::array<uint32_t, 4> ne = tensor.ne; + const uint32_t groupsX = ( ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE; + ne[ 1 ] = groupsX * TILE_SIZE;; + // Each panel has [ size.x, TILE_SIZE ] elements + const uint32_t panelSize = ne[ 0 ] * TILE_SIZE; + + Tensor result; + result.create( dataType, ne ); + + { + MappedResource mapped; + CHECK( mapped.map( constantBuffer, false ) ); + Constants& cb = *(Constants*)mapped.data(); + + store( cb.arg0.ne, tensor.sizeVec() ); + store( cb.arg0.nb, tensor.stridesVec() ); + cb.panelSize = panelSize; + cb.layerStrides[ 0 ] = result.nb[ 2 ]; + cb.layerStrides[ 1 ] = result.nb[ 3 ]; + } + + csSetCB( constantBuffer ); + { + Binder bind; + bind.bind( tensor, result ); + bindShader( eComputeShader::matReshapePanels ); + context()->Dispatch( groupsX, tensor.ne[ 2 ], tensor.ne[ 3 ] ); + } + + tensor.nb[ 0 ] = 0; + tensor.nb[ 1 ] = panelSize; + tensor.nb[ 2 ] = result.nb[ 2 ]; + tensor.nb[ 3 ] = result.nb[ 3 ]; + tensor.setGpuViews( result ); + return S_OK; +} + +DirectCompute::Reshaper::~Reshaper() +{ + if( constantBuffer ) + csSetCB( nullptr ); +}
\ No newline at end of file diff --git a/Whisper/ML/Reshaper.h b/Whisper/ML/Reshaper.h new file mode 100644 index 0000000..3565cbb --- /dev/null +++ b/Whisper/ML/Reshaper.h @@ -0,0 +1,17 @@ +#pragma once +#include "Tensor.h" + +namespace DirectCompute +{ + // This class reshapes some of the model’s tensor, immediately after they’re loaded. + // That feature is used on all AMD GPUs. + class Reshaper + { + CComPtr<ID3D11Buffer> constantBuffer; + HRESULT createConstants(); + + public: + ~Reshaper(); + HRESULT makePanels( Tensor& tensor, eDataType dataType ); + }; +}
\ No newline at end of file diff --git a/Whisper/ML/TempBuffers.cpp b/Whisper/ML/TempBuffers.cpp new file mode 100644 index 0000000..8823364 --- /dev/null +++ b/Whisper/ML/TempBuffers.cpp @@ -0,0 +1,88 @@ +#include "stdafx.h" +#include "TempBuffers.h" +#include "../D3D/createBuffer.h" +#include "../D3D/MappedResource.h" +#include "../D3D/shaders.h" +using namespace DirectCompute; + +#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } + +HRESULT TempBuffers::Buffer::resize( DXGI_FORMAT format, size_t elements, size_t cbElement, bool zeroMemory, CComPtr<ID3D11Buffer>& cb ) +{ + if( elements <= capacity ) + { + if( zeroMemory ) + TempBuffers::zeroMemory( *this, (uint32_t)elements, cb ); + return S_OK; + } + clear(); + + CComPtr<ID3D11Buffer> buffer; + const size_t totalBytes = elements * cbElement; + CHECK( createBuffer( eBufferUse::ReadWrite, totalBytes, &buffer, nullptr, nullptr ) ); + CHECK( TensorGpuViews::create( buffer, format, elements, true ) ); + capacity = elements; + return S_OK; +} + +void TempBuffers::zeroMemory( ID3D11UnorderedAccessView* uav, uint32_t length, CComPtr<ID3D11Buffer>& cb ) +{ + const __m128i cbData = _mm_cvtsi32_si128( (int)length ); + if( cb ) + { + MappedResource mapped; + check( mapped.map( cb, false ) ); + store16( mapped.data(), cbData ); + } + else + { + CD3D11_BUFFER_DESC desc{ 16, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE }; + std::array<uint32_t, 4> cbBuffer; + store( cbBuffer, cbData ); + D3D11_SUBRESOURCE_DATA srd{ cbBuffer.data(), 0, 0 }; + check( device()->CreateBuffer( &desc, &srd, &cb ) ); + } + + ID3D11DeviceContext* ctx = context(); + ctx->CSSetUnorderedAccessViews( 0, 1, &uav, nullptr ); + csSetCB( cb ); + + constexpr uint32_t THREADS = 512; + constexpr uint32_t ITERATIONS = 128; + constexpr uint32_t elementsPerGroup = THREADS * ITERATIONS; + const uint32_t countGroups = ( length + elementsPerGroup - 1 ) / elementsPerGroup; + bindShader( eComputeShader::zeroMemory ); + ctx->Dispatch( countGroups, 1, 1 ); +} + +const TensorGpuViews& TempBuffers::fp16( size_t countElements, bool zeroMemory ) +{ + HRESULT hr = m_fp16.resize( DXGI_FORMAT_R16_FLOAT, countElements, 2, zeroMemory, smallCb ); + if( FAILED( hr ) ) + throw hr; + return m_fp16; +} + +const TensorGpuViews& TempBuffers::fp16_2( size_t countElements, bool zeroMemory ) +{ + HRESULT hr = m_fp16_2.resize( DXGI_FORMAT_R16_FLOAT, countElements, 2, zeroMemory, smallCb ); + if( FAILED( hr ) ) + throw hr; + return m_fp16_2; +} + +const TensorGpuViews& TempBuffers::fp32( size_t countElements, bool zeroMemory ) +{ + HRESULT hr = m_fp32.resize( DXGI_FORMAT_R32_FLOAT, countElements, 4, zeroMemory, smallCb ); + if( FAILED( hr ) ) + throw hr; + return m_fp32; +} + +__m128i TempBuffers::getMemoryUse() const +{ + size_t cb = m_fp16.getCapacity() * 2; + cb += m_fp16_2.getCapacity() * 2; + cb += m_fp32.getCapacity() * 4; + return setHigh_size( cb ); +}
\ No newline at end of file diff --git a/Whisper/ML/TempBuffers.h b/Whisper/ML/TempBuffers.h new file mode 100644 index 0000000..f376ed6 --- /dev/null +++ b/Whisper/ML/TempBuffers.h @@ -0,0 +1,48 @@ +#pragma once +#include "TensorGpuViews.h" + +namespace DirectCompute +{ + class TempBuffers + { + class Buffer : public TensorGpuViews + { + size_t capacity = 0; + + public: + + void clear() + { + TensorGpuViews::clear(); + capacity = 0; + } + + HRESULT resize( DXGI_FORMAT format, size_t elements, size_t cbElement, bool zeroMemory, CComPtr<ID3D11Buffer>& cb ); + + size_t getCapacity() const { return capacity; } + }; + + Buffer m_fp16; + Buffer m_fp16_2; + Buffer m_fp32; + + public: + + CComPtr<ID3D11Buffer> smallCb; + + static void zeroMemory( ID3D11UnorderedAccessView* uav, uint32_t length, CComPtr<ID3D11Buffer>& cb ); + + const TensorGpuViews& fp16( size_t countElements, bool zeroMemory = false ); + const TensorGpuViews& fp16_2( size_t countElements, bool zeroMemory = false ); + const TensorGpuViews& fp32( size_t countElements, bool zeroMemory = false ); + + void clear() + { + m_fp16.clear(); + m_fp16_2.clear(); + m_fp32.clear(); + } + + __m128i getMemoryUse() const; + }; +}
\ No newline at end of file diff --git a/Whisper/ML/Tensor.cpp b/Whisper/ML/Tensor.cpp new file mode 100644 index 0000000..5542f5a --- /dev/null +++ b/Whisper/ML/Tensor.cpp @@ -0,0 +1,340 @@ +#include "stdafx.h" +#include "Tensor.h" +#include "../D3D/MappedResource.h" +#include "../D3D/createBuffer.h" +#include "../source/ggml.h" +using namespace DirectCompute; + +Tensor::Tensor( const Tensor& that ) +{ + ne = that.ne; + nb = that.nb; + srv = that.srv; + uav = that.uav; +#ifdef _DEBUG + dbgType = that.dbgType; +#endif +} + +Tensor::Tensor( Tensor&& that ) noexcept +{ + ne = that.ne; + nb = that.nb; + srv.Attach( that.srv.Detach() ); + uav.Attach( that.uav.Detach() ); +#ifdef _DEBUG + dbgType = that.dbgType; +#endif +} + +Tensor& Tensor::operator=( const Tensor& that ) +{ + ne = that.ne; + nb = that.nb; + srv = that.srv; + uav = that.uav; +#ifdef _DEBUG + dbgType = that.dbgType; +#endif + return *this; +} + +Tensor& Tensor::operator=( Tensor&& that ) noexcept +{ + ne = that.ne; + nb = that.nb; + srv.Attach( that.srv.Detach() ); + uav.Attach( that.uav.Detach() ); +#ifdef _DEBUG + dbgType = that.dbgType; +#endif + return *this; +} + +Tensor::Tensor( const TensorShape& shape, CComPtr<ID3D11ShaderResourceView>& srv, CComPtr<ID3D11UnorderedAccessView>& uav ) noexcept : + TensorShape( shape ) +{ + TensorGpuViews::srv.Attach( srv.Detach() ); + TensorGpuViews::uav.Attach( uav.Detach() ); +} + +Tensor::Tensor( const TensorShape& shape, const TensorGpuViews& views ) : + TensorShape( shape ) +{ + srv = views; + uav = views; +} + +HRESULT Tensor::create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData ) +{ + TensorGpuViews::clear(); + + switch( usage ) + { + case eBufferUse::Immutable: + case eBufferUse::ReadWriteDownload: + break; + default: + return E_INVALIDARG; + } + + CComPtr<ID3D11Buffer> buffer; + + CHECK( TensorShape::create( ggml ) ); + const ggml_type dataType = ggml.type; + const uint32_t cbElement = (uint32_t)ggml_type_size( dataType ); + + const size_t totalBytes = ggml_nbytes( &ggml ); + if( totalBytes > INT_MAX ) + return DISP_E_OVERFLOW; + const uint32_t countElements = (uint32_t)( totalBytes / cbElement ); + + { + const void* const rsi = uploadData ? ggml.data : nullptr; + CHECK( createBuffer( usage, totalBytes, &buffer, rsi, nullptr ) ); + } + + DXGI_FORMAT format; + eDataType type; + switch( dataType ) + { + case GGML_TYPE_F16: + format = DXGI_FORMAT_R16_FLOAT; + type = eDataType::FP16; + break; + case GGML_TYPE_F32: + format = DXGI_FORMAT_R32_FLOAT; + type = eDataType::FP32; + break; + default: + return E_NOTIMPL; + } + + const bool makeUav = ( usage == eBufferUse::ReadWrite ); + + CHECK( TensorGpuViews::create( buffer, format, totalBytes / cbElement, makeUav ) ); +#ifdef _DEBUG + dbgType.type = type; + dbgType.usage = usage; + dbgType.hasInitialData = uploadData; +#endif + return S_OK; +} + +HRESULT Tensor::createImmutable( eDataType type, const std::array<int, 4>& size, const void* rsi ) +{ + size_t elts = (uint32_t)size[ 0 ]; + elts *= (uint32_t)size[ 1 ]; + elts *= (uint32_t)size[ 2 ]; + elts *= (uint32_t)size[ 3 ]; + + DXGI_FORMAT format; + size_t cbElement; + switch( type ) + { + case eDataType::FP16: + format = DXGI_FORMAT_R16_FLOAT; + cbElement = 2; + break; + case eDataType::FP32: + format = DXGI_FORMAT_R32_FLOAT; + cbElement = 4; + break; + default: + return E_NOTIMPL; + } + + CComPtr<ID3D11Buffer> buffer; + CHECK( createBuffer( eBufferUse::Immutable, cbElement * elts, &buffer, rsi, nullptr ) ); + CHECK( TensorGpuViews::create( buffer, format, elts, false ) ); + + __m128i v = _mm_loadu_si128( ( const __m128i* )size.data() ); + _mm_storeu_si128( ( __m128i* )ne.data(), v ); + setDenseStrides(); + return S_OK; +} + +HRESULT Tensor::create( eDataType type, std::initializer_list<uint32_t> sizeElements, eBufferUse usage, CComPtr<ID3D11Buffer>& buffer, const void* rsi, ID3D11Buffer** ppStagingBuffer ) +{ + TensorGpuViews::clear(); + + size_t nDims = sizeElements.size(); + if( 0 == nDims || nDims > 4 ) + return E_INVALIDARG; + nDims = std::min( nDims, (size_t)4 ); + size_t totalElements = 1; + for( size_t i = 0; i < nDims; i++ ) + { + uint32_t n = sizeElements.begin()[ i ]; + if( n == 0 ) + return E_INVALIDARG; + ne[ i ] = n; + totalElements *= n; + } + + DXGI_FORMAT format; + size_t cbElement; + switch( type ) + { + case eDataType::FP32: + format = DXGI_FORMAT_R32_FLOAT; + cbElement = 4; + break; + case eDataType::FP16: + format = DXGI_FORMAT_R16_FLOAT; + cbElement = 2; + break; + case eDataType::U32: + format = DXGI_FORMAT_R32_UINT; + cbElement = 4; + break; + default: + return E_NOTIMPL; + } + + const size_t totalBytes = cbElement * totalElements; + if( totalBytes > INT_MAX ) + return DISP_E_OVERFLOW; + + for( size_t i = nDims; i < 4; i++ ) + ne[ i ] = 1; + TensorShape::setDenseStrides(); + + CHECK( createBuffer( usage, totalBytes, &buffer, rsi, ppStagingBuffer ) ); + + CHECK( TensorGpuViews::create( buffer, format, totalBytes / cbElement, true ) ); +#ifdef _DEBUG + dbgType.type = type; + dbgType.usage = usage; + dbgType.hasInitialData = ( nullptr != rsi ); +#endif + return S_OK; +} + +HRESULT Tensor::create( eDataType type, std::initializer_list<uint32_t> sizeElements ) +{ + CComPtr<ID3D11Buffer> buffer; + return create( type, sizeElements, eBufferUse::ReadWrite, buffer, nullptr, nullptr ); +} + +HRESULT Tensor::create( eDataType type, const std::array<uint32_t, 4>& sizeElements ) +{ + std::initializer_list<uint32_t> il( sizeElements.data(), sizeElements.data() + 4 ); + return create( type, il ); +} + +eDataType Tensor::getType() const +{ + ID3D11ShaderResourceView* const srv = *this; + if( nullptr == srv ) + throw OLE_E_BLANK; + + D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc; + srv->GetDesc( &viewDesc ); + const DXGI_FORMAT format = viewDesc.Format; + switch( format ) + { + case DXGI_FORMAT_R32_FLOAT: + return eDataType::FP32; + case DXGI_FORMAT_R16_FLOAT: + return eDataType::FP16; + case DXGI_FORMAT_R32_UINT: + return eDataType::U32; + } + throw E_NOTIMPL; +} + +CComPtr<ID3D11Buffer> Tensor::getBuffer() const +{ + ID3D11ShaderResourceView* const srv = *this; + if( nullptr == srv ) + throw OLE_E_BLANK; + + CComPtr<ID3D11Resource> res; + srv->GetResource( &res ); + + CComPtr<ID3D11Buffer> buff; + check( res.QueryInterface( &buff ) ); + return buff; +} + +uint32_t Tensor::dxgiSizeof( DXGI_FORMAT format ) +{ + switch( format ) + { + case DXGI_FORMAT_R16_FLOAT: + return 2; + case DXGI_FORMAT_R32_FLOAT: + case DXGI_FORMAT_R32_UINT: + return 4; + } + throw E_INVALIDARG; +} + +void Tensor::downloadImpl( const D3D11_SHADER_RESOURCE_VIEW_DESC& viewDesc, uint32_t countElements, size_t cbElement, void* rdi ) const +{ + assert( viewDesc.ViewDimension == D3D_SRV_DIMENSION_BUFFER ); + const uint32_t idxFirst = viewDesc.Buffer.FirstElement; + + CComPtr<ID3D11Buffer> buff = getBuffer(); + D3D11_BUFFER_DESC desc; + buff->GetDesc( &desc ); + desc.BindFlags = 0; + desc.Usage = D3D11_USAGE_STAGING; + desc.CPUAccessFlags = D3D11_CPU_ACCESS_READ; + + CComPtr<ID3D11Buffer> staging; + check( device()->CreateBuffer( &desc, nullptr, &staging ) ); + context()->CopyResource( staging, buff ); + + MappedResource mapped; + check( mapped.map( staging, true ) ); + const uint8_t* rsi = (const uint8_t*)mapped.data(); + rsi += cbElement * idxFirst; + memcpy( rdi, rsi, cbElement * countElements ); +} + +void Tensor::download( std::vector<float>& vec ) const +{ + ID3D11ShaderResourceView* const srv = *this; + if( nullptr == srv ) + throw OLE_E_BLANK; + + D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc; + srv->GetDesc( &viewDesc ); + if( viewDesc.Format != DXGI_FORMAT_R32_FLOAT ) + throw E_INVALIDARG; + + uint32_t countElements = viewDesc.Buffer.NumElements; + vec.resize( countElements ); + downloadImpl( viewDesc, countElements, 4, vec.data() ); +} + +void Tensor::download( std::vector<uint16_t>& vec ) const +{ + ID3D11ShaderResourceView* const srv = *this; + if( nullptr == srv ) + throw OLE_E_BLANK; + + D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc; + srv->GetDesc( &viewDesc ); + if( viewDesc.Format != DXGI_FORMAT_R16_FLOAT ) + throw E_INVALIDARG; + + uint32_t countElements = viewDesc.Buffer.NumElements; + vec.resize( countElements ); + downloadImpl( viewDesc, countElements, 2, vec.data() ); +} + +Tensor Tensor::reshape3d( uint32_t ne0, uint32_t ne1, uint32_t ne2 ) const +{ + if( !isContinuous() ) + throw E_NOTIMPL; + if( countElements() != ne0 * ne1 * ne2 ) + throw E_INVALIDARG; + + Tensor res = *this; + res.ne = { ne0, ne1, ne2, 1 }; + res.setDenseStrides(); + return res; +}
\ No newline at end of file diff --git a/Whisper/ML/Tensor.h b/Whisper/ML/Tensor.h new file mode 100644 index 0000000..cc61b7f --- /dev/null +++ b/Whisper/ML/Tensor.h @@ -0,0 +1,78 @@ +#pragma once +#include "TensorShape.h" +#include "TensorGpuViews.h" +#include "../D3D/enums.h" + +namespace DirectCompute +{ + // A minimal tensor object sufficient to compute things on GPU, with compute shaders + // This class only takes 48 bytes in system memory, and is very cheap to make copies 'coz GPU objects are reference counted. + class Tensor : public TensorShape, public TensorGpuViews + { + CComPtr<ID3D11Buffer> getBuffer() const; + + struct TensorType + { + eDataType type; + eBufferUse usage; + bool hasInitialData; + }; +#ifdef _DEBUG + // In debug builds, we include a few pieces of data to this class. + TensorType dbgType; +#endif + protected: + HRESULT create( eDataType type, std::initializer_list<uint32_t> sizeElements, eBufferUse usage, CComPtr<ID3D11Buffer>& buffer, const void* rsi, ID3D11Buffer** ppStagingBuffer ); + + static uint32_t dxgiSizeof( DXGI_FORMAT format ); + + void downloadImpl( const D3D11_SHADER_RESOURCE_VIEW_DESC& viewDesc, uint32_t countElements, size_t cbElement, void* rdi ) const; + + public: + Tensor() = default; + + // These copy operators don't copy any data, they merely increment ref.counter of the GPU resources + Tensor( const Tensor& ); + Tensor( Tensor&& that ) noexcept; + Tensor& operator=( const Tensor& that ); + Tensor& operator=( Tensor&& that ) noexcept; + + // Move the provided buffer views into this newly created tensor, and assign the shape + // This destroys old values in the smart pointers + Tensor( const TensorShape& shape, CComPtr<ID3D11ShaderResourceView>& srv, CComPtr<ID3D11UnorderedAccessView>& uav ) noexcept; + + Tensor( const TensorShape& shape, const TensorGpuViews& views ); + + // Create a tensor from the GGML's one + HRESULT create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData ); + + // Create a new dense tensor of the specified size in elements, without initial data + HRESULT create( eDataType type, std::initializer_list<uint32_t> sizeElements ); + HRESULT create( eDataType type, const std::array<uint32_t, 4>& sizeElements ); + HRESULT createImmutable( eDataType type, const std::array<int, 4>& size, const void* rsi ); + + eDataType getType() const; + + // This method should probably only be used to test things + // TensorEx is better for production usage, because it creates staging buffer in advance. + void download( std::vector<float>& vec ) const; + void download( std::vector<uint16_t>& vec ) const; + + // ggml_reshape_3d + Tensor reshape3d( uint32_t ne0, uint32_t ne1, uint32_t ne2 ) const; + + inline void dbgSetType( eDataType dt, bool hasData = false, eBufferUse use = eBufferUse::ReadWrite ) + { +#ifdef _DEBUG + dbgType.type = dt; + dbgType.hasInitialData = hasData; + dbgType.usage = use; +#endif + } + + __m128i getMemoryUse() const + { + return resourceMemoryUsage( srv ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/ML/TensorEx.cpp b/Whisper/ML/TensorEx.cpp new file mode 100644 index 0000000..97e4e30 --- /dev/null +++ b/Whisper/ML/TensorEx.cpp @@ -0,0 +1,97 @@ +#include "stdafx.h" +#include "TensorEx.h" +#include "../D3D/createBuffer.h" +#include "../source/ggml.h" +#include "../D3D/MappedResource.h" +using namespace DirectCompute; + +HRESULT TensorEx::create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData ) +{ + TensorGpuViews::clear(); + buffer = nullptr; + stagingBuffer = nullptr; + + CHECK( TensorShape::create( ggml ) ); + const ggml_type dataType = ggml.type; + const uint32_t cbElement = (uint32_t)ggml_type_size( dataType ); + + const size_t totalBytes = ggml_nbytes( &ggml ); + if( totalBytes > INT_MAX ) + return DISP_E_OVERFLOW; + const uint32_t countElements = (uint32_t)( totalBytes / cbElement ); + + { + const void* const rsi = uploadData ? ggml.data : nullptr; + ID3D11Buffer** ppStagingBuffer = ( usage == eBufferUse::ReadWriteDownload ) ? &stagingBuffer : nullptr; + CHECK( createBuffer( usage, totalBytes, &buffer, rsi, ppStagingBuffer ) ); + } + + DXGI_FORMAT format; + switch( dataType ) + { + case GGML_TYPE_F16: + format = DXGI_FORMAT_R16_FLOAT; + break; + case GGML_TYPE_F32: + format = DXGI_FORMAT_R32_FLOAT; + break; + default: + return E_NOTIMPL; + } + + const bool makeUav = usage == eBufferUse::ReadWrite || usage == eBufferUse::ReadWriteDownload; + return TensorGpuViews::create( buffer, format, totalBytes / cbElement, makeUav ); +} + +HRESULT TensorEx::create( eDataType type, eBufferUse usage, const std::array<uint32_t, 4>& sizeElements ) +{ + TensorGpuViews::clear(); + buffer = nullptr; + stagingBuffer = nullptr; + std::initializer_list<uint32_t> il( sizeElements.data(), sizeElements.data() + 4 ); + + ID3D11Buffer** ppStaging = ( usage == eBufferUse::ReadWriteDownload ) ? &stagingBuffer : nullptr; + return Tensor::create( type, il, usage, buffer, nullptr, ppStaging ); +} + +HRESULT TensorEx::getViewSize( uint32_t& cbElement, uint32_t& countElements ) const +{ + ID3D11ShaderResourceView* const srv = *this; + if( nullptr == srv ) + return OLE_E_BLANK; + + D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc; + srv->GetDesc( &viewDesc ); + + cbElement = dxgiSizeof( viewDesc.Format ); + + assert( viewDesc.ViewDimension == D3D_SRV_DIMENSION_BUFFER ); + assert( viewDesc.Buffer.FirstElement == 0 ); + countElements = viewDesc.Buffer.NumElements; + + return S_OK; +} + +HRESULT TensorEx::download( void* rdi, size_t cb ) const +{ + if( nullptr == stagingBuffer ) + return HRESULT_FROM_WIN32( ERROR_GPIO_OPERATION_DENIED ); // The requested operation is not supported for the specified handle. + + ID3D11DeviceContext* const ctx = context(); + ctx->CopyResource( stagingBuffer, buffer ); + + MappedResource mapped; + CHECK( mapped.map( stagingBuffer, true ) ); + memcpy( rdi, mapped.data(), cb ); + + return S_OK; +} + +HRESULT TensorEx::download( void* rdi ) const +{ + uint32_t cbElement, numElements; + CHECK( getViewSize( cbElement, numElements ) ); + + size_t cb = (size_t)cbElement * numElements; + return download( rdi, cb ); +}
\ No newline at end of file diff --git a/Whisper/ML/TensorEx.h b/Whisper/ML/TensorEx.h new file mode 100644 index 0000000..c82f3d8 --- /dev/null +++ b/Whisper/ML/TensorEx.h @@ -0,0 +1,42 @@ +#pragma once +#include "Tensor.h" + +namespace DirectCompute +{ + // A tensor which supports dynamic updates from CPU, or downloads from VRAM to system RAM + class TensorEx : public Tensor + { + protected: + CComPtr<ID3D11Buffer> buffer; + CComPtr<ID3D11Buffer> stagingBuffer; + + HRESULT getViewSize( uint32_t& cbElement, uint32_t& countElements ) const; + + public: + + HRESULT create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData ); + HRESULT create( eDataType type, eBufferUse usage, const std::array<uint32_t, 4>& sizeElements ); + + HRESULT download( void* rdi, size_t cb ) const; + + HRESULT download( void* rdi ) const; + + template<class E> + HRESULT download( std::vector<E>& vec ) const + { + uint32_t cbElement, numElements; + CHECK( getViewSize( cbElement, numElements ) ); + + try + { + vec.resize( numElements ); + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + + return download( vec.data(), (size_t)cbElement * numElements ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/ML/TensorGpuViews.cpp b/Whisper/ML/TensorGpuViews.cpp new file mode 100644 index 0000000..2788153 --- /dev/null +++ b/Whisper/ML/TensorGpuViews.cpp @@ -0,0 +1,23 @@ +#include "stdafx.h" +#include "TensorGpuViews.h" +using namespace DirectCompute; + +HRESULT TensorGpuViews::create( ID3D11Buffer* gpuBuffer, DXGI_FORMAT format, size_t countElements, bool makeUav ) +{ + srv = nullptr; + uav = nullptr; + + if( countElements > UINT_MAX ) + return DISP_E_OVERFLOW; + + CD3D11_SHADER_RESOURCE_VIEW_DESC viewDesc{ D3D11_SRV_DIMENSION_BUFFER, format, 0, (UINT)countElements }; + CHECK( device()->CreateShaderResourceView( gpuBuffer, &viewDesc, &srv ) ); + + if( makeUav ) + { + CD3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc{ D3D11_UAV_DIMENSION_BUFFER, format , 0, (UINT)countElements }; + CHECK( device()->CreateUnorderedAccessView( gpuBuffer, &uavDesc, &uav ) ); + } + + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/ML/TensorGpuViews.h b/Whisper/ML/TensorGpuViews.h new file mode 100644 index 0000000..ef26473 --- /dev/null +++ b/Whisper/ML/TensorGpuViews.h @@ -0,0 +1,32 @@ +#pragma once +#include <stdint.h> +#include "../D3D/device.h" + +namespace DirectCompute +{ + class TensorGpuViews + { + protected: + CComPtr<ID3D11ShaderResourceView> srv; + CComPtr<ID3D11UnorderedAccessView> uav; + + public: + + operator ID3D11ShaderResourceView* ( ) const { return srv; } + operator ID3D11UnorderedAccessView* ( ) const { return uav; } + + HRESULT create( ID3D11Buffer* buffer, DXGI_FORMAT format, size_t countElements, bool makeUav ); + + void clear() + { + srv = nullptr; + uav = nullptr; + } + + void setGpuViews( ID3D11ShaderResourceView* read, ID3D11UnorderedAccessView* write = nullptr ) + { + srv = read; + uav = write; + } + }; +}
\ No newline at end of file diff --git a/Whisper/ML/TensorShape.cpp b/Whisper/ML/TensorShape.cpp new file mode 100644 index 0000000..7de6fb8 --- /dev/null +++ b/Whisper/ML/TensorShape.cpp @@ -0,0 +1,72 @@ +#include "stdafx.h" +#include "TensorShape.h" +#include "../source/ggml.h" +using namespace DirectCompute; + +TensorShape::TensorShape() +{ + setZero(); +} + +TensorShape::TensorShape( const TensorShape& that ) +{ + _mm_storeu_si128( ( __m128i* )ne.data(), that.sizeVec() ); + _mm_storeu_si128( ( __m128i* )nb.data(), that.stridesVec() ); +} + +void TensorShape::operator=( const TensorShape& that ) +{ + _mm_storeu_si128( ( __m128i* )ne.data(), that.sizeVec() ); + _mm_storeu_si128( ( __m128i* )nb.data(), that.stridesVec() ); +} + +HRESULT TensorShape::create( const ggml_tensor& ggml ) +{ + for( size_t i = 0; i < 4; i++ ) + ne[ i ] = (uint32_t)ggml.ne[ i ]; + + const ggml_type dataType = ggml.type; + // Verify a few things + uint32_t cbElement = (uint32_t)ggml_type_size( dataType ); + for( size_t i = 0; i < 4; i++ ) + { + size_t stride = ggml.nb[ i ]; + if( 0 != stride % cbElement ) + return E_INVALIDARG; + size_t nn = stride / cbElement; + if( nn > UINT_MAX ) + return DISP_E_OVERFLOW; + nb[ i ] = (uint32_t)nn; + } + return S_OK; +} + +TensorShape::TensorShape( const ggml_tensor& ggml ) +{ + HRESULT hr = create( ggml ); + if( FAILED( hr ) ) + throw hr; +} + +void TensorShape::setDenseStrides() +{ + nb[ 0 ] = 1; + nb[ 1 ] = ne[ 0 ]; + const uint32_t p01 = ne[ 0 ] * ne[ 1 ]; + nb[ 2 ] = p01; + nb[ 3 ] = p01 * ne[ 2 ]; +} + +bool DirectCompute::canMulMat( const TensorShape& t0, const TensorShape& t1 ) +{ + /* + return + ( t0.ne[ 0 ] == t1.ne[ 0 ] ) && + ( t0.ne[ 2 ] == t1.ne[ 2 ] ) && + ( t0.ne[ 3 ] == t1.ne[ 3 ] ); */ + __m128i a = t0.sizeVec(); + __m128i b = t1.sizeVec(); + __m128i xx = _mm_xor_si128( a, b ); + xx = _mm_shuffle_epi32( xx, _MM_SHUFFLE( 3, 2, 0, 0 ) ); + return (bool)_mm_testz_si128( xx, xx ); +}
\ No newline at end of file diff --git a/Whisper/ML/TensorShape.h b/Whisper/ML/TensorShape.h new file mode 100644 index 0000000..473b0c9 --- /dev/null +++ b/Whisper/ML/TensorShape.h @@ -0,0 +1,120 @@ +#pragma once +#include <stdint.h> +#include <array> +#include <smmintrin.h> + +struct ggml_tensor; +using HRESULT = long; + +namespace DirectCompute +{ + // This POD structure describes the shape of a tensor. + // It’s used for both GPU tensors in VRAM, and tensors in system memory used by the Hybrid model. + struct TensorShape + { + // Count of elements, up to 4 coordinates + // The unused coordinates are set to 1 + std::array<uint32_t, 4> ne; + + // Strides of the tensor + // For a dense row-major tensor, these numbers are [ 1, ne[0], ne[0]*ne[1], ne[0]*ne[1]*ne[2] ] + // Note that unlike GGML code, these numbers are expressed in elements not bytes, but the meaning is the same + // GPU matrices reshaped into panels are keeping different values here: [ 0, panelSize, panelSize * panelsCount, panelSize * panelsCount * ne[ 2 ] ] + std::array<uint32_t, 4> nb; + + TensorShape(); + TensorShape( const TensorShape& that ); + void operator=( const TensorShape& that ); + HRESULT create( const ggml_tensor& ggml ); + TensorShape( const ggml_tensor& ggml ); + + __m128i sizeVec() const + { + return load( ne ); + } + __m128i stridesVec() const + { + return load( nb ); + } + + uint32_t countRows() const + { + return ne[ 1 ] * ne[ 2 ] * ne[ 3 ]; + } + + uint32_t countElements() const + { + // return ne[ 0 ] * countRows(); + const __m128i a = sizeVec(); + const __m128i b = _mm_srli_si128( a, 4 ); + const __m128i p2 = _mm_mul_epu32( a, b ); + uint64_t res = (uint64_t)_mm_extract_epi64( p2, 1 ); + res *= (uint64_t)_mm_cvtsi128_si64( p2 ); + assert( 0 == ( res >> 32 ) ); + return (uint32_t)res; + } + + // Compute strides from sizes, assuming dense row-major memory layout of the tensor + void setDenseStrides(); + + bool isMatrix() const + { + // return ne[ 2 ] == 1 && ne[ 3 ] == 1; + const uint64_t num = *(const uint64_t*)&ne[ 2 ]; + return num == 0x100000001ull; + } + bool isVector() const + { + return 1 == ne[ 1 ] && isMatrix(); + } + + // True of this tensor is dense and row-major + bool isContinuous() const + { + /* return 1 == nb[ 0 ] && + nb[ 1 ] == nb[ 0 ] * ne[ 0 ] && + nb[ 2 ] == nb[ 1 ] * ne[ 1 ] && + nb[ 3 ] == nb[ 2 ] * ne[ 2 ]; */ + + const __m128i nbv = stridesVec(); + const __m128i nev = sizeVec(); + __m128i tmp = _mm_mullo_epi32( nbv, nev ); // Vertical product of int32 lanes + tmp = _mm_shuffle_epi32( tmp, _MM_SHUFFLE( 2, 1, 0, 0 ) ); // Shift left by 1 int32 lane + tmp = _mm_insert_epi32( tmp, 1, 0 ); // Reset X lane to 1 + return vectorEqual( tmp, nbv ); + } + + // Reset all fields to zero + void setZero() + { + const __m128i z = _mm_setzero_si128(); + _mm_storeu_si128( ( __m128i* )ne.data(), z ); + _mm_storeu_si128( ( __m128i* )nb.data(), z ); + } + }; + + // True when two tensors have equal count of elements + inline bool isSameShape( const TensorShape& t0, const TensorShape& t1 ) + { + __m128i a = t0.sizeVec(); + __m128i b = t1.sizeVec(); + return vectorEqual( a, b ); + } + + // True when two tensors have equal count of elements, and equal VRAM layout too + inline bool isSameShapeAndLayout( const TensorShape& t0, const TensorShape& t1 ) + { + __m128i a, b, x; + a = t0.sizeVec(); + b = t1.sizeVec(); + x = _mm_xor_si128( a, b ); + + a = t0.stridesVec(); + b = t1.stridesVec(); + x = _mm_or_si128( x, _mm_xor_si128( a, b ) ); + return (bool)_mm_testz_si128( x, x ); + } + + // True when we can multiply two tensors of the provided shapes + bool canMulMat( const TensorShape& t0, const TensorShape& t1 ); +}
\ No newline at end of file diff --git a/Whisper/ML/TensorsArena.cpp b/Whisper/ML/TensorsArena.cpp new file mode 100644 index 0000000..f1a4bed --- /dev/null +++ b/Whisper/ML/TensorsArena.cpp @@ -0,0 +1,117 @@ +#include "stdafx.h" +#include "TensorsArena.h" +#include "../D3D/createBuffer.h" +#include <bit> + +uint32_t DirectCompute::defaultNewCapacity( uint32_t current, uint32_t requested ) +{ + if( 0 == current ) + { + // When the current capacity is 0 this means it's the first resize for the pooled tensor + // Create tensor of the exact requested size, as most tensors on these pools are never actually resized. + return requested; + } + else + { + // Implement some reasonable tactics to grow an old tensor + const uint32_t res = std::max( 1024u, std::bit_ceil( requested ) ); + assert( res >= requested ); + return res; + } +} + +using namespace DirectCompute; + +TensorsArena::ArenaImpl::ArenaImpl( eDataType dataType, const sArenaConfig& config ) : + type( dataType ), + pfnNewCap( nullptr != config.pfnCapInner ? config.pfnCapInner : &defaultNewCapacity ) +{ + pool.reserve( config.initialCapOuter ); +} + +Tensor PooledTensor::tensor( eDataType type, const std::array<uint32_t, 4>& ne, pfnNewCapacity pfnNewCap ) +{ + const uint32_t p1 = ne[ 0 ] * ne[ 1 ]; + const uint32_t p2 = ne[ 2 ] * ne[ 3 ]; + const uint32_t count = p1 * p2; + + if( count > capacity ) + { + views.clear(); + const uint32_t newCap = pfnNewCap( capacity, count ); + assert( newCap >= count ); + + const size_t cb = elementSize( type ) * newCap; + CComPtr<ID3D11Buffer> buffer; + check( createBuffer( eBufferUse::ReadWrite, cb, &buffer, nullptr, nullptr ) ); + check( views.create( buffer, viewFormat( type ), newCap, true ) ); + capacity = newCap; + } + + TensorShape shape; + shape.ne = ne; + shape.setDenseStrides(); + Tensor res{ shape, views }; + res.dbgSetType( type ); + return res; +} + +Tensor TensorsArena::ArenaImpl::tensor( const std::array<uint32_t, 4>& ne ) +{ + PooledTensor* res; + if( index >= pool.size() ) + { + assert( index == pool.size() ); + res = &pool.emplace_back(); + } + else + res = &pool[ index ]; + + index++; + return res->tensor( type, ne, pfnNewCap ); +} + +TensorsArena::TensorsArena( const sArenaConfigs& configs ) : + arenas{ ArenaImpl{ eDataType::FP16, configs.fp16 }, ArenaImpl{ eDataType::FP32, configs.fp32 } } +{ + static_assert( 0 == (uint8_t)eDataType::FP16 ); + static_assert( 1 == (uint8_t)eDataType::FP32 ); +} + +Tensor TensorsArena::tensor( eDataType type, const std::array<uint32_t, 4>& ne ) +{ + ArenaImpl& arena = arenas[ (uint8_t)type ]; + return arena.tensor( ne ); +} + +void TensorsArena::reset() +{ + for( ArenaImpl& a : arenas ) + a.reset(); +} + +void TensorsArena::clear() +{ + for( ArenaImpl& a : arenas ) + a.clear(); +} + +__m128i TensorsArena::ArenaImpl::getMemoryUse() const +{ + const size_t cbElement = elementSize( type ); + size_t countElts = 0; + for( const auto& t : pool ) + countElts += t.getCapacity(); + + const size_t cbVideo = cbElement * countElts; + const size_t cbSystem = vectorMemoryUse( pool ); + return setr_size( cbSystem, cbVideo ); +} + +__m128i TensorsArena::getMemoryUse() const +{ + __m128i res = _mm_setzero_si128(); + for( const auto& a : arenas ) + res = _mm_add_epi64( res, a.getMemoryUse() ); + return res; +}
\ No newline at end of file diff --git a/Whisper/ML/TensorsArena.h b/Whisper/ML/TensorsArena.h new file mode 100644 index 0000000..acfaf86 --- /dev/null +++ b/Whisper/ML/TensorsArena.h @@ -0,0 +1,79 @@ +#pragma once +#include "Tensor.h" + +namespace DirectCompute +{ + using pfnNewCapacity = uint32_t( * )( uint32_t current, uint32_t requested ); + + uint32_t defaultNewCapacity( uint32_t current, uint32_t requested ); + + class PooledTensor + { + TensorGpuViews views; + uint32_t capacity = 0; + public: + Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne, pfnNewCapacity pfnNewCap ); + size_t getCapacity() const { return capacity; } + }; + + __interface iTensorArena + { + Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne ); + void reset(); + }; + + class TensorsArena: public iTensorArena + { + public: + struct sArenaConfig + { + pfnNewCapacity pfnCapInner; + size_t initialCapOuter; + }; + + struct sArenaConfigs + { + sArenaConfig fp16, fp32; + }; + + TensorsArena( const sArenaConfigs& configs ); + + Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final; + void reset() override final; + + void clear(); + __m128i getMemoryUse() const; + + private: + + struct ArenaImpl + { + ArenaImpl( eDataType dataType, const sArenaConfig& config ); + + void reset() + { + index = 0; + } + + void clear() + { + index = 0; + pool.clear(); + } + + Tensor tensor( const std::array<uint32_t, 4>& ne ); + __m128i getMemoryUse() const; + + private: + + const eDataType type; + const pfnNewCapacity pfnNewCap; + + std::vector<PooledTensor> pool; + size_t index = 0; + }; + + static constexpr size_t countTypes = 2; + std::array<ArenaImpl, countTypes> arenas; + }; +}
\ No newline at end of file diff --git a/Whisper/ML/mlStartup.cpp b/Whisper/ML/mlStartup.cpp new file mode 100644 index 0000000..815ce52 --- /dev/null +++ b/Whisper/ML/mlStartup.cpp @@ -0,0 +1,27 @@ +#include "stdafx.h" +#include "mlStartup.h" +#include "../D3D/startup.h" +#include "LookupTables.h" + +namespace +{ + static DirectCompute::LookupTables s_tables; +} + +namespace DirectCompute +{ + const LookupTables& lookupTables = s_tables; + + HRESULT mlStartup() + { + CHECK( d3dStartup() ); + CHECK( s_tables.create() ); + return S_OK; + } + + void mlShutdown() + { + s_tables.clear(); + d3dShutdown(); + } +}
\ No newline at end of file diff --git a/Whisper/ML/mlStartup.h b/Whisper/ML/mlStartup.h new file mode 100644 index 0000000..ef5020c --- /dev/null +++ b/Whisper/ML/mlStartup.h @@ -0,0 +1,8 @@ +#pragma once +using HRESULT = long; + +namespace DirectCompute +{ + HRESULT mlStartup(); + void mlShutdown(); +}
\ No newline at end of file diff --git a/Whisper/ML/reshapedMultiply.h b/Whisper/ML/reshapedMultiply.h new file mode 100644 index 0000000..7cf2365 --- /dev/null +++ b/Whisper/ML/reshapedMultiply.h @@ -0,0 +1,10 @@ +#pragma once +#include <stdint.h> + +namespace DirectCompute +{ + namespace ReshapedMultiply + { + constexpr uint32_t TILE_SIZE = 32; + } +}
\ No newline at end of file diff --git a/Whisper/ML/tensorOpsTests.cpp b/Whisper/ML/tensorOpsTests.cpp new file mode 100644 index 0000000..adc020e --- /dev/null +++ b/Whisper/ML/tensorOpsTests.cpp @@ -0,0 +1,183 @@ +#include "stdafx.h" +#include "tensorOpsTests.h" +#include "MlContext.h" +#include "TensorEx.h" +#include "../D3D/shaders.h" +#include "../D3D/Binder.h" +#include "testUtils.h" +#include "../Whisper/WhisperContext.h" + +void DirectCompute::testMulMat( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst, const void* tempBuffer ) +{ + return; + CaptureRaii capture; + const size_t nb00 = src0->nb[ 0 ]; + const size_t nb01 = src0->nb[ 1 ]; + + if( src0->type != GGML_TYPE_F16 ) + return; // TODO + + if( nb01 < nb00 ) + return; // TODO + + WhisperContext& ctx = WhisperContext::current(); + + Tensor arg0, arg1; + check( arg0.create( *src0, eBufferUse::Immutable, true ) ); + check( arg1.create( *src1, eBufferUse::Immutable, true ) ); + TensorEx res; + check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) ); + + ctx.mulMat( arg0, arg1, res ); + + std::vector<float> tv; + check( res.download( tv ) ); + + const size_t len = tv.size(); + computeDiff( tv.data(), (const float*)dst->data, len ).print( "testMulMat-product" ); + +#if 0 + dbgWriteBinaryFile( L"product-orig.bin", dst->data, len * 4 ); + dbgWriteBinaryFile( L"product-gpu.bin", tv.data(), len * 4 ); + __debugbreak(); +#endif +} + +#if 0 +void DirectCompute::testMulMatReshape( const ggml_tensor* src1, const void* tempBuffer ) +{ + Tensor src; + check( src.create( *src1, eBufferUse::Immutable, true ) ); + + const size_t ne10 = (uint32_t)src1->ne[ 0 ]; + const size_t ne11 = (uint32_t)src1->ne[ 1 ]; + const size_t ne12 = (uint32_t)src1->ne[ 2 ]; + const size_t ne13 = (uint32_t)src1->ne[ 3 ]; + if( 1 != ne13 ) + throw E_UNEXPECTED; + const size_t tempLength = ne10 * ne11 * ne12 * ne13; + + Context& ctx = Context::current(); + const ReadWriteViews& temp = ctx.temp.fp16( tempLength ); + + { + Binder bind; + ctx.cb.bind(); + + bindShader( eComputeShader::mulMatDotReshape ); + + ctx.cb.update( src ); + bind.bind( src, temp ); + context()->Dispatch( (UINT)ne11, (UINT)ne12, 1 ); + } + + std::vector<uint16_t> reshaped; + check( downloadBuffer( temp, reshaped ) ); + computeDiff( reshaped.data(), (const uint16_t*)tempBuffer, reshaped.size() ).print( "testMulMatReshape" ); + +#if 0 + dbgWriteBinaryFile( L"fp32.bin", src1->data, ggml_nbytes( src1 ) ); + dbgWriteBinaryFile( L"fp16-cpu.bin", tempBuffer, reshaped.size() * 2 ); + dbgWriteBinaryFile( L"fp16-gpu.bin", reshaped.data(), reshaped.size() * 2 ); + __debugbreak(); +#endif +} +#endif + +void DirectCompute::computeMulMat( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst ) +{ + CaptureRaii capture; + const size_t nb00 = src0->nb[ 0 ]; + const size_t nb01 = src0->nb[ 1 ]; + + if( src0->type != GGML_TYPE_F16 ) + throw E_INVALIDARG; + if( nb01 < nb00 ) + throw E_INVALIDARG; + + WhisperContext& ctx = WhisperContext::current(); + + Tensor arg0, arg1; + check( arg0.create( *src0, eBufferUse::Immutable, true ) ); + check( arg1.create( *src1, eBufferUse::Immutable, true ) ); + TensorEx res; + check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) ); + + ctx.mulMat( arg0, arg1, res ); + + check( res.download( dst->data ) ); +} + +void DirectCompute::testFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, const ggml_tensor* dst ) +{ + CaptureRaii capture; + + Tensor Q, K, V; + TensorEx res; + check( Q.create( *q, eBufferUse::Immutable, true ) ); + check( K.create( *k, eBufferUse::Immutable, true ) ); + check( V.create( *v, eBufferUse::Immutable, true ) ); + check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) ); + + WhisperContext& ctx = WhisperContext::current(); + ctx.flashAttention( Q, K, V, res, masked ); + + std::vector<float> tv; + check( res.download( tv ) ); + + const size_t len = tv.size(); + computeDiff( tv.data(), (const float*)dst->data, len ).print( "testFlashAttention" ); +} + +void DirectCompute::computeFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, ggml_tensor* dst ) +{ + CaptureRaii capture; + + Tensor Q, K, V; + TensorEx res; + check( Q.create( *q, eBufferUse::Immutable, true ) ); + check( K.create( *k, eBufferUse::Immutable, true ) ); + check( V.create( *v, eBufferUse::Immutable, true ) ); + check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) ); + + WhisperContext& ctx = WhisperContext::current(); + ctx.flashAttention( Q, K, V, res, masked ); + + check( res.download( dst->data ) ); +} + +void DirectCompute::testConvolution( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst ) +{ + CaptureRaii capture; + + Tensor arg0, arg1; + check( arg0.create( *src0, eBufferUse::Immutable, true ) ); + check( arg1.create( *src1, eBufferUse::Immutable, true ) ); + TensorEx res; + check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) ); + + WhisperContext& ctx = WhisperContext::current(); + ctx.convolution( arg0, arg1, res ); + + std::vector<float> tv; + check( res.download( tv ) ); + + const size_t len = tv.size(); + computeDiff( tv.data(), (const float*)dst->data, len ).print( "testConvolution" ); +} + +void DirectCompute::computeConvolution( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst ) +{ + CaptureRaii capture; + + Tensor arg0, arg1; + check( arg0.create( *src0, eBufferUse::Immutable, true ) ); + check( arg1.create( *src1, eBufferUse::Immutable, true ) ); + TensorEx res; + check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) ); + + WhisperContext& ctx = WhisperContext::current(); + ctx.convolution( arg0, arg1, res ); + + res.download( dst->data ); +}
\ No newline at end of file diff --git a/Whisper/ML/tensorOpsTests.h b/Whisper/ML/tensorOpsTests.h new file mode 100644 index 0000000..7820bba --- /dev/null +++ b/Whisper/ML/tensorOpsTests.h @@ -0,0 +1,15 @@ +#pragma once +#include "../source/ggml.h" + +namespace DirectCompute +{ + // void testMulMatReshape( const ggml_tensor* src1, const void* tempBuffer ); + void testMulMat( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst, const void* tempBuffer ); + void computeMulMat( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst ); + + void testFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, const ggml_tensor* dst ); + void computeFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, ggml_tensor* dst ); + + void testConvolution( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst ); + void computeConvolution( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst ); +}
\ No newline at end of file diff --git a/Whisper/ML/testUtils.cpp b/Whisper/ML/testUtils.cpp new file mode 100644 index 0000000..8a20e49 --- /dev/null +++ b/Whisper/ML/testUtils.cpp @@ -0,0 +1,334 @@ +#include "stdafx.h" +#include "testUtils.h" +#include <immintrin.h> +#include <atlfile.h> +#include <atlpath.h> + +namespace +{ + using DirectCompute::sTensorDiff; + + __forceinline __m256 load( const float* rsi ) + { + return _mm256_loadu_ps( rsi ); + } + + __forceinline __m256 load( const uint16_t* rsi ) + { + const __m128i iv = _mm_load_si128( ( const __m128i* )rsi ); + return _mm256_cvtph_ps( iv ); + } + + __forceinline void loadPartial( const uint16_t* x, const uint16_t* y, size_t count, __m256& fx, __m256& fy ) + { + __m128i ix, iy; + switch( count ) + { + case 1: // load 2 bytes + ix = _mm_cvtsi32_si128( *x ); + iy = _mm_cvtsi32_si128( *y ); + break; + case 2: // load 4 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + iy = _mm_cvtsi32_si128( *(const int*)y ); + break; + case 3: // load 6 bytes + ix = _mm_cvtsi32_si128( *(const int*)x ); + iy = _mm_cvtsi32_si128( *(const int*)y ); + ix = _mm_insert_epi16( ix, x[ 2 ], 2 ); + iy = _mm_insert_epi16( iy, y[ 2 ], 2 ); + break; + case 4: // load 8 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + break; + case 5: // load 10 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi16( ix, x[ 4 ], 4 ); + iy = _mm_insert_epi16( iy, y[ 4 ], 4 ); + break; + case 6: // load 12 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 ); + break; + case 7: // load 14 bytes + ix = _mm_cvtsi64_si128( *(const int64_t*)x ); + iy = _mm_cvtsi64_si128( *(const int64_t*)y ); + ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 ); + iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 ); + ix = _mm_insert_epi16( ix, x[ 6 ], 6 ); + iy = _mm_insert_epi16( iy, y[ 6 ], 6 ); + break; + default: + fx = fy = _mm256_setzero_ps(); + return; + } + + fx = _mm256_cvtph_ps( ix ); + fy = _mm256_cvtph_ps( iy ); + } + + inline __m128 loadFloat2( const float* rsi ) + { + return _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) ); + } + inline __m128 loadFloat3( const float* rsi ) + { + __m128 f = loadFloat2( rsi ); + f = _mm_insert_ps( f, _mm_load_ss( rsi + 2 ), 0x20 ); + return f; + } + __forceinline void loadPartial( const float* x, const float* y, size_t count, __m256& fx, __m256& fy ) + { + __m128 low1, high1; + __m128 low2, high2; + high1 = high2 = _mm_setzero_ps(); + switch( count ) + { + case 1: + low1 = _mm_load_ss( x ); + low2 = _mm_load_ss( y ); + break; + case 2: + low1 = loadFloat2( x ); + low2 = loadFloat2( y ); + break; + case 3: + low1 = loadFloat3( x ); + low2 = loadFloat3( y ); + break; + case 4: + low1 = _mm_loadu_ps( x ); + low2 = _mm_loadu_ps( y ); + break; + case 5: + low1 = _mm_loadu_ps( x ); + low2 = _mm_loadu_ps( y ); + high1 = _mm_load_ss( x + 4 ); + high2 = _mm_load_ss( y + 4 ); + break; + case 6: + low1 = _mm_loadu_ps( x ); + low2 = _mm_loadu_ps( y ); + high1 = loadFloat2( x + 4 ); + high2 = loadFloat2( y + 4 ); + break; + case 7: // load 14 bytes + low1 = _mm_loadu_ps( x ); + low2 = _mm_loadu_ps( y ); + high1 = loadFloat3( x + 4 ); + high2 = loadFloat3( y + 4 ); + break; + default: + fx = fy = _mm256_setzero_ps(); + return; + } + + fx = _mm256_setr_m128( low1, high1 ); + fy = _mm256_setr_m128( low2, high2 ); + } + + __forceinline float horizontalMaximum( __m256 v ) + { + __m128 s = _mm256_extractf128_ps( v, 1 ); + s = _mm_max_ps( s, _mm256_castps256_ps128( v ) ); + s = _mm_max_ps( s, _mm_movehl_ps( s, s ) ); + s = _mm_max_ss( s, _mm_movehdup_ps( s ) ); + return _mm_cvtss_f32( s ); + } + + __forceinline double horizontalSum( __m256 v ) + { + __m256d d = _mm256_cvtps_pd( _mm256_extractf128_ps( v, 1 ) ); + d = _mm256_add_pd( d, _mm256_cvtps_pd( _mm256_castps256_ps128( v ) ) ); + + __m128d s = _mm256_extractf128_pd( d, 1 ); + s = _mm_add_pd( s, _mm256_castpd256_pd128( d ) ); + s = _mm_add_sd( s, _mm_unpackhi_pd( s, s ) ); + return _mm_cvtsd_f64( s ); + } + + __m256 maskInfNan( __m256 diff, __m256 a, __m256 b ) + { + __m256i ai = _mm256_castps_si256( a ); + __m256i bi = _mm256_castps_si256( b ); + __m256i eqi = _mm256_cmpeq_epi32( ai, bi ); + __m256 eq = _mm256_castsi256_ps( eqi ); + return _mm256_andnot_ps( eq, diff ); + } + + class DiffAcc + { + __m256 maxAbs = _mm256_setzero_ps(); + __m256 sumSquares = _mm256_setzero_ps(); + + public: + + __forceinline void add( __m256 a, __m256 b ) + { + const __m256 neg0 = _mm256_set1_ps( -0.0f ); + __m256 diff = _mm256_sub_ps( b, a ); + diff = maskInfNan( diff, a, b ); + sumSquares = _mm256_fmadd_ps( diff, diff, sumSquares ); + const __m256 absDiff = _mm256_andnot_ps( neg0, diff ); + maxAbs = _mm256_max_ps( maxAbs, absDiff ); + } + + __forceinline sTensorDiff reduce( size_t count ) + { + sTensorDiff res; + res.maxAbsDiff = horizontalMaximum( maxAbs ); + res.avgDiffSquared = (float)( horizontalSum( sumSquares ) / (double)(int64_t)count ); + res.length = count; + return res; + } + }; + + template<class E> + static sTensorDiff __declspec( noinline ) diffVectors( const E* a, const E* b, size_t length ) + { + // const E* const aEnd = a + length; + const E* const aEndAligned = a + ( length / 8 ) * 8; + const size_t remainder = length % 8; + + DiffAcc acc; + for( ; a < aEndAligned; a += 8, b += 8 ) + acc.add( load( a ), load( b ) ); + + if( remainder != 0 ) + { + __m256 va, vb; + loadPartial( a, b, remainder, va, vb ); + acc.add( va, vb ); + } + + return acc.reduce( length ); + } +} + +sTensorDiff DirectCompute::computeDiff( const float* a, const float* b, size_t length ) +{ + return diffVectors( a, b, length ); +} + +sTensorDiff DirectCompute::computeDiff( const uint16_t* a, const uint16_t* b, size_t length ) +{ + return diffVectors( a, b, length ); +} + +void DirectCompute::sTensorDiff::print( const char* what ) const +{ + logDebug( u8"%s: length %zu, maxAbsDiff = %g, avgDiffSquared = %g", what, length, maxAbsDiff, avgDiffSquared ); +} +void DirectCompute::sTensorDiff::print() const +{ + logDebug( u8"%zu elements, maxAbsDiff = %g, avgDiffSquared = %g", length, maxAbsDiff, avgDiffSquared ); +} + +HRESULT DirectCompute::dbgWriteBinaryFile( LPCTSTR fileName, const void* rsi, size_t cb ) +{ + CPath path; + path.m_strPath = LR"(C:\Temp\2remove\Whisper)"; + path.Append( fileName ); + + CAtlFile file; + CHECK( file.Create( path, GENERIC_WRITE, 0, CREATE_ALWAYS ) ); + CHECK( file.Write( rsi, (DWORD)cb ) ); + CHECK( file.Flush() ); + return S_OK; +} + +#include "Tensor.h" + +sTensorDiff DirectCompute::computeDiff( const Tensor& a, const Tensor& b ) +{ + assert( isSameShapeAndLayout( a, b ) ); + const eDataType dt = a.getType(); + assert( dt == b.getType() ); + switch( dt ) + { + case eDataType::FP32: + { + std::vector<float> v1, v2; + a.download( v1 ); + b.download( v2 ); + assert( v1.size() == v2.size() ); +#if 0 + const size_t firstZero = std::find( v2.begin(), v2.end(), 0.0f ) - v2.begin(); + + std::vector<float> delta; + delta.resize( v1.size() ); + for( size_t i = 0; i < v1.size(); i++ ) + delta[ i ] = std::abs( v1[ i ] - v2[ i ] ); + const size_t maxIndex = std::max_element( delta.begin(), delta.end() ) - delta.begin(); +#endif + return computeDiff( v1.data(), v2.data(), v1.size() ); + } + } + throw E_NOTIMPL; +} + +using namespace DirectCompute; + +void PrintUniqueTensorSizes::printImpl( const std::array<uint32_t, 8>& a ) +{ + auto pair = set.emplace( a ); + if( !pair.second ) + return; // was already there + + const __m128i rhs = _mm_loadu_si128( ( const __m128i* ) ( &a[ 4 ] ) ); + + if( _mm_testz_si128( rhs, rhs ) ) + { + logDebug( u8"%s: [ %i, %i, %i, %i ]", what, + a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ] ); + } + else + { + logDebug( u8"%s: [ %i, %i, %i, %i ], [ %i, %i, %i, %i ]", what, + a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ], a[ 4 ], a[ 5 ], a[ 6 ], a[ 7 ] ); + } +} + +void PrintUniqueTensorSizes::print( const Tensor& lhs, const Tensor& rhs ) +{ + std::array<uint32_t, 8> arr; + __m128i* const rdi = ( __m128i* )arr.data(); + _mm_storeu_si128( rdi, lhs.sizeVec() ); + _mm_storeu_si128( rdi + 1, rhs.sizeVec() ); + + printImpl( arr ); +} + +void PrintUniqueTensorSizes::print( const int* lhs, const int* rhs ) +{ + std::array<uint32_t, 8> arr; + __m128i* const rdi = ( __m128i* )arr.data(); + _mm_storeu_si128( rdi, load16( lhs ) ); + _mm_storeu_si128( rdi + 1, load16( rhs ) ); + + printImpl( arr ); +} + +void PrintUniqueTensorSizes::print( const Tensor& lhs ) +{ + std::array<uint32_t, 8> arr; + __m128i* const rdi = ( __m128i* )arr.data(); + _mm_storeu_si128( rdi, lhs.sizeVec() ); + _mm_storeu_si128( rdi + 1, _mm_setzero_si128() ); + + printImpl( arr ); +} + +#include "testUtilsC.h" + +void printUniqueTensorSize( const char* name, const int* lhs, const int* rhs ) +{ + using TS = DirectCompute::PrintUniqueTensorSizes; + static std::unordered_map<std::string, TS> map; + TS& ts = map.try_emplace( name, name ).first->second; + ts.print( lhs, rhs ); +}
\ No newline at end of file diff --git a/Whisper/ML/testUtils.h b/Whisper/ML/testUtils.h new file mode 100644 index 0000000..1e62e38 --- /dev/null +++ b/Whisper/ML/testUtils.h @@ -0,0 +1,62 @@ +#pragma once +#include "../D3D/downloadBuffer.h" +#include "../D3D/RenderDoc/renderDoc.h" +#include <unordered_set> +#include <functional> + +// Funfact: this code written by ChatGPT +namespace std +{ + template<> + struct hash<array<uint32_t, 8>> + { + size_t operator()( const array<uint32_t, 8>& arr ) const + { + size_t result = 0; + for( uint32_t element : arr ) + result = ( result * 31 ) ^ element; + return result; + } + }; +} + +namespace DirectCompute +{ + struct sTensorDiff + { + // maximum( absolute( a - b ) ) + float maxAbsDiff; + // average( ( a - b )^2 ) + float avgDiffSquared; + size_t length; + + void print() const; + void print( const char* what ) const; + }; + + // Compute difference between 2 FP32 vectors + sTensorDiff computeDiff( const float* a, const float* b, size_t length ); + + // Compute difference between 2 FP16 vectors + sTensorDiff computeDiff( const uint16_t* a, const uint16_t* b, size_t length ); + + class Tensor; + sTensorDiff computeDiff( const Tensor& a, const Tensor& b ); + + HRESULT dbgWriteBinaryFile( LPCTSTR fileName, const void* rsi, size_t cb ); + + // Print unique sizes of the two tensors + class PrintUniqueTensorSizes + { + std::unordered_set<std::array<uint32_t, 8>> set; + const char* const what; + void printImpl( const std::array<uint32_t, 8>& a ); + + public: + PrintUniqueTensorSizes( const char* w ) : what( w ) { } + + void print( const Tensor& lhs, const Tensor& rhs ); + void print( const Tensor& lhs ); + void print( const int* lhs, const int* rhs ); + }; +}
\ No newline at end of file diff --git a/Whisper/ML/testUtilsC.h b/Whisper/ML/testUtilsC.h new file mode 100644 index 0000000..2246e92 --- /dev/null +++ b/Whisper/ML/testUtilsC.h @@ -0,0 +1,10 @@ +#pragma once + +#ifdef __cplusplus +extern "C" +{ +#endif + void printUniqueTensorSize( const char* name, const int* lhs, const int* rhs ); +#ifdef __cplusplus +} +#endif
\ No newline at end of file diff --git a/Whisper/Readme.txt b/Whisper/Readme.txt new file mode 100644 index 0000000..fc0e9b8 --- /dev/null +++ b/Whisper/Readme.txt @@ -0,0 +1,9 @@ +This C++ project builds a DLL which actually does the heavy lifting of this project. + +It implements the ML model, handles multimedia files with Media Foundation, captures audio (also with MF), does voice activity detection (custom code running on CPU), and a few smaller things. + +The code requires C++/20, and only tested with Visual Studio 2022. + +When running pure GPGPU model, the DLL requires SSE 4.1 instruction set. + +When running a hybrid model, the DLL requires AVX1, FMA3, F16C, and BMI1 instruction set extensions.
\ No newline at end of file diff --git a/Whisper/Resource.rc b/Whisper/Resource.rc Binary files differnew file mode 100644 index 0000000..cb93b99 --- /dev/null +++ b/Whisper/Resource.rc diff --git a/Whisper/Utils/CpuProfiler.cpp b/Whisper/Utils/CpuProfiler.cpp new file mode 100644 index 0000000..6161e95 --- /dev/null +++ b/Whisper/Utils/CpuProfiler.cpp @@ -0,0 +1,65 @@ +#include "stdafx.h" +#include "CpuProfiler.h" + +namespace +{ + using namespace Whisper; + + inline int64_t qpcNow() + { + int64_t res; + QueryPerformanceCounter( (LARGE_INTEGER*)&res ); + return res; + } + + class CpuTimescale + { + uint64_t frequency = 0; + const int64_t tscStart; + const int64_t qpcStart; + + uint64_t computeTscFrequency(); + + public: + + CpuTimescale() : + tscStart( tscNow() ), + qpcStart( qpcNow() ) + { } + + inline uint64_t computeTicks( uint64_t tsc ) + { + uint64_t freq = frequency; + if( freq == 0 ) + freq = computeTscFrequency(); + + return makeTime( tsc, freq ); + } + }; + + uint64_t __declspec( noinline ) CpuTimescale::computeTscFrequency() + { + int64_t tsc = tscNow(); + int64_t qpc = qpcNow(); + tsc -= tscStart; + qpc -= qpcStart; + + uint64_t qpcFreq; + QueryPerformanceFrequency( (LARGE_INTEGER*)&qpcFreq ); + + // Seconds = qpc / qpcFreq + // ticks per second = tsc / seconds = tsc * qpcFreq / qpc + uint64_t res = ( (uint64_t)tsc * qpcFreq + ( (uint64_t)qpc / 2 ) - 1 ) / (uint64_t)qpc; + frequency = res; + const double GHz = (double)(int64_t)res * 1.0E-9; + logDebug( u8"Computed CPU base frequency: %g GHz", GHz ); + return res; + } + + static CpuTimescale timescale; +} + +uint64_t Whisper::ticksFromTsc( uint64_t tscDiff ) +{ + return timescale.computeTicks( tscDiff ); +}
\ No newline at end of file diff --git a/Whisper/Utils/CpuProfiler.h b/Whisper/Utils/CpuProfiler.h new file mode 100644 index 0000000..17887f8 --- /dev/null +++ b/Whisper/Utils/CpuProfiler.h @@ -0,0 +1,26 @@ +#pragma once + +namespace Whisper +{ + // Get current time in CPU clock + // More specifically, each CPU core has a timestamp counter which runs at CPU's base frequency, regardless on the frequency scaling of that core. + inline int64_t tscNow() + { + return __rdtsc(); + } + + // Scale the time interval from CPU time stamp counter clock into 100-nanosecond ticks, rounding to nearest + uint64_t ticksFromTsc( uint64_t tscDiff ); + + class CpuProfiler + { + const int64_t started = tscNow(); + + public: + + uint64_t elapsed() const + { + return ticksFromTsc( (uint64_t)( tscNow() - started ) ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/Utils/GpuProfiler.cpp b/Whisper/Utils/GpuProfiler.cpp new file mode 100644 index 0000000..6f19415 --- /dev/null +++ b/Whisper/Utils/GpuProfiler.cpp @@ -0,0 +1,374 @@ +#include "stdafx.h" +#include "GpuProfiler.h" +#include "GpuProfilerSimple.h" +using namespace DirectCompute; + +inline void GpuProfiler::sProfilerData::reset() +{ + _mm_storeu_si128( ( __m128i* ) & callsPending, _mm_setzero_si128() ); +} + +inline void GpuProfiler::sProfilerData::addPending( int64_t time ) +{ + callsPending++; + timePending += time; +} + +inline void GpuProfiler::sProfilerData::dropPending() +{ + callsPending = 0; + timePending = 0; +} + +inline void GpuProfiler::sProfilerData::makeTime( uint64_t freq ) +{ + dest->count += callsPending; + dest->totalTicks += ::makeTime( timePending, freq ); + callsPending = 0; + timePending = 0; +} + +HRESULT GpuProfiler::Queue::create() +{ + ID3D11Device* const dev = device(); + + CD3D11_QUERY_DESC desc{ D3D11_QUERY_TIMESTAMP }; + for( Entry& e : queue ) + { + CHECK( dev->CreateQuery( &desc, &e.query ) ); + e.block = nullptr; + e.event = eEvent::None; + e.shader = EmptyShader; + } + return S_OK; +} + +namespace +{ + static uint64_t getTimestamp( ID3D11Query* query ) + { + ID3D11DeviceContext* const ctx = context(); + + uint64_t res = 0; + while( true ) + { + const HRESULT hr = ctx->GetData( query, &res, sizeof( uint64_t ), 0 ); + check( hr ); + if( S_OK == hr ) + return res; +#if 0 + Sleep( 1 ); +#else + for( size_t i = 0; i < 1024; i++ ) + _mm_pause(); +#endif + } + } + + static D3D11_QUERY_DATA_TIMESTAMP_DISJOINT waitForDisjointData( ID3D11Query* query ) + { + ID3D11DeviceContext* const ctx = context(); + ctx->End( query ); + + D3D11_QUERY_DATA_TIMESTAMP_DISJOINT res; + while( true ) + { + const HRESULT hr = ctx->GetData( query, &res, sizeof( D3D11_QUERY_DATA_TIMESTAMP_DISJOINT ), 0 ); + check( hr ); + if( S_OK == hr ) + return res; + Sleep( 1 ); + } + } +} + +void GpuProfiler::Queue::Entry::join( GpuProfiler& owner ) +{ + assert( nullptr != block ); + + uint64_t res = getTimestamp( query ); +#if PROFILER_COLLECT_TAGS + block->haveTimestamp( event, shader, tag, res, owner ); +#else + block->haveTimestamp( event, shader, 0, res, owner ); +#endif + block = nullptr; + event = eEvent::None; + shader = EmptyShader; +} + +void GpuProfiler::Queue::submit( BlockState* block, eEvent evt, uint16_t shader, uint16_t tag ) +{ + // if( evt == GpuProfiler::eEvent::Shader && shader == 0 ) __debugbreak(); + assert( nullptr != block ); + + Entry& e = queue[ nextEntry ]; + if( nullptr != e.block ) + e.join( owner ); + + e.block = block; + e.event = evt; + e.shader = shader; +#if PROFILER_COLLECT_TAGS + e.tag = tag; +#endif + context()->End( e.query ); + nextEntry = ( nextEntry + 1 ) % queueLength; +} + +void GpuProfiler::Queue::join() +{ + while( true ) + { + Entry& e = queue[ nextEntry ]; + if( nullptr == e.block ) + return; + e.join( owner ); + nextEntry = ( nextEntry + 1 ) % queueLength; + } +} + +static inline uint32_t makeTagKey( uint16_t cs, uint16_t tag ) +{ + uint32_t r = cs; + r = r << 16; + r |= tag; + return r; +} + +void GpuProfiler::BlockState::completePrevShader( uint64_t time, GpuProfiler& profiler ) +{ + if( shaderStart == -1 ) + return; + assert( prevShader != EmptyShader ); + const int64_t elapsed = (int64_t)time - shaderStart; + + sProfilerData* dest = nullptr; + auto* p = profiler.results.Lookup( prevShader ); + if( nullptr != p ) + dest = &p->m_value; + else + { + sProfilerData& res = profiler.results[ prevShader ]; + res.dest = &profiler.dest.measure( (eComputeShader)prevShader ); + dest = &res; + } + dest->addPending( elapsed ); + +#if PROFILER_COLLECT_TAGS + if( 0 != prevShaderTag ) + { + const uint32_t key = makeTagKey( prevShader, prevShaderTag ); + auto* pt = profiler.resultsTagged.Lookup( key ); + if( nullptr != pt ) + dest = &pt->m_value; + else + { + sProfilerData& res = profiler.resultsTagged[ key ]; + res.dest = &profiler.dest.measure( (eComputeShader)prevShader, prevShaderTag ); + dest = &res; + } + dest->addPending( elapsed ); + } +#endif + prevShader = EmptyShader; + prevShaderTag = 0; + shaderStart = -1; +} + +void GpuProfiler::BlockState::haveTimestamp( eEvent evt, uint16_t cs, uint16_t tag, uint64_t time, GpuProfiler& profiler ) +{ + switch( evt ) + { + case eEvent::BlockStart: + assert( -1 == timeStart ); + assert( -1 == shaderStart ); + assert( cs == EmptyShader ); + timeStart = (int64_t)time; + if( nullptr != parentBlock ) + parentBlock->completePrevShader( time, profiler ); + return; + case eEvent::BlockEnd: + assert( -1 != timeStart ); + assert( cs == EmptyShader ); + completePrevShader( time, profiler ); + destBlock->addPending( (int64_t)time - timeStart ); + timeStart = -1; + return; + case eEvent::Shader: + assert( cs != EmptyShader ); + // if( cs == (uint16_t)0 ) __debugbreak(); + completePrevShader( time, profiler ); + prevShader = cs; + prevShaderTag = tag; + shaderStart = (int64_t)time; + return; + } + assert( false ); +} + +HRESULT GpuProfiler::create( size_t maxDepth ) +{ + CD3D11_QUERY_DESC desc{ D3D11_QUERY_TIMESTAMP_DISJOINT }; + CHECK( device()->CreateQuery( &desc, &disjoint ) ); + CHECK( queries.create() ); + stack.reserve( maxDepth ); + return S_OK; +} + +void GpuProfiler::blockStart( eProfilerBlock which ) +{ + BlockState* parentBlock; + if( stack.empty() ) + { + context()->Begin( disjoint ); + parentBlock = nullptr; + } + else + parentBlock = *stack.rbegin(); + + BlockState* bs = nullptr; + auto p = blockStates.Lookup( which ); + if( nullptr != p ) + bs = &p->m_value; + else + { + BlockState& block = blockStates[ which ]; + block.destBlock = &results[ (uint16_t)which ]; + block.destBlock->dest = &dest.measure( which ); + bs = █ + } + bs->parentBlock = parentBlock; + queries.submit( bs, eEvent::BlockStart ); + stack.push_back( bs ); +} + +void GpuProfiler::blockEnd() +{ + assert( !stack.empty() ); + BlockState* const bs = *stack.rbegin(); + queries.submit( bs, eEvent::BlockEnd ); + stack.pop_back(); + + if( !stack.empty() ) + return; + + const D3D11_QUERY_DATA_TIMESTAMP_DISJOINT dtsd = waitForDisjointData( disjoint ); + queries.join(); + + if( !dtsd.Disjoint ) + { + // Fortunately, these timers appear to be relatively high resolution. + // Specifically, on the iGPU inside Ryzen 7 5700G that frequency is 1E+8 = 100 MHz + // On nVidia 1080Ti, that frequency is 1E+9 = 1 GHz + const uint64_t freq = dtsd.Frequency; + resultsMakeTime( freq ); + } + else + { + // Something occurred in between the query's ID3D11DeviceContext::Begin and ID3D11DeviceContext::End calls + // that caused the timestamp counter to become discontinuous or disjoint, such as unplugging the AC cord on a laptop, overheating, or throttling up/down due to laptop savings events. + // The timestamp returned by ID3D11DeviceContext::GetData for a timestamp query is only reliable if Disjoint is FALSE. + resultsDropPending(); + } +} + +void GpuProfiler::computeShader( eComputeShader cs ) +{ + assert( !stack.empty() ); + if( !profileShaders ) + return; + + BlockState* const bs = *stack.rbegin(); +#if PROFILER_COLLECT_TAGS + queries.submit( bs, eEvent::Shader, (uint16_t)cs, m_nextTag ); + m_nextTag = 0; +#else + queries.submit( bs, eEvent::Shader, (uint16_t)cs ); +#endif +} + +void GpuProfiler::resultsDropPending() +{ + for( POSITION pos = results.GetStartPosition(); nullptr != pos; ) + results.GetNextValue( pos ).dropPending(); +#if PROFILER_COLLECT_TAGS + for( POSITION pos = resultsTagged.GetStartPosition(); nullptr != pos; ) + resultsTagged.GetNextValue( pos ).dropPending(); +#endif +} + +void GpuProfiler::resultsMakeTime( uint64_t freq ) +{ + for( POSITION pos = results.GetStartPosition(); nullptr != pos; ) + results.GetNextValue( pos ).makeTime( freq ); +#if PROFILER_COLLECT_TAGS + for( POSITION pos = resultsTagged.GetStartPosition(); nullptr != pos; ) + resultsTagged.GetNextValue( pos ).makeTime( freq ); +#endif +} + +void GpuProfiler::resultsReset() +{ + for( POSITION pos = results.GetStartPosition(); nullptr != pos; ) + results.GetNextValue( pos ).reset(); +#if PROFILER_COLLECT_TAGS + for( POSITION pos = resultsTagged.GetStartPosition(); nullptr != pos; ) + resultsTagged.GetNextValue( pos ).reset(); +#endif +} + +#if PROFILER_COLLECT_TAGS +uint16_t __declspec( noinline ) GpuProfiler::setNextTag( const char* name ) +{ + uint16_t tag = dest.makeTagId( name ); + m_nextTag = tag; + return tag; +} +#endif + +HRESULT GpuProfilerSimple::create() +{ + ID3D11Device* const dev = device(); + + CD3D11_QUERY_DESC desc{ D3D11_QUERY_TIMESTAMP_DISJOINT }; + CHECK( dev->CreateQuery( &desc, &disjoint ) ); + + desc.Query = D3D11_QUERY_TIMESTAMP; + CHECK( dev->CreateQuery( &desc, &begin ) ); + CHECK( dev->CreateQuery( &desc, &end ) ); + + context()->Begin( disjoint ); + context()->End( begin ); + return S_OK; +} + +HRESULT GpuProfilerSimple::time( uint64_t& rdi ) const +{ + context()->End( end ); + + try + { + const D3D11_QUERY_DATA_TIMESTAMP_DISJOINT dtsd = waitForDisjointData( disjoint ); + const uint64_t t1 = getTimestamp( begin ); + const uint64_t t2 = getTimestamp( end ); + + if( !dtsd.Disjoint ) + { + rdi = makeTime( t2 - t1, dtsd.Frequency ); + return S_OK; + } + else + { + // Something occurred in between the query's ID3D11DeviceContext::Begin and ID3D11DeviceContext::End calls + // that caused the timestamp counter to become discontinuous or disjoint, such as unplugging the AC cord on a laptop, overheating, or throttling up/down due to laptop savings events. + // The timestamp returned by ID3D11DeviceContext::GetData for a timestamp query is only reliable if Disjoint is FALSE. + rdi = -1; + return S_FALSE; + } + } + catch( HRESULT hr ) + { + return hr; + } +}
\ No newline at end of file diff --git a/Whisper/Utils/GpuProfiler.h b/Whisper/Utils/GpuProfiler.h new file mode 100644 index 0000000..fbc284e --- /dev/null +++ b/Whisper/Utils/GpuProfiler.h @@ -0,0 +1,187 @@ +#pragma once +#include "../D3D/device.h" +#include "ProfileCollection.h" + +namespace DirectCompute +{ + enum struct eProfilerBlock : uint16_t + { + LoadModel = 0x1000, + Run = 0x2000, + Encode = 0x3000, + EncodeLayer = 0x4000, + Decode = 0x5000, + DecodeStep = 0x6000, + DecodeLayer = 0x7000, + }; + + enum struct eComputeShader : uint16_t; + + class GpuProfiler + { + CComPtr<ID3D11Query> disjoint; + + enum struct eEvent + { + None = 0, + BlockStart, + BlockEnd, + Shader + }; + + struct BlockState; + static constexpr uint16_t EmptyShader = ~(uint16_t)0; + + // A circular buffer with in-flight queries which feeds timestamps into the iTimestampSink interface + class Queue + { + static constexpr size_t queueLength = 32; + + // Ring buffer for individual measures + struct Entry + { + CComPtr<ID3D11Query> query; + BlockState* block; + eEvent event; + uint16_t shader; +#if PROFILER_COLLECT_TAGS + uint16_t tag = 0; +#endif + void join( GpuProfiler& owner ); + }; + + GpuProfiler& owner; + std::array<Entry, queueLength> queue; + size_t nextEntry = 0; + + public: + Queue( GpuProfiler& gp ) : owner( gp ) {} + + HRESULT create(); + + // Begin a next query. Eventually, this will result in the BlockState.haveTimestamp callback + void submit( BlockState* block, eEvent evt, uint16_t shader = EmptyShader, uint16_t tag = 0 ); + + // Wait for all the pending queries, and call their callbacks + void join(); + }; + Queue queries; + + struct sProfilerData; + struct BlockState + { + int64_t timeStart = -1; + sProfilerData* destBlock = nullptr; + int64_t shaderStart = -1; + uint16_t prevShader = EmptyShader; + uint16_t prevShaderTag = 0; + BlockState* parentBlock = nullptr; + void haveTimestamp( eEvent evt, uint16_t cs, uint16_t tag, uint64_t time, GpuProfiler& profiler ); + private: + void completePrevShader( uint64_t time, GpuProfiler& profiler ); + }; + CAtlMap<eProfilerBlock, BlockState> blockStates; + std::vector<BlockState*> stack; + + struct sProfilerData + { + // Count of accumulated measures + size_t callsPending; + // Total time spent running all instances of that measure, expressed in GPU ticks + uint64_t timePending; + + Whisper::ProfileCollection::Measure* dest; + + inline void makeTime( uint64_t freq ); + inline void addPending( int64_t time ); + inline void reset(); + inline void dropPending(); + + sProfilerData() + { + reset(); + } + }; + + CAtlMap<uint16_t, sProfilerData> results; +#if PROFILER_COLLECT_TAGS + CAtlMap<uint32_t, sProfilerData> resultsTagged; +#endif + void resultsMakeTime( uint64_t freq ); + void resultsDropPending(); + void resultsReset(); + + void blockStart( eProfilerBlock which ); + void blockEnd(); + + Whisper::ProfileCollection& dest; +#if PROFILER_COLLECT_TAGS + uint16_t m_nextTag = 0; +#endif + public: + + GpuProfiler( Whisper::ProfileCollection& pc ) : + dest( pc ), queries( *this ) { } + + HRESULT create( size_t maxDepth = 3 ); + + class BlockRaii + { + GpuProfiler* profiler; + + public: + BlockRaii( GpuProfiler& owner, eProfilerBlock which ) + { + owner.blockStart( which ); + profiler = &owner; + } + ~BlockRaii() + { + if( nullptr != profiler ) + { + profiler->blockEnd(); + profiler = nullptr; + } + } + BlockRaii( BlockRaii&& that ) noexcept : + profiler( that.profiler ) + { + that.profiler = nullptr; + } + BlockRaii( const BlockRaii& ) = delete; + void operator=( const BlockRaii& ) = delete; + void operator=( BlockRaii&& ) = delete; + }; + + BlockRaii block( eProfilerBlock which ) + { + return BlockRaii{ *this, which }; + } + + void computeShader( eComputeShader cs ); + + bool profileShaders = false; + // bool profileShaders = true; + + decltype( auto ) cpuBlock( Whisper::eCpuBlock block ) + { + return dest.cpuBlock( block ); + } + Whisper::ProfileCollection& profiler() { return dest; } + + // Set tag string for the next compute shader + // The string should be readonly: for performance reason the implementation doesn’t copy nor compare any strings, it only keeps the pointer +#if PROFILER_COLLECT_TAGS + uint16_t setNextTag( const char* name ); +#else + inline uint16_t setNextTag( const char* name ) { return 0; } +#endif + + void setNextTag( uint16_t tag ) + { +#if PROFILER_COLLECT_TAGS + m_nextTag = tag; +#endif + } + }; +}
\ No newline at end of file diff --git a/Whisper/Utils/GpuProfilerSimple.h b/Whisper/Utils/GpuProfilerSimple.h new file mode 100644 index 0000000..7938b44 --- /dev/null +++ b/Whisper/Utils/GpuProfilerSimple.h @@ -0,0 +1,14 @@ +#pragma once +#include "../D3D/device.h" + +namespace DirectCompute +{ + // A simple profiler which doesn't collect anything, used to measure time it took to load the model + class GpuProfilerSimple + { + CComPtr<ID3D11Query> disjoint, begin, end; + public: + HRESULT create(); + HRESULT time( uint64_t& rdi ) const; + }; +}
\ No newline at end of file diff --git a/Whisper/Utils/Logger.cpp b/Whisper/Utils/Logger.cpp new file mode 100644 index 0000000..1b4b233 --- /dev/null +++ b/Whisper/Utils/Logger.cpp @@ -0,0 +1,240 @@ +#include "stdafx.h" +#include "Logger.h" +#include "../API/iContext.cl.h" +#include <cstdarg> +#include <atlstr.h> + +namespace +{ + wchar_t* formatMessage( HRESULT hr ) + { + wchar_t* err; + if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + hr, + MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), + (LPTSTR)&err, + 0, + nullptr ) ) + return err; + return nullptr; + } + + class Utf + { + CStringA utf8; + CStringW utf16; + + void appendError( HRESULT hr ) + { + const wchar_t* err = formatMessage( hr ); + if( nullptr != err ) + { + utf16 += err; + LocalFree( (HLOCAL)err ); + utf16.TrimRight(); + } + else + utf16.AppendFormat( L"error code %i (0x%08X)", hr, hr ); + } + + public: + const char* print( const char* pszFormat, std::va_list va ) + { + utf8.FormatV( pszFormat, va ); + return utf8; + } + const wchar_t* print( const wchar_t* pszFormat, std::va_list va ) + { + utf16.FormatV( pszFormat, va ); + return utf16; + } + const wchar_t* upcast( const char* message, int len ) + { + int count = MultiByteToWideChar( CP_UTF8, 0, message, len, nullptr, 0 ); + if( count == 0 ) + return nullptr; + wchar_t* b = utf16.GetBufferSetLength( len + 1 ); + count = MultiByteToWideChar( CP_UTF8, 0, message, len, b, len ); + utf16.ReleaseBuffer( count ); + return utf16; + } + int utf8Length() const + { + return utf8.GetLength(); + } + const wchar_t* printError( HRESULT hr, const char* pszFormat, std::va_list va ) + { + print( pszFormat, va ); + upcast( utf8, utf8.GetLength() ); + utf16 += L": "; + appendError( hr ); + return utf16; + } + const char* downcast() + { + int count = WideCharToMultiByte( CP_UTF8, 0, utf16, utf16.GetLength(), nullptr, 0, nullptr, nullptr ); + char* s = utf8.GetBufferSetLength( count + 1 ); + count = WideCharToMultiByte( CP_UTF8, 0, utf16, utf16.GetLength(), s, count, nullptr, nullptr ); + utf8.ReleaseBufferSetLength( count ); + return utf8; + } + }; + thread_local Utf ts_utf; + using Whisper::eLoggerFlags; + + class Logger : Whisper::sLoggerSetup + { + inline bool hasFlag( eLoggerFlags bit ) const + { + return 0 != ( (uint8_t)flags & (uint8_t)bit ); + } + + bool useStdError() const + { + return hasFlag( eLoggerFlags::UseStandardError ); + } + + static void writeStdError( Whisper::eLogLevel lvl, const char* message, int len ) + { + const wchar_t* w = ts_utf.upcast( message, len ); + if( nullptr != w ) + fwprintf( stderr, L"%s\n", w ); + } + + public: + Logger() + { + memset( this, 0, sizeof( Logger ) ); + } + + bool willLog( Whisper::eLogLevel lvl ) const + { + if( (uint8_t)lvl > (uint8_t)level ) + return false; + if( useStdError() ) + return true; + return nullptr != sink; + } + + void message( Whisper::eLogLevel lvl, const char8_t* pszFormat, std::va_list va ) const + { + const char* s = ts_utf.print( (const char*)pszFormat, va ); + auto pfn = sink; + if( nullptr != pfn ) + pfn( context, lvl, s ); + if( useStdError() ) + writeStdError( lvl, s, ts_utf.utf8Length() ); + } + void message( Whisper::eLogLevel lvl, const wchar_t* pszFormat, std::va_list va ) const + { + Utf& u = ts_utf; + const wchar_t* w = u.print( pszFormat, va ); + auto pfn = sink; + if( nullptr != pfn ) + pfn( context, lvl, u.downcast() ); + if( useStdError() ) + fwprintf( stderr, L"%s\n", w ); + } + void message( Whisper::eLogLevel lvl, HRESULT hr, const char* pszFormat, std::va_list va ) const + { + if( hasFlag( eLoggerFlags::SkipFormatMessage ) ) + { + message( lvl, (const char8_t*)pszFormat, va ); + return; + } + Utf& u = ts_utf; + const wchar_t* w = ts_utf.printError( hr, (const char*)pszFormat, va ); + auto pfn = sink; + if( nullptr != pfn ) + pfn( context, lvl, u.downcast() ); + if( useStdError() ) + fwprintf( stderr, L"%s\n", w ); + } + + void operator=( const sLoggerSetup& rsi ) + { + sink = rsi.sink; + context = rsi.context; + level = rsi.level; + flags = rsi.flags; + } + }; + + static Logger s_logger; +} + +bool willLogMessage( Whisper::eLogLevel lvl ) +{ + return s_logger.willLog( lvl ); +} + +using Whisper::eLogLevel; + +#define LOG_MESSAGE_IMPL( lvl ) \ + if( !s_logger.willLog( lvl ) ) \ + return; \ + std::va_list args; \ + va_start( args, pszFormat ); \ + s_logger.message( lvl, pszFormat, args ); \ + va_end( args ); + +void logError( const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Error ); +} +void logError16( const wchar_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Error ); +} +void logWarning( const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Warning ); +} +void logWarning16( const wchar_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Warning ); +} +void logInfo( const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Info ); +} +void logInfo16( const wchar_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Info ); +} +void logDebug( const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Debug ); +} +void logDebug16( const wchar_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Debug ); +} +#undef LOG_MESSAGE_IMPL + +#define LOG_MESSAGE_IMPL( lvl ) \ + if( !s_logger.willLog( lvl ) ) \ + return; \ + std::va_list args; \ + va_start( args, pszFormat ); \ + s_logger.message( lvl, hr, (const char*)pszFormat, args ); \ + va_end( args ); + +void logErrorHr( long hr, const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Error ); +} +void logWarningHr( long hr, const char8_t* pszFormat, ... ) +{ + LOG_MESSAGE_IMPL( eLogLevel::Warning ); +} + +#undef LOG_MESSAGE_IMPL + +// DLL entry point +HRESULT COMLIGHTCALL Whisper::setupLogger( const sLoggerSetup& setup ) +{ + s_logger = setup; + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/Utils/Logger.h b/Whisper/Utils/Logger.h new file mode 100644 index 0000000..5e6ec0d --- /dev/null +++ b/Whisper/Utils/Logger.h @@ -0,0 +1,23 @@ +#pragma once +#include "../API/loggerApi.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void logError( const char8_t* pszFormat, ... ); +void logError16( const wchar_t* pszFormat, ... ); +void logErrorHr( long hr, const char8_t* pszFormat, ... ); +void logWarning( const char8_t* pszFormat, ... ); +void logWarning16( const wchar_t* pszFormat, ... ); +void logWarningHr( long hr, const char8_t* pszFormat, ... ); +void logInfo( const char8_t* pszFormat, ... ); +void logInfo16( const wchar_t* pszFormat, ... ); +void logDebug( const char8_t* pszFormat, ... ); +void logDebug16( const wchar_t* pszFormat, ... ); + +bool willLogMessage( Whisper::eLogLevel lvl ); + +#ifdef __cplusplus +} +#endif diff --git a/Whisper/Utils/ProfileCollection.cpp b/Whisper/Utils/ProfileCollection.cpp new file mode 100644 index 0000000..2fc5919 --- /dev/null +++ b/Whisper/Utils/ProfileCollection.cpp @@ -0,0 +1,331 @@ +#include "stdafx.h" +#include "ProfileCollection.h" +#include "GpuProfiler.h" +#include "../Whisper/WhisperModel.h" +#include "../D3D/shaderNames.h" +using namespace Whisper; + +ProfileCollection::Measure& ProfileCollection::measure( DirectCompute::eProfilerBlock which ) +{ + uint32_t key = (uint16_t)which; + key |= 0x20000; + return measures[ key ]; +} + +ProfileCollection::Measure& ProfileCollection::measure( DirectCompute::eComputeShader which ) +{ + uint32_t key = (uint16_t)which; + key |= 0x30000; + return measures[ key ]; +} + +ProfileCollection::Measure& ProfileCollection::measure( eCpuBlock which ) +{ + uint32_t key = (uint8_t)which; + key |= 0x10000; + CComCritSecLock<CComAutoCriticalSection> lock{ critSec }; + return measures[ key ]; +} + +#if PROFILER_COLLECT_TAGS +ProfileCollection::Measure& ProfileCollection::measure( DirectCompute::eComputeShader which, uint16_t tag ) +{ + uint32_t key = (uint8_t)which; + key = key << 16; + key |= tag; + CComCritSecLock<CComAutoCriticalSection> lock{ critSec }; + return taggedShaders[ key ]; +} +#endif + +namespace +{ + using pfnPrintEnum = const char* ( * )( uint16_t val ); + + static const char* printCpuBlock( uint16_t id ) + { + const eCpuBlock which = (eCpuBlock)id; + switch( which ) + { +#define V(x) case eCpuBlock::x: return #x + V( LoadModel ); + V( Run ); + V( Spectrogram ); + V( Sample ); + V( VAD ); + V( Decode ); + V( DecodeStep ); + V( DecodeLayer ); +#undef V + } + assert( false ); + return nullptr; + } + + static const char* printGpuBlock( uint16_t id ) + { + using DirectCompute::eProfilerBlock; + const eProfilerBlock which = (eProfilerBlock)id; + + switch( which ) + { +#define V(x) case eProfilerBlock::x: return #x + V( LoadModel ); + V( Run ); + V( Encode ); + V( EncodeLayer ); + V( Decode ); + V( DecodeStep ); + V( DecodeLayer ); +#undef V + } + assert( false ); + return nullptr; + } + + static const char* printShader( uint16_t id ) + { + return DirectCompute::computeShaderName( (DirectCompute::eComputeShader)id ); + } + + static pfnPrintEnum printSectionStart( uint16_t type ) + { + switch( type ) + { + case 1: + logInfo( u8" CPU Tasks" ); + return &printCpuBlock; + case 2: + logInfo( u8" GPU Tasks" ); + return &printGpuBlock; + case 3: + logInfo( u8" Compute Shaders" ); + return &printShader; + default: + return nullptr; + } + } + + struct PrintedTime + { + double value; + const char* unit; + + PrintedTime( uint64_t ticks ) + { + const double dbl = (double)(int64_t)ticks; + if( ticks >= 10'000'000 ) + { + value = dbl / 1.0E+7; + unit = "seconds"; + } + else if( ticks >= 10'000 ) + { + value = dbl / 1.0E+4; + unit = "milliseconds"; + } + else + { + value = dbl / 1.0E+1; + unit = "microseconds"; + } + } + PrintedTime( double dbl ) + { + if( dbl >= 10'000'000 ) + { + value = dbl / 1.0E+7; + unit = "seconds"; + } + else if( dbl >= 10'000 ) + { + value = dbl / 1.0E+4; + unit = "milliseconds"; + } + else + { + value = dbl / 1.0E+1; + unit = "microseconds"; + } + } + }; +} + +void ProfileCollection::Measure::print( const char* name ) const +{ + PrintedTime total{ totalTicks }; + if( 1 == count ) + logInfo( u8"%s\t%g %s", name, total.value, total.unit ); + else + { + PrintedTime avg = (double)totalTicks / (double)(int64_t)count; + logInfo( u8"%s\t%g %s, %zu calls, %g %s average", name, total.value, total.unit, count, avg.value, avg.unit ); + } +} + +#if PROFILER_COLLECT_TAGS +struct TaggedShaderCmp +{ + bool operator()( uint16_t cs, uint32_t key ) const + { + return cs < key >> 16; + } + bool operator()( uint32_t key, uint16_t cs ) const + { + return key >> 16 < cs; + } +}; + +void ProfileCollection::TaggedTemp::print() const +{ + PrintedTime total{ ticks }; + if( 1 == count ) + logInfo( u8" %s\t%g %s", name, total.value, total.unit ); + else + { + PrintedTime avg = (double)ticks / (double)(int64_t)count; + logInfo( u8" %s\t%g %s, %zu calls, %g %s average", name, total.value, total.unit, count, avg.value, avg.unit ); + } +} +#endif + +void ProfileCollection::print() +{ + keysTemp.clear(); + for( POSITION pos = measures.GetStartPosition(); nullptr != pos; ) + { + auto* p = measures.GetNext( pos ); + if( p->m_value.count == 0 ) + continue; + keysTemp.push_back( p->m_key ); + } + + std::sort( keysTemp.begin(), keysTemp.end() ); + auto it = std::lower_bound( keysTemp.begin(), keysTemp.end(), 0x30000u ); + if( it != keysTemp.end() ) + { + auto lambda = [ this ]( uint32_t a, uint32_t b ) + { + const uint64_t ta = measures.Lookup( a )->m_value.totalTicks; + const uint64_t tb = measures.Lookup( b )->m_value.totalTicks; + return ta > tb; + }; + std::stable_sort( it, keysTemp.end(), lambda ); + } + +#if PROFILER_COLLECT_TAGS + taggedKeysTemp.clear(); + for( POSITION pos = taggedShaders.GetStartPosition(); nullptr != pos; ) + { + auto* p = taggedShaders.GetNext( pos ); + if( p->m_value.count == 0 ) + continue; + taggedKeysTemp.push_back( p->m_key ); + } + std::sort( taggedKeysTemp.begin(), taggedKeysTemp.end() ); +#endif + + uint16_t prevKeyType = 0; + pfnPrintEnum pfn = nullptr; + for( uint32_t k : keysTemp ) + { + const uint16_t type = (uint16_t)( k >> 16 ); + if( type != prevKeyType ) + { + prevKeyType = type; + pfn = printSectionStart( type ); + } + if( pfn == nullptr ) + continue; + const auto* p = measures.Lookup( k ); + assert( nullptr != p ); + p->m_value.print( pfn( (uint16_t)k ) ); + +#if PROFILER_COLLECT_TAGS + if( type == 3 ) + { + // Compute shader + auto range = std::equal_range( taggedKeysTemp.begin(), taggedKeysTemp.end(), (uint16_t)k, TaggedShaderCmp{} ); + if( range.first != range.second ) + { + // We have at least 1 tag for that compute shader + taggedTimes.clear(); + uint64_t totalTicks = 0; + size_t totalCount = 0; + for( auto it = range.first; it != range.second; it++ ) + { + const uint32_t key = *it; + const uint16_t tagId = (uint16_t)key; + assert( 0 != tagId ); + const auto* p = taggedShaders.Lookup( key ); + assert( nullptr != p ); + + auto& rdi = taggedTimes.emplace_back(); + rdi.ticks = p->m_value.totalTicks; + totalTicks += p->m_value.totalTicks; + + rdi.count = p->m_value.count; + totalCount += p->m_value.count; + + rdi.name = tagNames[ tagId ]; + } + + assert( totalCount <= p->m_value.count ); + if( totalCount < p->m_value.count ) + { + auto& rdi = taggedTimes.emplace_back(); + rdi.ticks = p->m_value.totalTicks - totalTicks; + rdi.count = p->m_value.count - totalCount; + rdi.name = tagNames[ 0 ]; + } + std::stable_sort( taggedTimes.begin(), taggedTimes.end() ); + for( const auto& e : taggedTimes ) + e.print(); + } + } +#endif + } +} + +void ProfileCollection::reset() +{ + for( POSITION pos = measures.GetStartPosition(); nullptr != pos; ) + measures.GetNextValue( pos ).reset(); +} + +ProfileCollection::ProfileCollection( const WhisperModel& model ) +{ + const __m128i vals = model.getLoadTimes(); + + uint64_t s = (uint64_t)_mm_cvtsi128_si64( vals ); + measure( eCpuBlock::LoadModel ).add( s ); + + s = (uint64_t)_mm_extract_epi64( vals, 1 ); + measure( DirectCompute::eProfilerBlock::LoadModel ).add( s ); +#if PROFILER_COLLECT_TAGS + // Tag ID 0 means no tag at all. makeTagId() method returns 0 for nullptr name, and starts numbering with 1 for non-empoty tag names + // Push the tag name corresponding to ID = 0, this way we can index directly with tag IDs. + tagNames.push_back( "<untagged>" ); +#endif +} + +uint16_t ProfileCollection::makeTagId( const char* tag ) +{ +#if PROFILER_COLLECT_TAGS + if( nullptr == tag ) + return 0; + auto p = tagIDs.Lookup( tag ); + if( nullptr != p ) + return p->m_value; + const size_t newTag = tagIDs.GetCount() + 1; + if( newTag <= 0xFFFF ) + { + tagIDs.SetAt( tag, (uint16_t)newTag ); + tagNames.push_back( tag ); + return (uint16_t)newTag; + } + throw DISP_E_OVERFLOW; +#else + return 0; +#endif +}
\ No newline at end of file diff --git a/Whisper/Utils/ProfileCollection.h b/Whisper/Utils/ProfileCollection.h new file mode 100644 index 0000000..732bb48 --- /dev/null +++ b/Whisper/Utils/ProfileCollection.h @@ -0,0 +1,112 @@ +#pragma once +#include <atlcoll.h> +#include "CpuProfiler.h" + +namespace DirectCompute +{ + enum struct eComputeShader : uint16_t; + enum struct eProfilerBlock : uint16_t; +} + +namespace Whisper +{ + struct WhisperModel; + + enum struct eCpuBlock : uint8_t + { + LoadModel, + Run, + Spectrogram, + Sample, + VAD, + Decode, + DecodeStep, + DecodeLayer, + }; + + class ProfileCollection + { + public: + ProfileCollection( const WhisperModel& model ); + + struct Measure + { + size_t count = 0; + // 100-nanosecond ticks + uint64_t totalTicks = 0; + + void reset() + { + count = 0; + totalTicks = 0; + } + + void print( const char* name ) const; + + void add( uint64_t val ) + { + count++; + totalTicks += val; + } + }; + + Measure& measure( DirectCompute::eProfilerBlock which ); + Measure& measure( DirectCompute::eComputeShader which ); + Measure& measure( eCpuBlock which ); +#if PROFILER_COLLECT_TAGS + Measure& measure( DirectCompute::eComputeShader which, uint16_t tag ); +#endif + void print(); + + void reset(); + + class CpuRaii + { + Measure& dest; + const int64_t tsc; + + public: + CpuRaii( Measure& m ) : dest( m ), tsc( tscNow() ) + { } + + ~CpuRaii() + { + const int64_t elapsed = tscNow() - tsc; + dest.add( ticksFromTsc( elapsed ) ); + } + }; + + decltype( auto ) cpuBlock( eCpuBlock which ) + { + return CpuRaii{ measure( which ) }; + } + + uint16_t makeTagId( const char* tag ); + + private: + CAtlMap<uint32_t, Measure> measures; + CComAutoCriticalSection critSec; +#if PROFILER_COLLECT_TAGS + CAtlMap<const char*, uint16_t> tagIDs; + std::vector<const char*> tagNames; + CAtlMap<uint32_t, Measure> taggedShaders; + std::vector<uint32_t> taggedKeysTemp; + struct TaggedTemp + { + uint64_t ticks; + size_t count; + const char* name; + + bool operator<( const TaggedTemp& that ) const + { + // Flipping the comparison to sort in descending order + return ticks > that.ticks; + } + + void print() const; + }; + std::vector<TaggedTemp> taggedTimes; +#endif + std::vector<uint32_t> keysTemp; + }; +}
\ No newline at end of file diff --git a/Whisper/Utils/ReadStream.h b/Whisper/Utils/ReadStream.h new file mode 100644 index 0000000..c9da400 --- /dev/null +++ b/Whisper/Utils/ReadStream.h @@ -0,0 +1,37 @@ +#pragma once +#include "../ComLightLib/streams.h" +#include "../ComLightLib/comLightServer.h" +#define WIN32_LEAN_AND_MEAN +#include <atlfile.h> + +class ReadStream : public ComLight::ObjectRoot<ComLight::iReadStream> +{ + CAtlFile file; + // TODO: implement a buffer in this class, at least 256kb + + HRESULT COMLIGHTCALL read( void* lpBuffer, int nNumberOfBytesToRead, int& lpNumberOfBytesRead ) override final + { + return file.Read( lpBuffer, (DWORD)nNumberOfBytesToRead, *(DWORD*)&lpNumberOfBytesRead ); + } + HRESULT COMLIGHTCALL seek( int64_t offset, ComLight::eSeekOrigin origin ) override final + { + return file.Seek( offset, (uint8_t)origin ); + } + HRESULT COMLIGHTCALL getPosition( int64_t& position ) override final + { + return file.GetPosition( *(ULONGLONG*)&position ); + } + HRESULT COMLIGHTCALL getLength( int64_t& length ) override final + { + return file.GetSize( *(ULONGLONG*)&length ); + } + +public: + + HRESULT open( const wchar_t* path ) + { + if( file ) + return HRESULT_CODE( ERROR_ALREADY_INITIALIZED ); + return file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN ); + } +};
\ No newline at end of file diff --git a/Whisper/Utils/Trace/TraceStructures.cpp b/Whisper/Utils/Trace/TraceStructures.cpp new file mode 100644 index 0000000..289a534 --- /dev/null +++ b/Whisper/Utils/Trace/TraceStructures.cpp @@ -0,0 +1,31 @@ +#include "stdafx.h" +#include "TraceStructures.h" +using namespace Tracing; + +uint64_t sTraceItem::buffer( uint64_t off, size_t length, eDataType type ) +{ + payloadOffset = off; + payloadSize = length * DirectCompute::elementSize( type ); + *(uint64_t*)( &size[ 0 ] ) = length; + *(uint64_t*)( &size[ 2 ] ) = 0; + _mm_storeu_si128( ( __m128i* )stride.data(), _mm_setzero_si128() ); + itemType = eItemType::Buffer; + dataType = type; + return payloadSize; +} + +uint64_t sTraceItem::tensor( uint64_t off, __m128i ne, __m128i nb, eDataType type ) +{ + payloadOffset = off; + _mm_storeu_si128( ( __m128i* )size.data(), ne ); + _mm_storeu_si128( ( __m128i* )stride.data(), nb ); + uint64_t count = 1; + for( uint32_t i : size ) + if( i != 0 ) + count *= i; + + payloadSize = count * DirectCompute::elementSize( type ); + itemType = eItemType::Tensor; + dataType = type; + return payloadSize; +}
\ No newline at end of file diff --git a/Whisper/Utils/Trace/TraceStructures.h b/Whisper/Utils/Trace/TraceStructures.h new file mode 100644 index 0000000..d1213bf --- /dev/null +++ b/Whisper/Utils/Trace/TraceStructures.h @@ -0,0 +1,55 @@ +#pragma once +#include <array> +#include <emmintrin.h> +#include "../../D3D/enums.h" + +namespace Tracing +{ + using DirectCompute::eDataType; + + // File header of the trace file + struct sFileHeader + { + static constexpr uint32_t correctMagic = 0xE6B4A12Du; // random.org + + uint32_t magic; + uint8_t formatVersion; + uint8_t zzPadding; + uint16_t cbItem; + uint32_t countItems; + uint32_t zzPadding2; + uint64_t bytesPayload; + uint32_t countStrings, bytesStrings; + }; + // Payload data starts immediately after the header, bytesPayload bytes in total. + // Then `bytesStrings` with string names, first countStrings * 4 of them are offsets, then ( bytesStrings - countStrings * 4 ) bytes with the string data. + // The strings in the file are null-terminated. + // Immediately after the strings, the next `cbItem` * `countItems` bytes are actual items (tensors and vectors) saved in the trace. + // The format is weird because optimized for streaming. + // These traces can grow large, we can’t afford memory keeping the payload data in memory. + // Metadata is tiny compared to payload, we accumulate that in memory, and write to the end of the file when closed. + + enum struct eItemType : uint8_t + { + Buffer = 1, + Tensor = 2, + }; + + struct sTraceItem + { + uint64_t payloadOffset; + uint64_t payloadSize; + std::array<uint32_t, 4> size; + std::array<uint32_t, 4> stride; + std::array<uint32_t, 4> formatArgs; + eItemType itemType; + eDataType dataType; + uint8_t countFormatArgs = 0; + uint8_t zzPadding = 0; + uint32_t stringIndex; + + uint64_t buffer( uint64_t off, size_t length, eDataType type ); + + uint64_t tensor( uint64_t off, __m128i ne, __m128i nb, eDataType type ); + }; +}
\ No newline at end of file diff --git a/Whisper/Utils/Trace/TraceWriter.cpp b/Whisper/Utils/Trace/TraceWriter.cpp new file mode 100644 index 0000000..be9946b --- /dev/null +++ b/Whisper/Utils/Trace/TraceWriter.cpp @@ -0,0 +1,263 @@ +#include "stdafx.h" +#include "TraceWriter.h" +#include <atlfile.h> +#include <atlcoll.h> +#include <atlstr.h> +#include "TraceStructures.h" +#include "../../ML/Tensor.h" +#include "../../CPU/Tensor.h" +#include <Shlobj.h> +using namespace Tracing; + +namespace +{ + static HRESULT createDir( LPCTSTR pathFile ) + { + LPCWSTR fn = PathFindFileName( pathFile ); + if( fn == pathFile ) + return E_FAIL; + + const int cc = (int)( fn - pathFile ); + CString dir{ pathFile, cc }; + if( PathIsDirectory( dir ) ) + return S_OK; + const int status = SHCreateDirectoryEx( nullptr, dir, nullptr ); + if( 0 == status ) + return S_OK; + return HRESULT_FROM_WIN32( status ); + } + + class TraceFileWriter + { + CAtlFile file; + // Concatenated strings, including the 0 terminators + std::vector<char> stringsData; + // Index = string ID, value = start offset into stringsData + std::vector<uint32_t> stringsIndex; + // Hash map to unduplicate these strings + CAtlMap<CStringA, uint32_t> stringsHash; + + uint32_t addString( const CStringA& s ) + { + auto p = stringsHash.Lookup( s ); + if( p != nullptr ) + return p->m_value; + + const uint32_t off = (uint32_t)stringsData.size(); + const char* rsi = s; + stringsData.insert( stringsData.end(), rsi, rsi + s.GetLength() + 1 ); + stringsIndex.push_back( off ); + + const uint32_t newId = (uint32_t)stringsHash.GetCount(); + stringsHash.SetAt( s, newId ); + return newId; + } + + void addString( sTraceItem& rdi, const ItemName& name ) + { + rdi.countFormatArgs = name.countArgs; + rdi.stringIndex = addString( name.pointer ); + rdi.formatArgs = name.args; + } + + std::vector<sTraceItem> items; + uint64_t offset = 0; + + public: + + HRESULT create( LPCTSTR path ) + { + CHECK( createDir( path ) ); + CHECK( file.Create( path, GENERIC_WRITE, 0, CREATE_ALWAYS ) ); + + constexpr uint64_t cbHeader = sizeof( sFileHeader ); + CHECK( file.SetSize( cbHeader ) ); + CHECK( file.Seek( 0, SEEK_END ) ); + offset = 0; + + return S_OK; + } + + HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt ) + { + sTraceItem& rdi = items.emplace_back(); + const uint64_t cb = rdi.buffer( offset, length, dt ); + addString( rdi, name ); + assert( cb <= UINT_MAX ); + CHECK( file.Write( rsi, (DWORD)cb ) ); + offset += cb; + return S_OK; + } + + HRESULT tensor( const ItemName& name, const void* rsi, __m128i size, __m128i strides, eDataType dt ) + { + sTraceItem& rdi = items.emplace_back(); + const uint64_t cb = rdi.tensor( offset, size, strides, dt ); + addString( rdi, name ); + assert( cb <= UINT_MAX ); + CHECK( file.Write( rsi, (DWORD)cb ) ); + offset += cb; + return S_OK; + } + + HRESULT close() + { + if( !file ) + return S_FALSE; + + const uint32_t cbStringsData = (uint32_t)stringsData.size(); + const uint32_t cbStringsIndex = (uint32_t)( stringsIndex.size() * 4 ); + if( !stringsIndex.empty() ) + CHECK( file.Write( stringsIndex.data(), cbStringsIndex ) ); + if( !stringsData.empty() ) + CHECK( file.Write( stringsData.data(), cbStringsData ) ); + + const uint32_t cbItems = (uint32_t)items.size() * (uint32_t)sizeof( sTraceItem ); + if( !items.empty() ) + CHECK( file.Write( items.data(), cbItems ) ); + CHECK( file.Seek( 0, FILE_BEGIN ) ); + + sFileHeader header; + memset( &header, 0, sizeof( header ) ); + header.magic = header.correctMagic; + header.cbItem = sizeof( sTraceItem ); + header.countItems = (uint32_t)items.size(); + header.bytesPayload = offset; + header.countStrings = (uint32_t)stringsIndex.size(); + header.bytesStrings = cbStringsData + cbStringsIndex; + CHECK( file.Write( &header, sizeof( header ) ) ); + CHECK( file.Flush() ); + file.Close(); + + return S_OK; + } + }; + + class TraceWriter : public iTraceWriter + { + TraceFileWriter file; + + HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt ) override final + { + return file.buffer( name, rsi, length, dt ); + } + + HRESULT tensor( const ItemName& name, const void* rsi, __m128i size, __m128i strides, eDataType dt ) override final + { + return file.tensor( name, rsi, size, strides, dt ); + } + + public: + + TraceWriter( LPCTSTR path ) + { + check( file.create( path ) ); + } + + ~TraceWriter() + { + check( file.close() ); + } + }; +} + +std::unique_ptr<iTraceWriter> iTraceWriter::create( LPCTSTR path ) +{ + return std::make_unique<TraceWriter>( path ); +} + +namespace +{ + static std::vector<float> tempFp32; + static std::vector<uint16_t> tempFp16; + + template<class E> + inline const void* ptr( const std::vector<E>& vec ) + { + return vec.empty() ? nullptr : vec.data(); + } +} + +HRESULT iTraceWriter::tensor( const ItemName& name, const DirectCompute::Tensor& source ) +{ + const __m128i size = source.sizeVec(); + const __m128i strides = source.stridesVec(); + const eDataType dt = source.getType(); + if( dt == eDataType::FP32 ) + { + source.download( tempFp32 ); + return tensor( name, ptr( tempFp32 ), size, strides, eDataType::FP32 ); + } + else if( dt == eDataType::FP16 ) + { + source.download( tempFp16 ); + return tensor( name, ptr( tempFp16 ), size, strides, eDataType::FP16 ); + } + return E_NOTIMPL; +} + +HRESULT iTraceWriter::tensor( const ItemName& name, const CpuCompute::Tensor& source ) +{ + const __m128i size = source.sizeVec(); + const __m128i strides = source.stridesVec(); + const eDataType dt = source.type(); + + if( dt == eDataType::FP32 ) + return tensor( name, source.fp32(), size, strides, eDataType::FP32 ); + else if( dt == eDataType::FP16 ) + return tensor( name, source.fp16(), size, strides, eDataType::FP16 ); + else + return E_NOTIMPL; +} + +#if BUILD_BOTH_VERSIONS +#include "../../source/ggml.h" +HRESULT __declspec( noinline ) iTraceWriter::tensor( const ItemName& name, const ggml_tensor& source ) +{ + __m128i size = load16( source.ne ); + __m128i strides = _mm_setr_epi32( + (int)(uint32_t)source.nb[ 0 ], + (int)(uint32_t)source.nb[ 1 ], + (int)(uint32_t)source.nb[ 2 ], + (int)(uint32_t)source.nb[ 3 ] ); + + const __m128i ones = _mm_set1_epi32( 1 ); + switch( source.n_dims ) + { + case 0: + size = ones; + break; + case 1: + size = _mm_blend_epi16( size, ones, 0b11111100 ); + break; + case 2: + size = _mm_blend_epi16( size, ones, 0b11110000 ); + break; + case 3: + size = _mm_blend_epi16( size, ones, 0b11000000 ); + break; + case 4: + break; + default: + return E_INVALIDARG; + } + + const ggml_type dt = source.type; + switch( dt ) + { + case GGML_TYPE_F16: + strides = _mm_srli_epi32( strides, 1 ); + return tensor( name, source.data, size, strides, eDataType::FP16 ); + case GGML_TYPE_F32: + strides = _mm_srli_epi32( strides, 2 ); + return tensor( name, source.data, size, strides, eDataType::FP32 ); + default: + return E_NOTIMPL; +} +} +#else +HRESULT iTraceWriter::tensor( const ItemName& name, const ggml_tensor& source ) +{ + return E_NOTIMPL; +} +#endif
\ No newline at end of file diff --git a/Whisper/Utils/Trace/TraceWriter.h b/Whisper/Utils/Trace/TraceWriter.h new file mode 100644 index 0000000..0514c5a --- /dev/null +++ b/Whisper/Utils/Trace/TraceWriter.h @@ -0,0 +1,70 @@ +#pragma once +#include <memory> +#include "../../D3D/enums.h" + +namespace DirectCompute +{ + class Tensor; +} +namespace CpuCompute +{ + class Tensor; +} + +struct ggml_tensor; + +namespace Tracing +{ + using DirectCompute::eDataType; + + struct ItemName + { + const char* pointer; + std::array<uint32_t, 4> args; + uint8_t countArgs; + + ItemName( const char* str ) + { + pointer = str; + _mm_storeu_si128( ( __m128i* )args.data(), _mm_setzero_si128() ); + countArgs = 0; + } + ItemName( const char* str, int a0 ) + { + pointer = str; + __m128i v = _mm_cvtsi32_si128( a0 ); + _mm_storeu_si128( ( __m128i* )args.data(), v ); + countArgs = 1; + } + ItemName( const char* str, uint32_t a0 ) + { + pointer = str; + __m128i v = _mm_cvtsi32_si128( (int)a0 ); + _mm_storeu_si128( ( __m128i* )args.data(), v ); + countArgs = 1; + } + ItemName( const char* str, size_t a0 ) + { + pointer = str; + __m128i v = _mm_cvtsi32_si128( (int)a0 ); + _mm_storeu_si128( ( __m128i* )args.data(), v ); + countArgs = 1; + } + }; + + class iTraceWriter + { + public: + virtual ~iTraceWriter() {} + + static std::unique_ptr<iTraceWriter> create( LPCTSTR path ); + + virtual HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt ) = 0; + + virtual HRESULT tensor( const ItemName& name, const void* rsi, __m128i size, __m128i strides, eDataType dt ) = 0; + + HRESULT tensor( const ItemName& name, const DirectCompute::Tensor& tensor ); + HRESULT tensor( const ItemName& name, const CpuCompute::Tensor& tensor ); + HRESULT tensor( const ItemName& name, const ggml_tensor& tensor ); + }; +}
\ No newline at end of file diff --git a/Whisper/Utils/Trace/tracing.cpp b/Whisper/Utils/Trace/tracing.cpp new file mode 100644 index 0000000..976f517 --- /dev/null +++ b/Whisper/Utils/Trace/tracing.cpp @@ -0,0 +1,60 @@ +#include "stdafx.h" +#include "tracing.h" +#include "../../source/ggml.h" + +#if SAVE_DEBUG_TRACE +namespace Tracing +{ + std::unique_ptr<iTraceWriter> s_writer; + + static BOOL __stdcall consoleHandler( DWORD dwCtrlType ) + { + if( dwCtrlType == CTRL_C_EVENT ) + s_writer = nullptr; + + // Return TRUE if handled this message, further handler functions won't be called. + // Return FALSE to pass this message to further handlers until default handler calls ExitProcess(). + return FALSE; + } + + void traceCreate( LPCTSTR path ) + { + s_writer = iTraceWriter::create( path ); + SetConsoleCtrlHandler( &consoleHandler, TRUE ); + } + + void traceClose() + { + s_writer = nullptr; + } + + iTraceWriter* getWriter() + { + return s_writer.get(); + } + + using Pair = std::pair<ItemName, ggml_tensor>; + static std::vector<Pair> delayed; + + void delayTensor( const ItemName& name, const ggml_tensor* tensor ) + { + delayed.emplace_back( name, *tensor ); + } + + HRESULT writeDelayedTensors() + { + if( delayed.empty() ) + return S_FALSE; + iTraceWriter* w = getWriter(); + if( nullptr == w ) + { + delayed.clear(); + return S_FALSE; + } + for( const Pair& p : delayed ) + w->tensor( p.first, p.second ); + delayed.clear(); + return S_OK; + } +} +#endif
\ No newline at end of file diff --git a/Whisper/Utils/Trace/tracing.h b/Whisper/Utils/Trace/tracing.h new file mode 100644 index 0000000..66a7ac4 --- /dev/null +++ b/Whisper/Utils/Trace/tracing.h @@ -0,0 +1,67 @@ +#pragma once +#include "TraceWriter.h" + +namespace Tracing +{ +#if SAVE_DEBUG_TRACE + void traceCreate( LPCTSTR path ); + void traceClose(); + + iTraceWriter* getWriter(); + + inline HRESULT tensor( const ItemName& name, const DirectCompute::Tensor& tensor ) + { + iTraceWriter* w = getWriter(); + if( w ) + return w->tensor( name, tensor ); + return S_FALSE; + } + inline HRESULT tensor( const ItemName& name, const CpuCompute::Tensor& tensor ) + { + iTraceWriter* w = getWriter(); + if( w ) + return w->tensor( name, tensor ); + return S_FALSE; + } + + inline HRESULT tensor( const ItemName& name, const ggml_tensor* tensor ) + { + iTraceWriter* w = getWriter(); + if( w ) + return w->tensor( name, *tensor ); + return S_FALSE; + } + + void delayTensor( const ItemName& name, const ggml_tensor* tensor ); + HRESULT writeDelayedTensors(); + + inline HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt ) + { + iTraceWriter* w = getWriter(); + if( w ) + return w->buffer( name, rsi, length, dt ); + return S_FALSE; + } + + inline HRESULT vector( const ItemName& name, const std::vector<float>& vec ) + { + const float* rsi = vec.empty() ? nullptr : vec.data(); + return buffer( name, rsi, vec.size(), eDataType::FP32 ); + } + inline HRESULT vector( const ItemName& name, const float* rsi, size_t length ) + { + return buffer( name, rsi, length, eDataType::FP32 ); + } +#else + inline void traceCreate( LPCTSTR path ) { } + inline void traceClose() { } + inline HRESULT tensor( const ItemName& name, const DirectCompute::Tensor& tensor ) { return S_FALSE; } + inline HRESULT tensor( const ItemName& name, const CpuCompute::Tensor& tensor ) { return S_FALSE; } + inline HRESULT tensor( const ItemName& name, const ggml_tensor* tensor ) { return S_FALSE; } + inline HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt ) { return S_FALSE; } + inline HRESULT vector( const ItemName& name, const std::vector<float>& vec ) { return S_FALSE; } + inline void delayTensor( const ItemName& name, const ggml_tensor* tensor ) { } + inline HRESULT writeDelayedTensors() { return S_FALSE; } + inline HRESULT vector( const ItemName& name, const float* rsi, size_t length ) { } +#endif +}
\ No newline at end of file diff --git a/Whisper/Utils/miscUtils.cpp b/Whisper/Utils/miscUtils.cpp new file mode 100644 index 0000000..c3f7dd1 --- /dev/null +++ b/Whisper/Utils/miscUtils.cpp @@ -0,0 +1,33 @@ +#include "stdafx.h" +#include "miscUtils.h" + +void setCurrentThreadName( const char* threadName ) +{ + const DWORD dwThreadID = GetCurrentThreadId(); + + // https://stackoverflow.com/a/10364541/126995 +#pragma pack(push,8) + typedef struct tagTHREADNAME_INFO + { + DWORD dwType; // Must be 0x1000. + LPCSTR szName; // Pointer to name (in user addr space). + DWORD dwThreadID; // Thread ID (-1=caller thread). + DWORD dwFlags; // Reserved for future use, must be zero. + } THREADNAME_INFO; +#pragma pack(pop) + + THREADNAME_INFO info; + info.dwType = 0x1000; + info.szName = threadName; + info.dwThreadID = dwThreadID; + info.dwFlags = 0; + + constexpr DWORD MS_VC_EXCEPTION = 0x406D1388; + __try + { + RaiseException( MS_VC_EXCEPTION, 0, sizeof( info ) / sizeof( ULONG_PTR ), (ULONG_PTR*)&info ); + } + __except( EXCEPTION_EXECUTE_HANDLER ) + { + } +}
\ No newline at end of file diff --git a/Whisper/Utils/miscUtils.h b/Whisper/Utils/miscUtils.h new file mode 100644 index 0000000..d665cbc --- /dev/null +++ b/Whisper/Utils/miscUtils.h @@ -0,0 +1,81 @@ +#pragma once + +#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } +#define CHECK_LOG( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) { logErrorHr(__hr, u8"%s failed", #hr ); return __hr; } } + +inline void check( HRESULT hr ) +{ + if( SUCCEEDED( hr ) ) + return; + throw hr; +} + +inline __m128i load16( const int* rsi ) +{ + return _mm_loadu_si128( ( const __m128i* )rsi ); +} +inline __m128i load16( const uint32_t* rsi ) +{ + return _mm_loadu_si128( ( const __m128i* )rsi ); +} +inline __m128i load( const std::array<uint32_t, 4>& arr ) +{ + return load16( arr.data() ); +} +inline void store16( void* rdi, __m128i v ) +{ + _mm_storeu_si128( ( __m128i* )rdi, v ); +} +inline void store12( void* rdi, __m128i v ) +{ + _mm_storel_epi64( ( __m128i* )rdi, v ); + ( (int*)rdi )[ 2 ] = _mm_extract_epi32( v, 2 ); +} +inline void store( std::array<uint32_t, 4>& arr, __m128i v ) +{ + store16( arr.data(), v ); +} +inline bool vectorEqual( __m128i a, __m128i b ) +{ + __m128i xx = _mm_xor_si128( a, b ); + return (bool)_mm_testz_si128( xx, xx ); +} + +inline __m128i setLow_size( size_t low ) +{ + return _mm_cvtsi64_si128( (int64_t)low ); +} +inline __m128i setr_size( size_t low, size_t high ) +{ + __m128i v = setLow_size( low ); + v = _mm_insert_epi64( v, (int64_t)high, 1 ); + return v; +} +inline __m128i setHigh_size( size_t high ) +{ + __m128i v = _mm_setzero_si128(); + v = _mm_insert_epi64( v, (int64_t)high, 1 ); + return v; +} + +void setCurrentThreadName( const char* name ); + +inline HRESULT getLastHr() +{ + return HRESULT_FROM_WIN32( GetLastError() ); +} + +// Scale time in seconds from unsigned 64 bit rational number ( mul / div ) into 100-nanosecond ticks +// These 100-nanosecond ticks are used in NTFS, FILETIME, .NET standard library, media foundation, and quite a few other places +inline uint64_t makeTime( uint64_t mul, uint64_t div ) +{ + mul *= 10'000'000; + mul += ( ( div / 2 ) - 1 ); + return mul / div; +} + +template<class E> +inline size_t vectorMemoryUse( const std::vector<E>& vec ) +{ + return sizeof( E ) * vec.capacity(); +}
\ No newline at end of file diff --git a/Whisper/Utils/parallelFor.cpp b/Whisper/Utils/parallelFor.cpp new file mode 100644 index 0000000..c2b324b --- /dev/null +++ b/Whisper/Utils/parallelFor.cpp @@ -0,0 +1,144 @@ +#include "stdafx.h" +#include "parallelFor.h" + +namespace +{ + class alignas( 64 ) ParallelForContext + { + volatile long threadIndex; + volatile HRESULT status; + + alignas( 64 ) void* const context; + const Whisper::pfnParallelForCallback pfn; + + static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ); + + public: + + ParallelForContext( void* ctx, Whisper::pfnParallelForCallback pfn ); + + PTP_WORK createWork(); + + HRESULT getStatus() const; + }; + + ParallelForContext::ParallelForContext( void* ctx, Whisper::pfnParallelForCallback callback ) : + threadIndex( 1 ), + status( S_FALSE ), + context( ctx ), + pfn( callback ) + { } + + PTP_WORK ParallelForContext::createWork() + { + return CreateThreadpoolWork( &callbackStatic, this, nullptr ); + } + + void __stdcall ParallelForContext::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ) + { + ParallelForContext& context = *(ParallelForContext*)pv; + int ith = InterlockedIncrement( &context.threadIndex ); + ith--; + const HRESULT hr = context.pfn( ith, context.context ); + if( SUCCEEDED( hr ) ) + return; + InterlockedCompareExchange( &context.status, hr, S_FALSE ); + } + + HRESULT ParallelForContext::getStatus() const + { + const HRESULT hr = status; + if( SUCCEEDED( hr ) ) + return S_OK; + return hr; + } +} + +namespace Whisper +{ + HRESULT parallelFor( pfnParallelForCallback pfn, int threadsCount, void* ctx ) + { + if( threadsCount < 1 ) + return E_BOUNDS; + if( threadsCount == 1 ) + return pfn( 0, ctx ); + + ParallelForContext context{ ctx, pfn }; + + PTP_WORK const pw = context.createWork(); + if( nullptr == pw ) + return getLastHr(); + + for( int i = 1; i < threadsCount; i++ ) + SubmitThreadpoolWork( pw ); + + const HRESULT hr0 = pfn( 0, ctx ); + + WaitForThreadpoolWorkCallbacks( pw, FALSE ); + CloseThreadpoolWork( pw ); + + if( FAILED( hr0 ) ) + return hr0; + return context.getStatus(); + } +} + +using namespace Whisper; + +ThreadPoolWork::~ThreadPoolWork() +{ + if( nullptr != work ) + { + CloseThreadpoolWork( work ); + work = nullptr; + } +} + +HRESULT ThreadPoolWork::create() +{ + if( nullptr == work ) + { + work = CreateThreadpoolWork( &callbackStatic, this, nullptr ); + if( nullptr != work ) + return S_OK; + return getLastHr(); + } + return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); +} + +HRESULT ThreadPoolWork::parallelFor( int threadsCount ) noexcept +{ + if( nullptr != work ) + { + if( threadsCount <= 1 ) + return threadPoolCallback( 0 ); + + threadIndex = 1; + status = S_FALSE; + for( int i = 1; i < threadsCount; i++ ) + SubmitThreadpoolWork( work ); + + const HRESULT hr0 = threadPoolCallback( 0 ); + + WaitForThreadpoolWorkCallbacks( work, FALSE ); + + if( FAILED( hr0 ) ) + return hr0; + if( SUCCEEDED( status ) ) + return S_OK; + return status; + } + + return OLE_E_BLANK; +} + +void __stdcall ThreadPoolWork::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ) +{ + ThreadPoolWork* tpw = (ThreadPoolWork*)pv; + int ith = InterlockedIncrement( &tpw->threadIndex ); + ith--; + const HRESULT hr = tpw->threadPoolCallback( ith ); + if( SUCCEEDED( hr ) ) + return; + InterlockedCompareExchange( &tpw->status, hr, S_FALSE ); +}
\ No newline at end of file diff --git a/Whisper/Utils/parallelFor.h b/Whisper/Utils/parallelFor.h new file mode 100644 index 0000000..15cd603 --- /dev/null +++ b/Whisper/Utils/parallelFor.h @@ -0,0 +1,38 @@ +#pragma once + +namespace Whisper +{ + // A callback to offload to the thread pool + using pfnParallelForCallback = HRESULT( * )( int ith, void* ctx ) noexcept; + + // A simple parallel for implementation; Windows includes a decent thread pool since Vista (2006) + HRESULT parallelFor( pfnParallelForCallback pfn, int threadsCount, void* ctx ); + + // Use this version when you wanna use the thread pool repeatedly, for the same work. + // This class caches native work handle, saving a couple of WinAPI calls. + class alignas( 64 ) ThreadPoolWork + { + PTP_WORK work = nullptr; + + // We want these volatile fields in another cache line from the rest of the data of this class. + // threadIndex field is concurrently modified by different CPU cores, and these cache coherency protocols are slow. + // OTOH, work and callback fields of this class only change when created / destroyed, that cache line is shared by CPU cores without any performance penalty. + alignas( 64 ) volatile long threadIndex = 0; + volatile HRESULT status = E_UNEXPECTED; + + static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ); + + protected: + virtual HRESULT threadPoolCallback( int ith ) noexcept = 0; + + public: + ThreadPoolWork() = default; + ThreadPoolWork( const ThreadPoolWork& ) = delete; + + ~ThreadPoolWork(); + + HRESULT create(); + + HRESULT parallelFor( int threadsCount ) noexcept; + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper.vcxproj b/Whisper/Whisper.vcxproj new file mode 100644 index 0000000..f270440 --- /dev/null +++ b/Whisper/Whisper.vcxproj @@ -0,0 +1,347 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{701df8c8-e4a5-43ec-9c6b-747bbf4d8e71}</ProjectGuid> + <RootNamespace>Whisper</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>DynamicLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>DynamicLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <IncludePath>$(ProjectDir);$(IncludePath)</IncludePath> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <IncludePath>$(ProjectDir);$(IncludePath)</IncludePath> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>_DEBUG;WHISPER_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + <PrecompiledHeader>Use</PrecompiledHeader> + <MultiProcessorCompilation>true</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Windows</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + <EnableUAC>false</EnableUAC> + <ModuleDefinitionFile>whisper.def</ModuleDefinitionFile> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;WHISPER_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + <PrecompiledHeader>Use</PrecompiledHeader> + <MultiProcessorCompilation>true</MultiProcessorCompilation> + <RuntimeLibrary>MultiThreaded</RuntimeLibrary> + </ClCompile> + <Link> + <SubSystem>Windows</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + <EnableUAC>false</EnableUAC> + <ModuleDefinitionFile>whisper.def</ModuleDefinitionFile> + <LinkTimeCodeGeneration>UseLinkTimeCodeGeneration</LinkTimeCodeGeneration> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ProjectReference Include="..\ComLightLib\ComLightLib.vcxproj"> + <Project>{52f486e7-830c-45d8-be47-e76b5aab2772}</Project> + </ProjectReference> + </ItemGroup> + <ItemGroup> + <ClCompile Include="CPU\BufferAllocator.cpp" /> + <ClCompile Include="CPU\DecoderTensors.cpp" /> + <ClCompile Include="CPU\mulMatImpl.avx2.cpp"> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions2</EnableEnhancedInstructionSet> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions2</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="CPU\mulMatImpl.panel.cpp"> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="CPU\TensorCpu.cpp"> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="CPU\HybridLoader.cpp" /> + <ClCompile Include="Hybrid\HybridContext.cpp" /> + <ClCompile Include="CPU\ParallelForRunner.cpp"> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="CPU\LargeBuffer.cpp" /> + <ClCompile Include="CPU\simdUtils.cpp"> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="CPU\mulMat.cpp"> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="CPU\MlContextCpu.cpp"> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="CPU\KvTensorsCpu.cpp" /> + <ClCompile Include="Hybrid\KeyValueDownloader.cpp" /> + <ClCompile Include="CPU\mulMatImpl.cpp"> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="ML\Reshaper.cpp" /> + <ClCompile Include="Utils\Logger.cpp" /> + <ClCompile Include="MF\AudioCapture.cpp" /> + <ClCompile Include="Utils\miscUtils.cpp" /> + <ClCompile Include="Whisper\voiceActivityDetection.cpp" /> + <ClCompile Include="Whisper\ContextImpl.capture.cpp" /> + <ClCompile Include="Whisper\MelStreamer.cpp" /> + <ClCompile Include="Whisper\melSpectrogram.cpp" /> + <ClCompile Include="modelFactory.cpp" /> + <ClCompile Include="MF\AudioBuffer.cpp" /> + <ClCompile Include="MF\PcmReader.cpp" /> + <ClCompile Include="Utils\Trace\tracing.cpp" /> + <ClCompile Include="Utils\Trace\TraceStructures.cpp" /> + <ClCompile Include="Utils\Trace\TraceWriter.cpp" /> + <ClCompile Include="source.compat\convertThings.cpp" /> + <ClCompile Include="D3D\shaderNames.cpp" /> + <ClCompile Include="MF\mfUtils.cpp" /> + <ClCompile Include="MF\loadAudioFile.cpp" /> + <ClCompile Include="MF\MediaFoundation.cpp" /> + <ClCompile Include="MF\mfStartup.cpp" /> + <ClCompile Include="source.compat\ggmlMsvc.c"> + <PrecompiledHeader>NotUsing</PrecompiledHeader> + <EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet> + </ClCompile> + <ClCompile Include="Whisper\ContextImpl.misc.cpp" /> + <ClCompile Include="Utils\ProfileCollection.cpp" /> + <ClCompile Include="Utils\CpuProfiler.cpp" /> + <ClCompile Include="D3D\enums.cpp" /> + <ClCompile Include="Utils\GpuProfiler.cpp" /> + <ClCompile Include="ML\TensorsArena.cpp" /> + <ClCompile Include="Whisper\Languages.cpp" /> + <ClCompile Include="Whisper\ContextImpl.cpp" /> + <ClCompile Include="Whisper\ModelImpl.cpp" /> + <ClCompile Include="Utils\parallelFor.cpp" /> + <ClCompile Include="Whisper\Spectrogram.cpp" /> + <ClCompile Include="Whisper\WhisperModel.cpp" /> + <ClCompile Include="Whisper\Vocabulary.cpp" /> + <ClCompile Include="Whisper\DecoderResultBuffer.cpp" /> + <ClCompile Include="Whisper\DecoderInputBuffers.cpp" /> + <ClCompile Include="ML\mlStartup.cpp" /> + <ClCompile Include="Whisper\KeyValueBuffers.cpp" /> + <ClCompile Include="D3D\Binder.cpp" /> + <ClCompile Include="ML\LookupTables.cpp" /> + <ClCompile Include="ML\LookupTablesData.cpp" /> + <ClCompile Include="ML\Context.ops.cpp" /> + <ClCompile Include="ML\MlContext.dbg.cpp" /> + <ClCompile Include="ML\MlContext.cpp" /> + <ClCompile Include="ML\ConstantBuffer.cpp" /> + <ClCompile Include="D3D\createBuffer.cpp" /> + <ClCompile Include="D3D\device.cpp" /> + <ClCompile Include="D3D\MappedResource.cpp" /> + <ClCompile Include="Whisper\WhisperContext.cpp" /> + <ClCompile Include="Whisper\ModelBuffers.cpp" /> + <ClCompile Include="D3D\shaders.cpp" /> + <ClCompile Include="ML\TensorShape.cpp" /> + <ClCompile Include="DllMain.cpp" /> + <ClCompile Include="D3D\downloadBuffer.cpp" /> + <ClCompile Include="D3D\RenderDoc\renderDoc.cpp" /> + <ClCompile Include="Whisper\MelInputTensor.cpp" /> + <ClCompile Include="source\ggml.c"> + <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild> + <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild> + </ClCompile> + <ClCompile Include="source\whisper.cpp"> + <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild> + <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild> + </ClCompile> + <ClCompile Include="D3D\startup.cpp" /> + <ClCompile Include="ML\TempBuffers.cpp" /> + <ClCompile Include="ML\tensorOpsTests.cpp" /> + <ClCompile Include="stdafx.cpp"> + <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader> + <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader> + </ClCompile> + <ClCompile Include="ML\testUtils.cpp" /> + <ClCompile Include="ML\Tensor.cpp" /> + <ClCompile Include="ML\TensorGpuViews.cpp" /> + <ClCompile Include="ML\TensorEx.cpp" /> + <ClCompile Include="whisperCom.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="API\iContext.h" /> + <ClInclude Include="API\iMediaFoundation.h" /> + <ClInclude Include="API\iTranscribeResult.h" /> + <ClInclude Include="API\loggerApi.h" /> + <ClInclude Include="API\MfStructs.h" /> + <ClInclude Include="API\iContext.cl.h" /> + <ClInclude Include="API\iMediaFoundation.cl.h" /> + <ClInclude Include="API\iTranscribeResult.cl.h" /> + <ClInclude Include="API\sLanguageList.h" /> + <ClInclude Include="API\sLoadModelCallbacks.h" /> + <ClInclude Include="API\SpecialTokens.h" /> + <ClInclude Include="API\sFullParams.h" /> + <ClInclude Include="API\whisperComLight.h" /> + <ClInclude Include="API\whisperWindows.h" /> + <ClInclude Include="CPU\BufferAllocator.h" /> + <ClInclude Include="CPU\mulMatUtils.hpp" /> + <ClInclude Include="CPU\Tensor.h" /> + <ClInclude Include="CPU\DecoderTensors.h" /> + <ClInclude Include="CPU\HybridLoader.h" /> + <ClInclude Include="Hybrid\HybridContext.h" /> + <ClInclude Include="CPU\ParallelForRunner.h" /> + <ClInclude Include="CPU\LargeBuffer.h" /> + <ClInclude Include="CPU\simdUtils.h" /> + <ClInclude Include="CPU\MlContext.h" /> + <ClInclude Include="CPU\KvTensors.h" /> + <ClInclude Include="Hybrid\KeyValueDownloader.h" /> + <ClInclude Include="ML\reshapedMultiply.h" /> + <ClInclude Include="ML\testUtilsC.h" /> + <ClInclude Include="CPU\mulMat.h" /> + <ClInclude Include="CPU\mulMatImpl.h" /> + <ClInclude Include="ML\Reshaper.h" /> + <ClInclude Include="Utils\Logger.h" /> + <ClInclude Include="MF\AudioCapture.h" /> + <ClInclude Include="Whisper\sModelParams.h" /> + <ClInclude Include="Whisper\voiceActivityDetection.h" /> + <ClInclude Include="Whisper\MelStreamer.h" /> + <ClInclude Include="Whisper\melSpectrogram.h" /> + <ClInclude Include="modelFactory.h" /> + <ClInclude Include="MF\AudioBuffer.h" /> + <ClInclude Include="MF\PcmReader.h" /> + <ClInclude Include="Utils\miscUtils.h" /> + <ClInclude Include="Utils\Trace\tracing.h" /> + <ClInclude Include="Utils\Trace\TraceStructures.h" /> + <ClInclude Include="Utils\Trace\TraceWriter.h" /> + <ClInclude Include="source.compat\convertThings.h" /> + <ClInclude Include="MF\mfUtils.h" /> + <ClInclude Include="MF\loadAudioFile.h" /> + <ClInclude Include="MF\mfStartup.h" /> + <ClInclude Include="API\TranscribeStructs.h" /> + <ClInclude Include="resource.h" /> + <ClInclude Include="Utils\ReadStream.h" /> + <ClInclude Include="Whisper\audioConstants.h" /> + <ClInclude Include="Whisper\iSpectrogram.h" /> + <ClInclude Include="Whisper\sTokenData.h" /> + <ClInclude Include="Whisper\TranscribeResult.h" /> + <ClInclude Include="Utils\ProfileCollection.h" /> + <ClInclude Include="Utils\CpuProfiler.h" /> + <ClInclude Include="Utils\GpuProfiler.h" /> + <ClInclude Include="ML\TensorsArena.h" /> + <ClInclude Include="Utils\GpuProfilerSimple.h" /> + <ClInclude Include="Whisper\Languages.h" /> + <ClInclude Include="Whisper\ContextImpl.h" /> + <ClInclude Include="Whisper\ModelImpl.h" /> + <ClInclude Include="Utils\parallelFor.h" /> + <ClInclude Include="Whisper\Spectrogram.h" /> + <ClInclude Include="Whisper\loaderUtils.h" /> + <ClInclude Include="Whisper\WhisperModel.h" /> + <ClInclude Include="Whisper\Vocabulary.h" /> + <ClInclude Include="Whisper\DecoderResultBuffer.h" /> + <ClInclude Include="Whisper\DecoderInputBuffers.h" /> + <ClInclude Include="ML\mlStartup.h" /> + <ClInclude Include="Whisper\KeyValueBuffers.h" /> + <ClInclude Include="D3D\Binder.h" /> + <ClInclude Include="ML\LookupTables.h" /> + <ClInclude Include="ML\LookupTablesData.h" /> + <ClInclude Include="ML\MlContext.h" /> + <ClInclude Include="ML\ConstantBuffer.h" /> + <ClInclude Include="D3D\device.h" /> + <ClInclude Include="D3D\createBuffer.h" /> + <ClInclude Include="D3D\enums.h" /> + <ClInclude Include="D3D\MappedResource.h" /> + <ClInclude Include="Whisper\sEncodeParams.h" /> + <ClInclude Include="Whisper\WhisperContext.h" /> + <ClInclude Include="Whisper\ModelBuffers.h" /> + <ClInclude Include="Whisper\ModelLoader.h" /> + <ClInclude Include="D3D\RenderDoc\renderdoc_app.h" /> + <ClInclude Include="D3D\shaderNames.h" /> + <ClInclude Include="D3D\shaders.h" /> + <ClInclude Include="D3D\downloadBuffer.h" /> + <ClInclude Include="D3D\RenderDoc\renderDoc.h" /> + <ClInclude Include="Whisper\MelInputTensor.h" /> + <ClInclude Include="ML\TensorShape.h" /> + <ClInclude Include="source\ggml.h" /> + <ClInclude Include="source\whisper.h" /> + <ClInclude Include="D3D\startup.h" /> + <ClInclude Include="ML\TempBuffers.h" /> + <ClInclude Include="ML\tensorOpsTests.h" /> + <ClInclude Include="stdafx.h" /> + <ClInclude Include="ML\testUtils.h" /> + <ClInclude Include="ML\Tensor.h" /> + <ClInclude Include="ML\TensorGpuViews.h" /> + <ClInclude Include="ML\TensorEx.h" /> + </ItemGroup> + <ItemGroup> + <None Include="D3D\shaderData-Debug.inl" /> + <None Include="D3D\shaderData-Release.inl" /> + <None Include="CPU\mulMat.kernel.hpp" /> + <None Include="source\LICENSE" /> + <None Include="whisper.def" /> + <None Include="Whisper\languageCodez.inl" /> + <None Include="Whisper\languageCodez.tsv" /> + </ItemGroup> + <ItemGroup> + <Natvis Include="misc.natvis" /> + </ItemGroup> + <ItemGroup> + <ResourceCompile Include="Resource.rc" /> + </ItemGroup> + <ItemGroup> + <Text Include="API\Readme.txt" /> + <Text Include="CPU\Readme.txt" /> + <Text Include="Hybrid\Readme.txt" /> + <Text Include="Readme.txt" /> + <Text Include="source.compat\Readme.txt" /> + <Text Include="source\Readme.txt" /> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/Whisper/Whisper.vcxproj.filters b/Whisper/Whisper.vcxproj.filters new file mode 100644 index 0000000..193fbe7 --- /dev/null +++ b/Whisper/Whisper.vcxproj.filters @@ -0,0 +1,214 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <ClCompile Include="DllMain.cpp" /> + <ClCompile Include="whisperCom.cpp" /> + <ClCompile Include="source\ggml.c" /> + <ClCompile Include="source\whisper.cpp" /> + <ClCompile Include="D3D\device.cpp" /> + <ClCompile Include="ML\ConstantBuffer.cpp" /> + <ClCompile Include="D3D\MappedResource.cpp" /> + <ClCompile Include="D3D\shaders.cpp" /> + <ClCompile Include="D3D\startup.cpp" /> + <ClCompile Include="D3D\createBuffer.cpp" /> + <ClCompile Include="ML\TempBuffers.cpp" /> + <ClCompile Include="ML\MlContext.cpp" /> + <ClCompile Include="ML\tensorOpsTests.cpp" /> + <ClCompile Include="D3D\Binder.cpp" /> + <ClCompile Include="stdafx.cpp" /> + <ClCompile Include="ML\testUtils.cpp" /> + <ClCompile Include="D3D\downloadBuffer.cpp" /> + <ClCompile Include="D3D\RenderDoc\renderDoc.cpp" /> + <ClCompile Include="Whisper\MelInputTensor.cpp" /> + <ClCompile Include="Whisper\ModelBuffers.cpp" /> + <ClCompile Include="ML\Context.ops.cpp" /> + <ClCompile Include="ML\TensorShape.cpp" /> + <ClCompile Include="ML\Tensor.cpp" /> + <ClCompile Include="ML\TensorGpuViews.cpp" /> + <ClCompile Include="ML\TensorEx.cpp" /> + <ClCompile Include="Whisper\WhisperContext.cpp" /> + <ClCompile Include="ML\MlContext.dbg.cpp" /> + <ClCompile Include="ML\LookupTablesData.cpp" /> + <ClCompile Include="ML\LookupTables.cpp" /> + <ClCompile Include="Whisper\KeyValueBuffers.cpp" /> + <ClCompile Include="ML\mlStartup.cpp" /> + <ClCompile Include="Whisper\DecoderInputBuffers.cpp" /> + <ClCompile Include="Whisper\DecoderResultBuffer.cpp" /> + <ClCompile Include="Whisper\Vocabulary.cpp" /> + <ClCompile Include="Whisper\WhisperModel.cpp" /> + <ClCompile Include="Whisper\Spectrogram.cpp" /> + <ClCompile Include="Utils\parallelFor.cpp" /> + <ClCompile Include="Whisper\ModelImpl.cpp" /> + <ClCompile Include="Whisper\ContextImpl.cpp" /> + <ClCompile Include="Whisper\Languages.cpp" /> + <ClCompile Include="ML\TensorsArena.cpp" /> + <ClCompile Include="D3D\enums.cpp" /> + <ClCompile Include="Utils\GpuProfiler.cpp" /> + <ClCompile Include="Utils\CpuProfiler.cpp" /> + <ClCompile Include="Utils\ProfileCollection.cpp" /> + <ClCompile Include="D3D\shaderNames.cpp" /> + <ClCompile Include="MF\mfStartup.cpp" /> + <ClCompile Include="MF\MediaFoundation.cpp" /> + <ClCompile Include="MF\loadAudioFile.cpp" /> + <ClCompile Include="MF\mfUtils.cpp" /> + <ClCompile Include="source.compat\convertThings.cpp" /> + <ClCompile Include="source.compat\ggmlMsvc.c" /> + <ClCompile Include="Whisper\ContextImpl.misc.cpp" /> + <ClCompile Include="Utils\Trace\TraceWriter.cpp" /> + <ClCompile Include="Utils\Trace\TraceStructures.cpp" /> + <ClCompile Include="Utils\Trace\tracing.cpp" /> + <ClCompile Include="MF\AudioBuffer.cpp" /> + <ClCompile Include="modelFactory.cpp" /> + <ClCompile Include="MF\PcmReader.cpp" /> + <ClCompile Include="Whisper\melSpectrogram.cpp" /> + <ClCompile Include="Whisper\MelStreamer.cpp" /> + <ClCompile Include="Utils\miscUtils.cpp" /> + <ClCompile Include="MF\AudioCapture.cpp" /> + <ClCompile Include="Utils\Logger.cpp" /> + <ClCompile Include="Whisper\ContextImpl.capture.cpp" /> + <ClCompile Include="Whisper\voiceActivityDetection.cpp" /> + <ClCompile Include="CPU\LargeBuffer.cpp" /> + <ClCompile Include="CPU\ParallelForRunner.cpp" /> + <ClCompile Include="CPU\simdUtils.cpp" /> + <ClCompile Include="CPU\mulMat.cpp" /> + <ClCompile Include="CPU\TensorCpu.cpp" /> + <ClCompile Include="CPU\MlContextCpu.cpp" /> + <ClCompile Include="CPU\BufferAllocator.cpp" /> + <ClCompile Include="CPU\HybridLoader.cpp" /> + <ClCompile Include="CPU\DecoderTensors.cpp" /> + <ClCompile Include="Hybrid\HybridContext.cpp" /> + <ClCompile Include="CPU\KvTensorsCpu.cpp" /> + <ClCompile Include="Hybrid\KeyValueDownloader.cpp" /> + <ClCompile Include="CPU\mulMatImpl.cpp" /> + <ClCompile Include="CPU\mulMatImpl.avx2.cpp" /> + <ClCompile Include="CPU\mulMatImpl.panel.cpp" /> + <ClCompile Include="ML\Reshaper.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="source\ggml.h" /> + <ClInclude Include="source\whisper.h" /> + <ClInclude Include="API\iContext.cl.h" /> + <ClInclude Include="API\sFullParams.h" /> + <ClInclude Include="D3D\device.h" /> + <ClInclude Include="ML\ConstantBuffer.h" /> + <ClInclude Include="D3D\MappedResource.h" /> + <ClInclude Include="D3D\shaderNames.h" /> + <ClInclude Include="D3D\shaders.h" /> + <ClInclude Include="D3D\startup.h" /> + <ClInclude Include="D3D\createBuffer.h" /> + <ClInclude Include="ML\TempBuffers.h" /> + <ClInclude Include="ML\MlContext.h" /> + <ClInclude Include="ML\tensorOpsTests.h" /> + <ClInclude Include="D3D\Binder.h" /> + <ClInclude Include="stdafx.h" /> + <ClInclude Include="ML\testUtils.h" /> + <ClInclude Include="D3D\downloadBuffer.h" /> + <ClInclude Include="D3D\RenderDoc\renderdoc_app.h" /> + <ClInclude Include="D3D\RenderDoc\renderDoc.h" /> + <ClInclude Include="Whisper\MelInputTensor.h" /> + <ClInclude Include="Whisper\ModelBuffers.h" /> + <ClInclude Include="Whisper\ModelLoader.h" /> + <ClInclude Include="ML\TensorShape.h" /> + <ClInclude Include="ML\Tensor.h" /> + <ClInclude Include="ML\TensorGpuViews.h" /> + <ClInclude Include="D3D\enums.h" /> + <ClInclude Include="ML\TensorEx.h" /> + <ClInclude Include="Whisper\WhisperContext.h" /> + <ClInclude Include="ML\LookupTablesData.h" /> + <ClInclude Include="ML\LookupTables.h" /> + <ClInclude Include="Whisper\sEncodeParams.h" /> + <ClInclude Include="Whisper\KeyValueBuffers.h" /> + <ClInclude Include="ML\mlStartup.h" /> + <ClInclude Include="Whisper\DecoderInputBuffers.h" /> + <ClInclude Include="Whisper\DecoderResultBuffer.h" /> + <ClInclude Include="Whisper\Vocabulary.h" /> + <ClInclude Include="Whisper\WhisperModel.h" /> + <ClInclude Include="Whisper\loaderUtils.h" /> + <ClInclude Include="Whisper\Spectrogram.h" /> + <ClInclude Include="Utils\parallelFor.h" /> + <ClInclude Include="Whisper\ModelImpl.h" /> + <ClInclude Include="Whisper\ContextImpl.h" /> + <ClInclude Include="Whisper\Languages.h" /> + <ClInclude Include="ML\TensorsArena.h" /> + <ClInclude Include="Utils\GpuProfiler.h" /> + <ClInclude Include="Utils\GpuProfilerSimple.h" /> + <ClInclude Include="Utils\CpuProfiler.h" /> + <ClInclude Include="Utils\ProfileCollection.h" /> + <ClInclude Include="MF\mfStartup.h" /> + <ClInclude Include="API\iMediaFoundation.cl.h" /> + <ClInclude Include="MF\loadAudioFile.h" /> + <ClInclude Include="MF\mfUtils.h" /> + <ClInclude Include="API\TranscribeStructs.h" /> + <ClInclude Include="API\iTranscribeResult.cl.h" /> + <ClInclude Include="Whisper\TranscribeResult.h" /> + <ClInclude Include="Utils\ReadStream.h" /> + <ClInclude Include="API\SpecialTokens.h" /> + <ClInclude Include="Whisper\sTokenData.h" /> + <ClInclude Include="resource.h" /> + <ClInclude Include="source.compat\convertThings.h" /> + <ClInclude Include="Utils\Trace\TraceWriter.h" /> + <ClInclude Include="Utils\Trace\TraceStructures.h" /> + <ClInclude Include="Utils\Trace\tracing.h" /> + <ClInclude Include="Utils\miscUtils.h" /> + <ClInclude Include="MF\AudioBuffer.h" /> + <ClInclude Include="modelFactory.h" /> + <ClInclude Include="Whisper\iSpectrogram.h" /> + <ClInclude Include="MF\PcmReader.h" /> + <ClInclude Include="Whisper\audioConstants.h" /> + <ClInclude Include="Whisper\melSpectrogram.h" /> + <ClInclude Include="Whisper\MelStreamer.h" /> + <ClInclude Include="API\MfStructs.h" /> + <ClInclude Include="MF\AudioCapture.h" /> + <ClInclude Include="API\loggerApi.h" /> + <ClInclude Include="Utils\Logger.h" /> + <ClInclude Include="Whisper\voiceActivityDetection.h" /> + <ClInclude Include="CPU\LargeBuffer.h" /> + <ClInclude Include="API\iContext.h" /> + <ClInclude Include="API\iMediaFoundation.h" /> + <ClInclude Include="API\iTranscribeResult.h" /> + <ClInclude Include="API\whisperComLight.h" /> + <ClInclude Include="API\whisperWindows.h" /> + <ClInclude Include="API\sLanguageList.h" /> + <ClInclude Include="CPU\ParallelForRunner.h" /> + <ClInclude Include="CPU\simdUtils.h" /> + <ClInclude Include="ML\testUtilsC.h" /> + <ClInclude Include="CPU\mulMat.h" /> + <ClInclude Include="CPU\Tensor.h" /> + <ClInclude Include="CPU\MlContext.h" /> + <ClInclude Include="CPU\BufferAllocator.h" /> + <ClInclude Include="CPU\DecoderTensors.h" /> + <ClInclude Include="CPU\HybridLoader.h" /> + <ClInclude Include="Whisper\sModelParams.h" /> + <ClInclude Include="Hybrid\HybridContext.h" /> + <ClInclude Include="CPU\KvTensors.h" /> + <ClInclude Include="Hybrid\KeyValueDownloader.h" /> + <ClInclude Include="CPU\mulMatUtils.hpp" /> + <ClInclude Include="CPU\mulMatImpl.h" /> + <ClInclude Include="API\sLoadModelCallbacks.h" /> + <ClInclude Include="ML\Reshaper.h" /> + <ClInclude Include="ML\reshapedMultiply.h" /> + </ItemGroup> + <ItemGroup> + <None Include="whisper.def" /> + <None Include="D3D\shaderData-Debug.inl" /> + <None Include="D3D\shaderData-Release.inl" /> + <None Include="Whisper\languageCodez.inl" /> + <None Include="Whisper\languageCodez.tsv" /> + <None Include="CPU\mulMat.kernel.hpp" /> + <None Include="source\LICENSE" /> + </ItemGroup> + <ItemGroup> + <Natvis Include="misc.natvis" /> + </ItemGroup> + <ItemGroup> + <ResourceCompile Include="Resource.rc" /> + </ItemGroup> + <ItemGroup> + <Text Include="Readme.txt" /> + <Text Include="API\Readme.txt" /> + <Text Include="source.compat\Readme.txt" /> + <Text Include="source\Readme.txt" /> + <Text Include="Hybrid\Readme.txt" /> + <Text Include="CPU\Readme.txt" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/Whisper/Whisper/ContextImpl.capture.cpp b/Whisper/Whisper/ContextImpl.capture.cpp new file mode 100644 index 0000000..86dc0d2 --- /dev/null +++ b/Whisper/Whisper/ContextImpl.capture.cpp @@ -0,0 +1,418 @@ +#include "stdafx.h" +#include "ContextImpl.h" +#include "../API/iMediaFoundation.cl.h" +#include "../MF/AudioBuffer.h" +#include "../MF/mfUtils.h" +#include <mfidl.h> +#include <mfapi.h> +#include <mfreadwrite.h> +#include "voiceActivityDetection.h" + +namespace +{ + using namespace Whisper; + + class TranscribeBuffer : public ComLight::ObjectRoot<iAudioBuffer> + { + // ==== iAudioBuffer ==== + uint32_t COMLIGHTCALL countSamples() const override final + { + return (uint32_t)pcm.mono.size(); + } + const float* COMLIGHTCALL getPcmMono() const override final + { + if( !pcm.mono.empty() ) + return pcm.mono.data(); + return nullptr; + } + const float* COMLIGHTCALL getPcmStereo() const override final + { + if( !pcm.stereo.empty() ) + return pcm.stereo.data(); + return nullptr; + } + HRESULT COMLIGHTCALL getTime( int64_t& rdi ) const override final + { + rdi = MFllMulDiv( currentOffset, 10'000'000, SAMPLE_RATE, 0 ); + return S_OK; + } + public: + AudioBuffer pcm; + int64_t currentOffset = 0; + }; + + class TranscribeBufferObj : public ComLight::Object<TranscribeBuffer> + { + uint32_t Release() override final + { + return RefCounter::implRelease(); + } + }; + + struct CaptureParams + { + uint32_t minDuration, maxDuration, dropStartSilence, pauseDuration; + uint32_t flags; + + CaptureParams( const sCaptureParams& cp ) + { + // Convert these floats from seconds to samples + __m128 floats = _mm_loadu_ps( &cp.minDuration ); + floats = _mm_mul_ps( floats, _mm_set1_ps( (float)SAMPLE_RATE ) ); + floats = _mm_round_ps( floats, _MM_FROUND_NINT ); + __m128i ints = _mm_cvtps_epi32( floats ); + store16( &minDuration, ints ); + + flags = cp.flags; + } + }; + + class Capture + { + CComPtr<IMFSourceReader> reader; + const CaptureParams captureParams; + const sCaptureCallbacks callbacks; + // Count of channels delivered from the source reader + uint8_t readerChannels = 0; + volatile char stateFlags = 0; + + PTP_WORK work = nullptr; + volatile HRESULT workStatus = S_OK; + + TranscribeBufferObj buffer; + CComAutoCriticalSection critSec; + AudioBuffer pcm; + AudioBuffer::pfnAppendSamples pfnAppendSamples = nullptr; + int64_t pcmStartTime = 0; + int64_t nextSampleTime = 0; + VAD vad; + sFullParams fullParams; + ProfileCollection& profiler; + iContext* const whisperContext; + + HRESULT setStateFlag( eCaptureStatus newBit ) noexcept + { + const uint8_t bit = (uint8_t)newBit; + const uint8_t oldVal = (uint8_t)InterlockedOr8( &stateFlags, (char)bit ); + if( nullptr == callbacks.captureStatus ) + return S_OK; // no callbacks + if( 0 != ( oldVal & bit ) ) + return S_OK; // The bit was already set + return callbacks.captureStatus( callbacks.pv, (eCaptureStatus)( oldVal | bit ) ); + } + + HRESULT clearStateFlag( eCaptureStatus clearBit ) noexcept + { + const uint8_t bit = (uint8_t)clearBit; + const uint8_t mask = ~bit; + const uint8_t oldVal = (uint8_t)InterlockedAnd8( &stateFlags, (char)mask ); + if( nullptr == callbacks.captureStatus ) + return S_OK; // no callbacks + if( 0 == ( oldVal & bit ) ) + return S_OK; // The bit wasn't there + return callbacks.captureStatus( callbacks.pv, (eCaptureStatus)( oldVal & mask ) ); + } + + bool hasStateFlag( eCaptureStatus testBit ) const + { + const uint8_t bit = (uint8_t)testBit; + return 0 != ( (uint8_t)stateFlags & bit ); + } + + HRESULT workCallback(); + static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ); + + HRESULT readSample( bool discard ); + + // Run voice detection on the data in pcm.mono vector. + // When not detected, return 0. When detected, return last frame index where it is detected. + size_t detectVoice(); + + HRESULT postPoolWork() + { + assert( workStatus == S_OK ); + CHECK( setStateFlag( eCaptureStatus::Transcribing ) ); + + workStatus = S_FALSE; + buffer.currentOffset = pcmStartTime; + buffer.pcm.mono = pcm.mono; + buffer.pcm.stereo = pcm.stereo; + SubmitThreadpoolWork( work ); + pcmStartTime = nextSampleTime; + pcm.clear(); + vad.clear(); + return S_OK; + } + + public: + Capture( const sCaptureCallbacks& cb, const iAudioCapture* ac, const sFullParams& sfp, iContext* wc, ProfileCollection& pc ) : + callbacks( cb ), + captureParams( ac->getParams() ), + fullParams( sfp ), whisperContext( wc ), profiler( pc ) + { + } + + ~Capture() + { + if( workStatus == S_FALSE && nullptr != work ) + WaitForThreadpoolWorkCallbacks( work, FALSE ); + + if( nullptr != work ) + { + CloseThreadpoolWork( work ); + work = nullptr; + } + } + + HRESULT startup( const iAudioCapture* ac ); + + HRESULT checkCancel() noexcept + { + if( nullptr == callbacks.shouldCancel ) + return S_FALSE; + return callbacks.shouldCancel( callbacks.pv ); + } + + HRESULT run(); + }; + + HRESULT Capture::startup( const iAudioCapture* ac ) + { + // Initialize the MF source reader + CHECK( ac->getReader( &reader ) ); + work = CreateThreadpoolWork( &callbackStatic, this, nullptr ); + if( nullptr == work ) + return HRESULT_FROM_WIN32( GetLastError() ); + + // Set up media type, and figure out sample handler + CHECK( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) ); + CHECK( reader->SetStreamSelection( MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE ) ); + + CComPtr<IMFMediaType> mtNative; + CHECK( reader->GetNativeMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, MF_SOURCE_READER_CURRENT_TYPE_INDEX, &mtNative ) ); + UINT32 numChannels; + CHECK( mtNative->GetUINT32( MF_MT_AUDIO_NUM_CHANNELS, &numChannels ) ); + + const bool sourceMono = numChannels < 2; + const bool wantStereo = 0 != ( captureParams.flags & (uint32_t)eCaptureFlags::Stereo ); + pfnAppendSamples = AudioBuffer::appendSamplesFunc( sourceMono, wantStereo ); + + CComPtr<IMFMediaType> mt; + this->readerChannels = ( !sourceMono && wantStereo ) ? 2 : 1; + CHECK( createMediaType( !sourceMono, &mt ) ); + CHECK( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) ); + + CHECK( setStateFlag( eCaptureStatus::Listening ) ); + return S_OK; + } + + HRESULT Capture::run() + { + HRESULT hr; + if( hasStateFlag( eCaptureStatus::Stalled ) ) + { + hr = workStatus; + CHECK( hr ); + if( S_OK != hr ) + { + // Still stalled, discard the upcoming sample + return readSample( true ); + } + else + { + // The postponed task has completed by now, no longer stalled + // Move the current PCM buffer to the transcribe thread + CHECK( clearStateFlag( eCaptureStatus::Stalled ) ); + return postPoolWork(); + } + } + + const size_t oldSamples = pcm.mono.size(); + CHECK( readSample( false ) ); + const size_t newSamples = pcm.mono.size(); + + const size_t lastVoiceFrame = detectVoice(); + if( lastVoiceFrame == 0 ) + { + // No voice is detected in the entire buffered audio + clearStateFlag( eCaptureStatus::Voice ); + if( newSamples < captureParams.dropStartSilence ) + return S_OK; + + pcm.clear(); + vad.clear(); + pcmStartTime = nextSampleTime; + return S_OK; + } + + const bool newFrameVoice = lastVoiceFrame + captureParams.pauseDuration >= oldSamples; + + if( newFrameVoice ) + { + // A voice is detected in the buffer, and it was fairly recently + setStateFlag( eCaptureStatus::Voice ); + if( newSamples < captureParams.maxDuration ) + return S_OK; // While voice is continuously detected, we allow to grow the buffer up to `maxDuration` time + } + else + { + // A voice is detected in the buffer, but it was a while ago + clearStateFlag( eCaptureStatus::Voice ); + if( newSamples < captureParams.minDuration ) + return S_OK; // When detected pause in the voice, we fire the transcribe task right away. + } + + // Hopefully, we have enough captured PCM data to run the ASR model. + // Check the background task status first. + hr = workStatus; + CHECK( hr ); + if( hr == S_OK ) + return postPoolWork(); + + // workStatus = S_FALSE means the previous task has not finished yet. + // We don't want concurrency here because it's not implemented, and will simply crash. + // The "Stalled" flag which will cause capture to drop further samples. + setStateFlag( eCaptureStatus::Stalled ); + return S_OK; + } + + HRESULT Capture::readSample( bool discard ) + { + while( true ) + { + DWORD dwFlags = 0; + CComPtr<IMFSample> sample; + + // Read the next sample + HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"IMFSourceReader.ReadSample" ); + return hr; + } + + if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED ) + { + logError( u8"Media type changes ain’t supported by the library." ); + return E_UNEXPECTED; + } + + if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM ) + return E_EOF; + + if( !sample ) + continue; + + // Get a pointer to the audio data in the sample. + CComPtr<IMFMediaBuffer> buffer; + hr = sample->ConvertToContiguousBuffer( &buffer ); + if( FAILED( hr ) ) + return hr; + + const float* pAudioData = nullptr; + DWORD cbBuffer; + hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer ); + if( FAILED( hr ) ) + return hr; + + try + { + assert( 0 == ( cbBuffer % sizeof( float ) ) ); + const size_t countFloats = cbBuffer / sizeof( float ); + if( !discard ) + { + const size_t prevSize = pcm.mono.size(); + ( pcm.*pfnAppendSamples )( pAudioData, countFloats ); + const size_t newSize = pcm.mono.size(); + this->nextSampleTime += ( newSize - prevSize ); + } + else + { + this->nextSampleTime += countFloats / readerChannels; + } + } + catch( const std::bad_alloc& ) + { + buffer->Unlock(); + return E_OUTOFMEMORY; + } + + // Unlock the buffer + hr = buffer->Unlock(); + if( FAILED( hr ) ) + return hr; + + return S_OK; + } + } + + HRESULT Capture::workCallback() + { + CHECK( whisperContext->runFull( fullParams, &buffer ) ); + CHECK( clearStateFlag( eCaptureStatus::Transcribing ) ); + return S_OK; + } + + void __stdcall Capture::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ) + { + Capture* pThis = (Capture*)pv; + HRESULT status = E_UNEXPECTED; + try + { + status = pThis->workCallback(); + } + catch( HRESULT hr ) + { + status = hr; + } + catch( const std::bad_alloc& ) + { + status = E_OUTOFMEMORY; + } + catch( const std::exception& ) + { + status = E_FAIL; + } + assert( S_OK == status || FAILED( status ) ); + pThis->workStatus = status; + } + + size_t Capture::detectVoice() + { + auto pf = profiler.cpuBlock( eCpuBlock::VAD ); + return vad.detect( pcm.mono.data(), pcm.mono.size() ); + } +} + +HRESULT COMLIGHTCALL ContextImpl::runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) +{ + if( nullptr == reader ) + return E_POINTER; + + // Validate a few things + { + const auto& cp = reader->getParams(); + if( cp.minDuration < 0.125f || cp.minDuration > 30.0f ) + { + logError( u8"%s parameter %g is out of range", "minDuration", cp.minDuration ); + return E_INVALIDARG; + } + if( cp.maxDuration < 0.125f || cp.maxDuration > 30.0f ) + { + logError( u8"%s parameter %g is out of range", "maxDuration", cp.maxDuration ); + return E_INVALIDARG; + } + } + + Capture capture{ callbacks, reader, params, this, profiler }; + CHECK( capture.startup( reader ) ); + + while( true ) + { + HRESULT hr = capture.checkCancel(); + CHECK( hr ); + if( hr == S_OK ) + return S_OK; + CHECK( capture.run() ); + } +}
\ No newline at end of file diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp new file mode 100644 index 0000000..a8e16f5 --- /dev/null +++ b/Whisper/Whisper/ContextImpl.cpp @@ -0,0 +1,528 @@ +#include "stdafx.h" +#include "ContextImpl.h" +#include "Languages.h" +#include "../Utils/Trace/tracing.h" +using namespace Whisper; + +ContextImpl::ContextImpl( const WhisperModel& modelData, iModel* modelPointer ) : + model( modelData ), + modelPtr( modelPointer ), + context( modelData, profiler ), + profiler( modelData ) +{ } + +#define WHISPER_CHUNK_SIZE 30 + +HRESULT ContextImpl::encode( iSpectrogram& mel, int seek ) +{ + // whisper_encode + using namespace DirectCompute; + + sEncodeParams ep; + ep.n_ctx = ( exp_n_audio_ctx > 0 ) ? exp_n_audio_ctx : model.parameters.n_audio_ctx; + ep.n_mels = model.parameters.n_mels; + ep.mel_offset = seek; + ep.layersCount = model.parameters.n_audio_layer; + ep.n_state = model.parameters.n_audio_state; + ep.n_head = model.parameters.n_audio_head; + ep.n_audio_ctx = model.parameters.n_audio_ctx; + ep.n_text_state = model.parameters.n_text_state; + ep.n_text_layer = model.parameters.n_text_layer; + ep.n_text_ctx = model.parameters.n_text_ctx; + try + { + auto cur = context.encode( mel, ep ); + Tracing::tensor( "encode-out", cur ); + return S_OK; + } + catch( HRESULT hr ) + { + return hr; + } +} + +HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int threads ) +{ + // whisper_decode + using namespace DirectCompute; + sDecodeParams dp; + dp.n_state = model.parameters.n_audio_state; + dp.n_head = model.parameters.n_audio_head; + dp.n_ctx = model.parameters.n_text_ctx; + dp.n_past = n_past; + dp.M = exp_n_audio_ctx > 0 ? exp_n_audio_ctx : model.parameters.n_audio_ctx; + dp.n_text_layer = model.parameters.n_text_layer; + dp.n_vocab = model.parameters.n_vocab; + + try + { + context.decode( tokens, (int)length, dp, probs, threads ); + return S_OK; + } + catch( HRESULT hr ) + { + return hr; + } +} + +// the most basic sampling scheme - select the top token +sTokenData ContextImpl::sampleBest( const float* probs, bool force_timestamp, bool is_initial ) +{ + // whisper_sample_best + const Vocabulary& vocab = model.vocab; + sTokenData result = { 0 }; + + size_t n_logits = vocab.size(); + + probs_id.clear(); + probs_id.reserve( n_logits ); + + for( size_t i = 0; i < n_logits; i++ ) + probs_id.emplace_back( probs[ i ], (int)i ); + + { + double sum_ts = 0.0; + double max_ts = -1.0; + double max_tx = -1.0; + + for( int i = 0; i < vocab.token_beg; i++ ) + max_tx = std::max( max_tx, probs_id[ i ].first ); + + const int i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; + const int i1 = is_initial ? vocab.token_beg + 101 : (int)n_logits; + + // the initial timestamp cannot be larger than 100 + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if( is_initial ) + { + for( int i = i0; i < n_logits; i++ ) + probs_id[ i ].first = -INFINITY; + } + + for( int i = vocab.token_beg; i < i1; i++ ) + { + sum_ts += probs_id[ i ].first; + if( probs_id[ i ].first > max_ts ) + { + max_ts = probs_id[ i ].first; + result.tid = probs_id[ i ].second; + } + } + + // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a + // timestamp token + if( sum_ts > max_tx || force_timestamp ) + { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 + for( int i = 0; i < vocab.token_beg; i++ ) + probs_id[ i ].first = -INFINITY; + } + + result.pt = (float)( max_ts / ( sum_ts + 1e-10 ) ); + result.ptsum = (float)sum_ts; + } + + // find the top K tokens + const int top_k = 4; + + std::partial_sort( + probs_id.begin(), + probs_id.begin() + top_k, probs_id.end(), + []( const std::pair<double, Vocabulary::id>& a, const std::pair<double, Vocabulary::id>& b ) { + return a.first > b.first; + } ); + + probs_id.resize( top_k ); + + //printf("\n"); + //for (int i = 0; i < (int) probs_id.size(); i++) { + // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); + //} + + int res = 0; + while( ( probs_id[ res ].second == vocab.token_sot || + probs_id[ res ].second == vocab.token_solm || + probs_id[ res ].second == vocab.token_not ) && + res < (int)probs_id.size() - 1 ) + { + res++; + } + + result.id = probs_id[ res ].second; + result.p = (float)probs_id[ res ].first; + + return result; +} + +sTokenData ContextImpl::sampleBest() +{ + const int n_vocab = model.vocab.n_vocab; + return sampleBest( probs.data() + ( probs.size() - n_vocab ), false, false ); +} + +sTokenData ContextImpl::sampleTimestamp( bool initial ) +{ + const int n_vocab = model.vocab.n_vocab; + return sampleBest( probs.data() + ( probs.size() - n_vocab ), true, initial ); +} + +void ContextImpl::expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum ) +{ + // whisper_exp_compute_token_level_timestamps + throw E_NOTIMPL; +} + +static std::string to_timestamp( int64_t t, bool comma = false ) +{ + int64_t msec = t * 10; + int64_t hr = msec / ( 1000 * 60 * 60 ); + msec = msec - hr * ( 1000 * 60 * 60 ); + int64_t min = msec / ( 1000 * 60 ); + msec = msec - min * ( 1000 * 60 ); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[ 32 ]; + snprintf( buf, sizeof( buf ), "%02d:%02d:%02d%s%03d", (int)hr, (int)min, (int)sec, comma ? "," : ".", (int)msec ); + + return std::string( buf ); +} + +HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const sProgressSink& progress, iSpectrogram& mel ) +{ + // Ported from whisper_full() function + result_all.clear(); + if( params.flag( eFullParamsFlags::SpeedupAudio ) ) + { + logError( u8"GPU model doesn't implement the SpeedupAudio flag" ); + return E_NOTIMPL; + } + + const int seek_start = params.offset_ms / 10; + const int seek_end = seek_start + ( params.duration_ms == 0 ? (int)mel.getLength() : params.duration_ms / 10 ); + + // if length of spectrogram is less than 1s (100 samples), then return + // basically don't process anything that is less than 1s + // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 + if( seek_end < 100 + seek_start ) + return S_FALSE; + + // the accumulated text context so far + if( params.flag( eFullParamsFlags::NoContext ) ) + prompt_past.clear(); + + // prepend the prompt tokens to the prompt_past + if( params.prompt_tokens && params.prompt_n_tokens > 0 ) + { + // parse tokens from the pointer + for( int i = 0; i < params.prompt_n_tokens; i++ ) + prompt_past.push_back( params.prompt_tokens[ i ] ); + std::rotate( prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end() ); + } + + // overwrite audio_ctx + exp_n_audio_ctx = params.audio_ctx; + + // these tokens determine the task that will be performed + std::vector<whisper_token> prompt_init = { model.vocab.token_sot }; + if( model.vocab.is_multilingual() ) + { + int langId = lookupLanguageId( params.language ); + if( langId < 0 ) + { + char lang[ 5 ]; + *(uint32_t*)( &lang[ 0 ] ) = params.language; + lang[ 4 ] = '\0'; + logError( u8"%s: unknown language '%s'", __func__, lang ); + return E_INVALIDARG; + } + + prompt_init.push_back( model.vocab.token_sot + 1 + langId ); + if( params.flag( eFullParamsFlags::Translate ) ) + prompt_init.push_back( model.vocab.token_translate ); + else + prompt_init.push_back( model.vocab.token_transcribe ); + } + + // int progress_prev = 0; + // int progress_step = 5; + + std::vector<sTokenData> tokens_cur; + tokens_cur.reserve( model.parameters.n_text_ctx ); + std::vector<whisper_token> prompt; + prompt.reserve( model.parameters.n_text_ctx ); + + // main loop + int seek = seek_start; + auto prof = context.completeProfiler(); + while( true ) + { + if( nullptr != progress.pfn ) + { + const int pos = seek - seek_start; + const int total = seek_end - seek_start; + const double percentage = (double)pos / (double)total; + CHECK( progress.pfn( percentage, this, progress.pv ) ); + } + /* + const int progress_cur = ( 100 * ( seek - seek_start ) ) / ( seek_end - seek_start ); + while( progress_cur >= progress_prev + progress_step ) + { + progress_prev += progress_step; + if( params.flag( eFullParamsFlags::PrintProgress ) ) + logInfo( u8"%s: progress = %3d%%", __func__, progress_prev ); + } + */ + + if( seek + 100 >= seek_end ) + break; + + if( nullptr != params.encoder_begin_callback ) + { + HRESULT hr = params.encoder_begin_callback( this, params.encoder_begin_callback_user_data ); + if( FAILED( hr ) ) + return hr; + if( hr != S_OK ) + break; + } + + // encode audio features starting at offset seek + CHECK( encode( mel, seek ) ); + + int n_past = 0; + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if( !prompt_past.empty() ) + { + int n_take = std::min( std::min( params.n_max_text_ctx, model.parameters.n_text_ctx / 2 ), int( prompt_past.size() ) ); + + prompt = { model.vocab.token_prev }; + prompt.insert( prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end() ); + + prompt_past.clear(); + prompt_past.insert( prompt_past.end(), prompt.begin() + 1, prompt.end() ); + } + + prompt.insert( prompt.end(), prompt_init.begin(), prompt_init.end() ); + + int seek_delta = 100 * WHISPER_CHUNK_SIZE; + + // print the prompt + //printf("\n\n"); + //for (int i = 0; i < prompt.size(); i++) { + // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); + //} + //printf("\n\n"); + + // the accumulated transcription in the current iteration + int result_len = 0; + tokens_cur.clear(); + + bool failed = false; + bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + + { + auto prof = context.decodeProfiler(); + for( int i = 0, n_max = model.parameters.n_text_ctx / 2 - 4; i < n_max; i++ ) + { + CHECK( decode( prompt.data(), prompt.size(), n_past, params.cpuThreads ) ); + + n_past += (int)prompt.size(); + prompt.clear(); + + // very basic greedy sampling strategy: + // + // - always take the most probable token + // + // more sophisticated sampling strategies could be implemented here, but we keep it simple + // feel free to experiment! + // + { + auto p = profiler.cpuBlock( eCpuBlock::Sample ); + const sTokenData token = ( i == 0 ) ? sampleTimestamp( true ) : sampleBest(); + + // timestamp token - update sliding window + if( token.id > model.vocab.token_beg ) + { + const int seek_delta_new = 2 * ( token.id - model.vocab.token_beg ); + + // do not allow to go back in time + if( has_ts && seek_delta > seek_delta_new && result_len < i ) + break; + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + + // add it to the context + prompt.push_back( token.id ); + tokens_cur.push_back( token ); + + //{ + // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; + // printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + //} + + // end of segment + if( token.id == model.vocab.token_eot || // end of text token + ( params.max_tokens > 0 && i >= params.max_tokens ) || // max tokens per segment reached + ( has_ts && seek + seek_delta + 100 >= seek_end ) // end of audio reached + ) + { + if( result_len == 0 ) + { + if( seek + seek_delta + 100 >= seek_end ) + result_len = i + 1; + else + { + failed = true; + break; + } + } + + if( params.flag( eFullParamsFlags::SingleSegment ) ) + { + result_len = i + 1; + seek_delta = 100 * WHISPER_CHUNK_SIZE; + } + + break; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance + // the sliding window by 1 second + if( i == n_max - 1 && ( result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2 ) ) + { + failed = true; + break; + } + } + } + + if( failed ) + { + logError( u8"%s: failed to generate timestamp token - skipping one second", __func__ ); + seek += 100; + continue; + } + + // shrink down to result_len + tokens_cur.resize( result_len ); + + for( const auto& r : tokens_cur ) + prompt_past.push_back( r.id ); + + // store the text from this iteration + if( !tokens_cur.empty() ) + { + int i0 = 0; + int t0 = seek + 2 * ( tokens_cur.front().tid - model.vocab.token_beg ); + std::string text = ""; + + for( int i = 0; i < (int)tokens_cur.size(); i++ ) + { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + if( params.flag( eFullParamsFlags::PrintSpecial ) || tokens_cur[ i ].id < model.vocab.token_eot ) + text += model.vocab.string( tokens_cur[ i ].id ); + + if( tokens_cur[ i ].id > model.vocab.token_beg && !params.flag( eFullParamsFlags::SingleSegment ) ) + { + const int t1 = seek + 2 * ( tokens_cur[ i ].tid - model.vocab.token_beg ); + if( !text.empty() ) + { + const bool speedUp = params.flag( eFullParamsFlags::SpeedupAudio ); + const int tt0 = speedUp ? 2 * t0 : t0; + const int tt1 = speedUp ? 2 * t1 : t1; + + if( params.flag( eFullParamsFlags::PrintRealtime ) ) + { + if( params.flag( eFullParamsFlags::PrintTimestamps ) ) + printf( "[%s --> %s] %s\n", to_timestamp( tt0 ).c_str(), to_timestamp( tt1 ).c_str(), text.c_str() ); + else + { + printf( "%s", text.c_str() ); + fflush( stdout ); + } + } + + result_all.push_back( { tt0, tt1, text, {} } ); + for( int j = i0; j <= i; j++ ) + result_all.back().tokens.push_back( tokens_cur[ j ] ); + + int n_new = 1; + + if( params.flag( eFullParamsFlags::TokenTimestamps ) ) + { + expComputeTokenLevelTimestamps( (int)result_all.size() - 1, params.thold_pt, params.thold_ptsum ); + if( params.max_len > 0 ) + n_new = wrapSegment( params.max_len ); + } + if( nullptr != params.new_segment_callback ) + { + HRESULT hr = params.new_segment_callback( this, n_new, params.new_segment_callback_user_data ); + if( FAILED( hr ) ) + return hr; + } + } + text = ""; + while( i < (int)tokens_cur.size() && tokens_cur[ i ].id > model.vocab.token_beg ) + i++; + i--; + t0 = t1; + i0 = i + 1; + } + } + + if( !text.empty() ) + { + const int t1 = seek + seek_delta; + + const bool speedUp = params.flag( eFullParamsFlags::SpeedupAudio ); + const int tt0 = speedUp ? 2 * t0 : t0; + const int tt1 = speedUp ? 2 * t1 : t1; + + if( params.flag( eFullParamsFlags::PrintRealtime ) ) + { + if( params.flag( eFullParamsFlags::PrintTimestamps ) ) + printf( "[%s --> %s] %s\n", to_timestamp( tt0 ).c_str(), to_timestamp( tt1 ).c_str(), text.c_str() ); + else + { + printf( "%s", text.c_str() ); + fflush( stdout ); + } + } + + result_all.push_back( { tt0, tt1, text, {} } ); + for( int j = i0; j < (int)tokens_cur.size(); j++ ) + result_all.back().tokens.push_back( tokens_cur[ j ] ); + + int n_new = 1; + if( params.flag( eFullParamsFlags::TokenTimestamps ) ) + { + expComputeTokenLevelTimestamps( (int)result_all.size() - 1, params.thold_pt, params.thold_ptsum ); + if( params.max_len > 0 ) + n_new = wrapSegment( params.max_len ); + } + if( nullptr != params.new_segment_callback ) + { + HRESULT hr = params.new_segment_callback( this, n_new, params.new_segment_callback_user_data ); + if( FAILED( hr ) ) + return hr; + } + } + } + seek += seek_delta; + } + + if( nullptr != progress.pfn ) + { + CHECK( progress.pfn( 1.0, this, progress.pv ) ); + } + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h new file mode 100644 index 0000000..971f629 --- /dev/null +++ b/Whisper/Whisper/ContextImpl.h @@ -0,0 +1,75 @@ +#pragma once +#include "../API/iContext.cl.h" +#include "../ComLightLib/comLightServer.h" +#include "WhisperContext.h" +#include "Spectrogram.h" +#include "TranscribeResult.h" +#include "sTokenData.h" + +namespace Whisper +{ + class ContextImpl : public ComLight::ObjectRoot<iContext> + { + const WhisperModel& model; + ComLight::CComPtr<iModel> modelPtr; + DirectCompute::WhisperContext context; + Spectrogram spectrogram; + int64_t mediaTimeOffset = 0; + ProfileCollection profiler; + + HRESULT COMLIGHTCALL getModel( iModel** pp ) override final; + HRESULT COMLIGHTCALL timingsPrint() override final; + HRESULT COMLIGHTCALL timingsReset() override final; + HRESULT COMLIGHTCALL fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi ) override final; + HRESULT COMLIGHTCALL runFullImpl( const sFullParams& params, const sProgressSink& progress, iSpectrogram& mel ); + HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) override final; + HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) override final; + HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) override final; + + struct Segment + { + int64_t t0; + int64_t t1; + std::string text; + std::vector<sTokenData> tokens; + size_t memoryUsage() const; + }; + std::vector<Segment> result_all; + + std::vector<whisper_token> prompt_past; + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg = 0; + int64_t t_last = 0; + whisper_token tid_last = 0; + std::vector<float> energy; // PCM signal energy + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx = 0; // 0 - use default + + HRESULT encode( iSpectrogram& mel, int seek ); + HRESULT decode( const int* tokens, size_t length, int n_past, int threads ); + sTokenData sampleBest( const float* probs, bool force_timestamp, bool is_initial ); + sTokenData sampleBest(); + sTokenData sampleTimestamp( bool initial ); + int wrapSegment( int max_len ); + void expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum ); + + std::vector<float> probs; + std::vector<std::pair<double, Vocabulary::id>> probs_id; + + mutable TranscribeResultStatic results; + + HRESULT COMLIGHTCALL makeResults( eResultFlags flags, TranscribeResult& res ) const noexcept; + + HRESULT COMLIGHTCALL getResults( eResultFlags flags, iTranscribeResult** pp ) const noexcept override final; + + int defaultThreadsCount() const; + + __m128i getMemoryUse() const; + + public: + + ContextImpl( const WhisperModel& modelData, iModel* modelPointer ); + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp new file mode 100644 index 0000000..98d2164 --- /dev/null +++ b/Whisper/Whisper/ContextImpl.misc.cpp @@ -0,0 +1,408 @@ +#include "stdafx.h" +#include "ContextImpl.h" +#include <mfapi.h> +#include "MelStreamer.h" +#include "../API/iMediaFoundation.cl.h" +#include "../Utils/Trace/tracing.h" +using namespace Whisper; + +static int getCpuCoresCount() +{ + DWORD bufferSize = 0; + GetLogicalProcessorInformation( NULL, &bufferSize ); + + // The SYSTEM_LOGICAL_PROCESSOR_INFORMATION structure has a uint64_t field + // Ideally need to align by 8 bytes, and that's why uint64_t type for the storage + std::unique_ptr<uint64_t[]> buffer = std::make_unique<uint64_t[]>( ( bufferSize + 7 ) / 8 ); + + SYSTEM_LOGICAL_PROCESSOR_INFORMATION* ptr = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION*)buffer.get(); + if( !GetLogicalProcessorInformation( ptr, &bufferSize ) ) + { + HRESULT hr = getLastHr(); + logWarningHr( hr, u8"GetLogicalProcessorInformation" ); + return 0; + } + + DWORD byteOffset = 0; + int physicalCores = 0; + while( byteOffset < bufferSize ) + { + if( ptr->Relationship == RelationProcessorCore ) + physicalCores++; + byteOffset += sizeof( SYSTEM_LOGICAL_PROCESSOR_INFORMATION ); + ptr++; + } + return physicalCores; +} + +int ContextImpl::defaultThreadsCount() const +{ +#if BUILD_HYBRID_VERSION + const bool isHybrid = !model.hybridTensors.layers.empty(); +#else + constexpr bool isHybrid = false; +#endif + + SYSTEM_INFO si; + GetSystemInfo( &si ); + const int hardwareThreads = (int)si.dwNumberOfProcessors; + + if( !isHybrid ) + return std::min( hardwareThreads, 4 ); + + // It seems the CPU decoder in the hybrid context doesn’t scale well with count of hardware threads, but it does scale with count of physical cores. + int cores = getCpuCoresCount(); + if( cores > 1 ) + return cores; + + return hardwareThreads; +} + +HRESULT COMLIGHTCALL ContextImpl::fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi ) +{ + // whisper_full_default_params + if( nullptr == rdi ) + return E_POINTER; + memset( rdi, 0, sizeof( sFullParams ) ); + + rdi->strategy = strategy; + rdi->cpuThreads = defaultThreadsCount(); + rdi->n_max_text_ctx = 16384; + rdi->flags = eFullParamsFlags::PrintProgress | eFullParamsFlags::PrintTimestamps; + rdi->thold_pt = 0.01f; + rdi->thold_ptsum = 0.01f; + rdi->language = makeLanguageKey( "en" ); + + switch( strategy ) + { + case eSamplingStrategy::Greedy: + rdi->beam_search.n_past = -1; + rdi->beam_search.beam_width = -1; + rdi->beam_search.n_best = -1; + break; + case eSamplingStrategy::BeamSearch: + rdi->greedy.n_past = -1; + rdi->beam_search.beam_width = 10; + rdi->beam_search.n_best = 5; + break; + default: + logError( u8"Unknown sampling strategy %i", (int)strategy ); + return E_INVALIDARG; + } + return S_OK; +} + +HRESULT COMLIGHTCALL ContextImpl::getModel( iModel** pp ) +{ + if( nullptr == pp ) + return E_POINTER; + if( !modelPtr ) + return OLE_E_BLANK; + *pp = modelPtr; + modelPtr->AddRef(); + return S_OK; +} + +size_t ContextImpl::Segment::memoryUsage() const +{ + return text.capacity() + vectorMemoryUse( tokens ); +} + +__m128i ContextImpl::getMemoryUse() const +{ + // Misc. system RAM + size_t cb = vectorMemoryUse( result_all ); + for( const auto& r : result_all ) + cb += r.memoryUsage(); + cb += vectorMemoryUse( prompt_past ); + cb += vectorMemoryUse( energy ); + cb += vectorMemoryUse( probs ); + cb += vectorMemoryUse( probs_id ); + cb += vectorMemoryUse( results.segments ); + cb += vectorMemoryUse( results.tokens ); + cb += spectrogram.memoryUsage(); + + __m128i res = setLow_size( cb ); + // Add all the VRAM in the temporary buffers + res = _mm_add_epi64( res, context.getMemoryUse() ); + return res; +} + +namespace +{ + struct PrintedSize + { + double val; + const char* unit; + PrintedSize( int64_t cb ) + { + if( cb < ( 1 << 10 ) ) + { + val = (double)cb; + unit = "bytes"; + } + else if( cb < ( 1 << 20 ) ) + { + val = (double)cb * ( 1.0 / ( 1 << 10 ) ); + unit = "KB"; + } + else if( cb < ( 1 << 30 ) ) + { + val = (double)cb * ( 1.0 / ( 1 << 20 ) ); + unit = "MB"; + } + else + { + val = (double)cb * ( 1.0 / ( 1 << 30 ) ); + unit = "GB"; + } + } + }; + + static void __declspec( noinline ) logMemoryUse( const char* what, __m128i cb ) + { + PrintedSize sys{ _mm_cvtsi128_si64( cb ) }; + PrintedSize vram{ _mm_extract_epi64( cb, 1 ) }; + logInfo( u8"%s\t%g %s RAM, %g %s VRAM", what, sys.val, sys.unit, vram.val, vram.unit ); + } +} + +HRESULT COMLIGHTCALL ContextImpl::timingsPrint() +{ + profiler.print(); + + const __m128i memModel = model.getMemoryUse(); + const __m128i memContext = getMemoryUse(); + logInfo( u8" Memory Usage" ); + logMemoryUse( "Model", memModel ); + logMemoryUse( "Context", memContext ); + logMemoryUse( "Total", _mm_add_epi64( memModel, memContext ) ); + return S_OK; +} + +HRESULT COMLIGHTCALL ContextImpl::timingsReset() +{ + profiler.reset(); + return S_OK; +} + +HRESULT COMLIGHTCALL ContextImpl::getResults( eResultFlags flags, iTranscribeResult** pp ) const noexcept +{ + if( nullptr == pp ) + return E_POINTER; + + if( flags & eResultFlags::NewObject ) + { + ComLight::CComPtr<ComLight::Object<TranscribeResult>> obj; + CHECK( ComLight::Object<TranscribeResult>::create( obj ) ); + CHECK( makeResults( flags, *obj ) ); + obj.detach( pp ); + return S_OK; + } + else + { + CHECK( makeResults( flags, results ) ); + iTranscribeResult* res = &results; + res->AddRef(); + *pp = res; + return S_OK; + } +} + +inline int64_t scaleTime( int64_t wisperTicks ) +{ + return MFllMulDiv( wisperTicks, 10'000'000, 100, 0 ); +} + +HRESULT COMLIGHTCALL ContextImpl::makeResults( eResultFlags flags, TranscribeResult& res ) const noexcept +{ + const size_t segments = result_all.size(); + // Resize both vectors + try + { + res.segments.resize( segments ); + if( flags & eResultFlags::Tokens ) + { + size_t tc = 0; + for( const auto& s : result_all ) + tc += s.tokens.size(); + res.tokens.resize( tc ); + } + else + res.tokens.clear(); + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + + const Vocabulary::id tokenEot = model.vocab.token_eot; + + size_t tokensSoFar = 0; + for( size_t i = 0; i < segments; i++ ) + { + sSegment& rdi = res.segments[ i ]; + const auto& rsi = result_all[ i ]; + rdi.text = rsi.text.c_str(); + if( flags & eResultFlags::Timestamps ) + { + // Offset the time relative to the start of the media + rdi.time.begin = scaleTime( rsi.t0 ) + mediaTimeOffset; + rdi.time.end = scaleTime( rsi.t1 ) + mediaTimeOffset; + } + else + store16( &rdi.time, _mm_setzero_si128() ); + + rdi.firstToken = (uint32_t)tokensSoFar; + const size_t tc = rsi.tokens.size(); + rdi.countTokens = (uint32_t)tc; + + if( flags & eResultFlags::Tokens ) + { + for( size_t i = 0; i < tc; i++ ) + { + sToken& rdi = res.tokens[ tokensSoFar + i ]; + const auto& src = rsi.tokens[ i ]; + rdi.text = model.vocab.string( src.id ); + + if( flags & eResultFlags::Timestamps ) + { + // Offset the time relative to the start of the media + rdi.time.begin = scaleTime( src.t0 ) + mediaTimeOffset; + rdi.time.end = scaleTime( src.t1 ) + mediaTimeOffset; + } + else + store16( &rdi.time, _mm_setzero_si128() ); + + // Copy 4 floats with unaligned load and store instructions + _mm_storeu_ps( &rdi.probability, _mm_loadu_ps( &src.p ) ); + + rdi.id = src.id; + + uint32_t flags = 0; + if( src.id >= tokenEot ) + flags |= (uint32_t)eTokenFlags::Special; + rdi.flags = (eTokenFlags)flags; + } + } + tokensSoFar += tc; + } + return S_OK; +} + +int ContextImpl::wrapSegment( int max_len ) +{ + // whisper_wrap_segment + auto segment = result_all.back(); + int res = 1; + int acc = 0; + std::string text; + const int tokenEot = model.vocab.token_eot; + + for( int i = 0; i < (int)segment.tokens.size(); i++ ) + { + const auto& token = segment.tokens[ i ]; + if( token.id >= tokenEot ) + continue; + + const char* txt = model.vocab.string( token.id ); + const int cur = (int)strlen( txt ); + + if( acc + cur > max_len && i > 0 ) + { + // split here + result_all.back().text = std::move( text ); + result_all.back().t1 = token.t0; + result_all.back().tokens.resize( i ); + + result_all.push_back( {} ); + result_all.back().t0 = token.t0; + result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + result_all.back().tokens.insert( result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end() ); + + acc = 0; + text = ""; + + segment = result_all.back(); + i = -1; + + res++; + } + else + { + acc += cur; + text += txt; + } + } + + result_all.back().text = std::move( text ); + return res; +} + +HRESULT COMLIGHTCALL ContextImpl::runFull( const sFullParams& params, const iAudioBuffer* buffer ) +{ +#if SAVE_DEBUG_TRACE + Tracing::vector( "runFull.pcm.in", buffer->getPcmMono(), buffer->countSamples() ); +#endif + CHECK( buffer->getTime( mediaTimeOffset ) ); + + auto profCompleteCpu = profiler.cpuBlock( eCpuBlock::Run ); + { + auto p = profiler.cpuBlock( eCpuBlock::Spectrogram ); + CHECK( spectrogram.pcmToMel( buffer, model.filters, params.cpuThreads ) ); + } + + if( params.flag( eFullParamsFlags::TokenTimestamps ) ) + { + t_beg = 0; + t_last = 0; + tid_last = 0; + computeSignalEnergy( energy, buffer, 32 ); + } + + try + { + sProgressSink progressSink{ nullptr, nullptr }; + return runFullImpl( params, progressSink, spectrogram ); + } + catch( HRESULT hr ) + { + return hr; + } +} + +HRESULT COMLIGHTCALL ContextImpl::runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) +{ + if( params.flag( eFullParamsFlags::TokenTimestamps ) ) + { + logError( u8"eFullParamsFlags.TokenTimestamps flag is not supported in streaming mode" ); + return E_NOTIMPL; + } + + mediaTimeOffset = 0; + auto profCompleteCpu = profiler.cpuBlock( eCpuBlock::Run ); + + CComPtr<IMFSourceReader> mfReader; + CHECK( reader->getReader( &mfReader ) ); + const bool stereo = reader->requestedStereo() == S_OK; + + try + { + if( params.cpuThreads > 1 ) + { + MelStreamerThread mel{ model.filters, profiler, mfReader, params.cpuThreads, stereo }; + return runFullImpl( params, progress, mel ); + } + else + { + MelStreamerSimple mel{ model.filters, profiler, mfReader, stereo }; + return runFullImpl( params, progress, mel ); + } + } + catch( HRESULT hr ) + { + return hr; + } +}
\ No newline at end of file diff --git a/Whisper/Whisper/DecoderInputBuffers.cpp b/Whisper/Whisper/DecoderInputBuffers.cpp new file mode 100644 index 0000000..68d3cec --- /dev/null +++ b/Whisper/Whisper/DecoderInputBuffers.cpp @@ -0,0 +1,66 @@ +#include "stdafx.h" +#include "DecoderInputBuffers.h" +#include "../D3D/createBuffer.h" +#include "../D3D/MappedResource.h" +using namespace DirectCompute; + +void DecoderInputBuffers::resize( uint32_t size ) +{ + if( 0 == size ) + throw E_INVALIDARG; + + if( size <= m_capacity ) + { + m_size = size; + return; + } + + embd = nullptr; + + // Round up by 256, mostly for lulz + const uint32_t newCapacity = ( size + 0xFFu ) & ( ~( 0xFFu ) ); + const size_t totalBytes = (size_t)4 * newCapacity; + + check( createBuffer( eBufferUse::Dynamic, totalBytes, &embd, nullptr, nullptr ) ); + + m_capacity = newCapacity; + m_size = size; +} + +namespace +{ + static Tensor createView( ID3D11Buffer* buffer, uint32_t length ) + { + Tensor res; + + TensorGpuViews& views = res; + check( views.create( buffer, DXGI_FORMAT_R32_UINT, length, false ) ); + + res.ne = { length, 1, 1, 1 }; + res.setDenseStrides(); + return res; + } +} + +Tensor DecoderInputBuffers::embedding( const int* rsi ) const +{ + if( nullptr == embd || m_size == 0 ) + throw OLE_E_BLANK; + + // Upload the data + { + MappedResource mapped; + check( mapped.map( embd, false ) ); + int* const rdi = (int*)mapped.data(); + memcpy( rdi, rsi, m_size * (size_t)4 ); + } + + return createView( embd, m_size ); +} + +void DecoderInputBuffers::clear() +{ + embd = nullptr; + m_size = 0; + m_capacity = 0; +}
\ No newline at end of file diff --git a/Whisper/Whisper/DecoderInputBuffers.h b/Whisper/Whisper/DecoderInputBuffers.h new file mode 100644 index 0000000..9ce8f75 --- /dev/null +++ b/Whisper/Whisper/DecoderInputBuffers.h @@ -0,0 +1,29 @@ +#pragma once +#include "../ML/Tensor.h" + +namespace DirectCompute +{ + // Two dynamic buffers + class DecoderInputBuffers + { + CComPtr<ID3D11Buffer> embd; + uint32_t m_size = 0; + uint32_t m_capacity = 0; + + public: + + void resize( uint32_t size ); + + // Create 1D tensor with R32_UINT elements, upload the source data + Tensor embedding( const int* rsi ) const; + + void clear(); + + __m128i getMemoryUse() const + { + size_t i = m_capacity; + i *= sizeof( uint32_t ); + return _mm_set_epi64x( (int64_t)i, 0 ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/DecoderResultBuffer.cpp b/Whisper/Whisper/DecoderResultBuffer.cpp new file mode 100644 index 0000000..3a26d51 --- /dev/null +++ b/Whisper/Whisper/DecoderResultBuffer.cpp @@ -0,0 +1,48 @@ +#include "stdafx.h" +#include "DecoderResultBuffer.h" +#include "../D3D/MappedResource.h" +using namespace DirectCompute; + +void DecoderResultBuffer::copyFromVram( const Tensor& rsi ) +{ + ID3D11ShaderResourceView* srv = rsi; + if( nullptr == srv ) + throw OLE_E_BLANK; + if( !rsi.isContinuous() ) + throw E_INVALIDARG; + + const uint32_t len = rsi.countElements(); + if( len > m_capacity ) + { + buffer = nullptr; + CD3D11_BUFFER_DESC desc{ len * 4, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ }; + check( device()->CreateBuffer( &desc, nullptr, &buffer ) ); + m_capacity = len; + } + + CComPtr<ID3D11Resource> source; + srv->GetResource( &source ); + // Coordinates of a box are in bytes for buffers + D3D11_BOX box; + store16( &box, _mm_setr_epi32( 0, 0, 0, (int)( len * 4 ) ) ); + *(uint64_t*)&box.bottom = 0x100000001ull; + context()->CopySubresourceRegion( buffer, 0, 0, 0, 0, source, 0, &box ); + m_size = len; +} + +void DecoderResultBuffer::copyToVector( std::vector<float>& vec ) const +{ + vec.resize( m_size ); + if( vec.empty() ) + throw OLE_E_BLANK; + + MappedResource mapped; + check( mapped.map( buffer, true ) ); + memcpy( vec.data(), mapped.data(), (size_t)4 * m_size ); +} + +void DecoderResultBuffer::clear() +{ + buffer = nullptr; + m_size = m_capacity = 0; +}
\ No newline at end of file diff --git a/Whisper/Whisper/DecoderResultBuffer.h b/Whisper/Whisper/DecoderResultBuffer.h new file mode 100644 index 0000000..471395f --- /dev/null +++ b/Whisper/Whisper/DecoderResultBuffer.h @@ -0,0 +1,29 @@ +#pragma once +#include "../ML/Tensor.h" + +namespace DirectCompute +{ + class DecoderResultBuffer + { + CComPtr<ID3D11Buffer> buffer; + uint32_t m_size = 0; + uint32_t m_capacity = 0; + + public: + + void copyFromVram( const Tensor& rsi ); + + void copyToVector( std::vector<float>& vec ) const; + + uint32_t size() const + { + return m_size; + } + void clear(); + + __m128i getMemoryUse() const + { + return bufferMemoryUsage( buffer ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/KeyValueBuffers.cpp b/Whisper/Whisper/KeyValueBuffers.cpp new file mode 100644 index 0000000..b932fdb --- /dev/null +++ b/Whisper/Whisper/KeyValueBuffers.cpp @@ -0,0 +1,42 @@ +#include "stdafx.h" +#include "KeyValueBuffers.h" +#include "../D3D/createBuffer.h" +using namespace DirectCompute; + +void AttentionBuffer::resize( uint32_t size ) +{ + if( size <= m_size ) + return; + + buffer = nullptr; + check( createBuffer( eBufferUse::ReadWrite, (size_t)2 * size, &buffer, nullptr, nullptr ) ); + m_size = size; +} + +Tensor AttentionBuffer::view( uint32_t length, uint32_t offset ) const +{ + if( length + offset > m_size ) + throw E_BOUNDS; + if( 0 == length ) + throw E_INVALIDARG; + + CComPtr<ID3D11ShaderResourceView> srv; + CComPtr<ID3D11UnorderedAccessView> uav; + + CD3D11_SHADER_RESOURCE_VIEW_DESC srvDesc{ D3D11_SRV_DIMENSION_BUFFER, DXGI_FORMAT_R16_FLOAT, offset, length }; + check( device()->CreateShaderResourceView( buffer, &srvDesc, &srv ) ); + + CD3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc{ D3D11_UAV_DIMENSION_BUFFER, DXGI_FORMAT_R16_FLOAT, offset, length }; + check( device()->CreateUnorderedAccessView( buffer, &uavDesc, &uav ) ); + + TensorShape shape; + shape.ne = { length, 1, 1, 1 }; + shape.setDenseStrides(); + return Tensor( shape, srv, uav ); +} + +void KeyValueBuffers::resize( uint32_t size ) +{ + keys.resize( size ); + values.resize( size ); +}
\ No newline at end of file diff --git a/Whisper/Whisper/KeyValueBuffers.h b/Whisper/Whisper/KeyValueBuffers.h new file mode 100644 index 0000000..9c737be --- /dev/null +++ b/Whisper/Whisper/KeyValueBuffers.h @@ -0,0 +1,50 @@ +#pragma once +#include "../ML/Tensor.h" + +namespace DirectCompute +{ + // FP16 buffer for self-attention and cross-attention layers + class AttentionBuffer + { + CComPtr<ID3D11Buffer> buffer; + uint32_t m_size = 0; + + public: + // Create buffer for the specified count of elements + void resize( uint32_t size ); + + // Create an 1D tensor which references a slice of that buffer + Tensor view( uint32_t length, uint32_t offset ) const; + + void clear() + { + buffer = nullptr; + m_size = 0; + } + + ID3D11Buffer* getBuffer() const { return buffer; } + + uint32_t getSize() const { return m_size; } + }; + + struct KeyValueBuffers + { + AttentionBuffer keys, values; + + void resize( uint32_t size ); + + void clear() + { + keys.clear(); + values.clear(); + } + + __m128i getMemoryUse() const + { + size_t i = keys.getSize(); + i += values.getSize(); + i *= sizeof( uint16_t ); + return setHigh_size( (int64_t)i ); // They both are in VRAM + } + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/Languages.cpp b/Whisper/Whisper/Languages.cpp new file mode 100644 index 0000000..2081d6a --- /dev/null +++ b/Whisper/Whisper/Languages.cpp @@ -0,0 +1,122 @@ +#include "stdafx.h" +#include "Languages.h" +#include <atlcoll.h> +#include "../API/iContext.cl.h" + +namespace +{ + // These structures are compiled into the DLL, in read only data section + using Lang = Whisper::sLanguageEntry; + + static const Lang s_languageData[] = + { +#include "languageCodez.inl" + }; + + using Whisper::makeLanguageKey; + + // Values for the hash map + struct sLanguage + { + int id; + const char* name; + }; + + class LanguageIDs + { + CAtlMap<uint32_t, sLanguage> map; + + void add( const char* code, int id, const char* name ) + { + assert( strlen( code ) <= 4 ); + const uint16_t key = makeLanguageKey( code ); + map.SetAt( key, sLanguage{ id, name } ); + } + + public: + + LanguageIDs() : + map( 103u, 0.75f, 0.25f, 2.25f, 99 ) + { + for( const Lang& e : s_languageData ) + map.SetAt( e.key, sLanguage{ e.id, e.name } ); + }; + + int lookupId( const char* code ) const + { + const uint32_t key = makeLanguageKey( code ); + auto p = map.Lookup( key ); + return ( nullptr != p ) ? p->m_value.id : -1; + } + + int lookupKey( uint32_t key ) const + { + auto p = map.Lookup( key ); + return ( nullptr != p ) ? p->m_value.id : -1; + } + + const char* lookupName( const char* code ) const + { + const uint32_t key = makeLanguageKey( code ); + auto p = map.Lookup( key ); + return ( nullptr != p ) ? p->m_value.name : nullptr; + } + }; + + static const LanguageIDs g_table; +} + +namespace Whisper +{ + int lookupLanguageId( const char* code ) + { + return g_table.lookupId( code ); + } + int lookupLanguageId( uint32_t key ) + { + return g_table.lookupKey( key ); + } + const char* lookupLanguageName( const char* code ) + { + return g_table.lookupName( code ); + } + int COMLIGHTCALL getLanguageId( const char* lang ) + { + return lookupLanguageId( lang ); + } + + uint32_t COMLIGHTCALL findLanguageKeyW( const wchar_t* lang ) + { + uint32_t key = 0; + uint32_t shift = 0; + for( size_t i = 0; i < 4; i++, lang++, shift += 8 ) + { + const wchar_t c = *lang; + if( c == L'\0' ) + break; + if( c >= 0x80 ) + return UINT_MAX; + uint32_t u32 = (uint8_t)c; + u32 = u32 << shift; + key |= u32; + } + if( g_table.lookupKey( key ) >= 0 ) + return key; + return UINT_MAX; + } + + uint32_t COMLIGHTCALL findLanguageKeyA( const char* lang ) + { + const uint32_t key = makeLanguageKey( lang ); + if( g_table.lookupKey( key ) >= 0 ) + return key; + return UINT_MAX; + } + + HRESULT COMLIGHTCALL getSupportedLanguages( sLanguageList& rdi ) + { + rdi.length = sizeof( s_languageData ) / sizeof( s_languageData[ 0 ] ); + rdi.pointer = s_languageData; + return S_OK; + } +}
\ No newline at end of file diff --git a/Whisper/Whisper/Languages.h b/Whisper/Whisper/Languages.h new file mode 100644 index 0000000..bb9e599 --- /dev/null +++ b/Whisper/Whisper/Languages.h @@ -0,0 +1,12 @@ +#pragma once +#include "../../ComLightLib/comLightCommon.h" + +namespace Whisper +{ + int lookupLanguageId( const char* code ); + int lookupLanguageId( uint32_t key ); + + const char* lookupLanguageName( const char* code ); + + int COMLIGHTCALL getLanguageId( const char* lang ); +}
\ No newline at end of file diff --git a/Whisper/Whisper/MelInputTensor.cpp b/Whisper/Whisper/MelInputTensor.cpp new file mode 100644 index 0000000..c2c5e43 --- /dev/null +++ b/Whisper/Whisper/MelInputTensor.cpp @@ -0,0 +1,63 @@ +#include "stdafx.h" +#include "MelInputTensor.h" +#include "../D3D/MappedResource.h" +#include "../D3D/createBuffer.h" +#include <mfapi.h> // MFCopyImage +using namespace DirectCompute; + +HRESULT MelInputTensor::create( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams ) +{ + // Ported from the initial portion of whisper_encode() function + const size_t ne0 = encParams.n_ctx * 2; + const size_t ne1 = encParams.n_mels; + const size_t totalElts = ne0 * ne1; + const size_t totalBytes = totalElts * 4; + + if( capacity < (uint32_t)totalElts ) + { + // The old buffer is too small: drop the old one, and create a larger buffer with SRV + buffer = nullptr; + TensorGpuViews::clear(); + + CHECK( createBuffer( eBufferUse::Dynamic, totalBytes, &buffer, nullptr, nullptr ) ); + CHECK( TensorGpuViews::create( buffer, DXGI_FORMAT_R32_FLOAT, totalElts, false ) ); + + capacity = (uint32_t)totalElts; + } + + // Upload data to VRAM using D3D11_MAP_WRITE_DISCARD, that's why we made a dynamic buffer + { + // Ported from whisper_encode() function + MappedResource mapped; + CHECK( mapped.map( buffer, false ) ); + float* const dst = (float*)mapped.data(); + memset( dst, 0, totalBytes ); + + const size_t n_len = spectrogram.getLength(); + const size_t i0 = std::min( (size_t)encParams.mel_offset, n_len ); + const size_t i1 = std::min( (size_t)encParams.mel_offset + 2 * encParams.n_ctx, n_len ); + + // Whisper::MelBufferRaii sourceBuffer{ spectrogram, i0, i1 - i0 }; + constexpr DWORD n_mel = Whisper::N_MEL; + const size_t rowBytes = ( i1 - i0 ) * 4; + /* + for( size_t j = 0; j < n_mel; j++ ) + { + float* rdi = dst + j * 2 * encParams.n_ctx; + const float* rsi = sourceBuffer[ j ]; + memcpy( rdi, rsi, rowBytes ); + } */ + + Whisper::MelBufferRaii sourceBuffer; + CHECK( sourceBuffer.make( spectrogram, i0, i1 - i0 ) ); + CHECK( MFCopyImage( + (BYTE*)dst, (LONG)( 2 * encParams.n_ctx * sizeof( float ) ), + sourceBuffer.bytePtr(), sourceBuffer.strideBytes(), + (DWORD)rowBytes, n_mel ) ); + } + + // Shape the tensor + ne = { 2 * encParams.n_ctx, encParams.n_mels, 1, 1 }; + TensorShape::setDenseStrides(); + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/Whisper/MelInputTensor.h b/Whisper/Whisper/MelInputTensor.h new file mode 100644 index 0000000..3923ad6 --- /dev/null +++ b/Whisper/Whisper/MelInputTensor.h @@ -0,0 +1,22 @@ +#pragma once +#include "../ML/TensorEx.h" +#include "sEncodeParams.h" +#include "iSpectrogram.h" + +namespace DirectCompute +{ + // Input tensor in VRAM, in a dynamic FP32 buffer + class MelInputTensor : public TensorEx + { + uint32_t capacity; + + public: + + HRESULT create( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams ); + + __m128i getMemoryUse() const + { + return setHigh_size( (size_t)capacity * 4 ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/MelStreamer.cpp b/Whisper/Whisper/MelStreamer.cpp new file mode 100644 index 0000000..54268f2 --- /dev/null +++ b/Whisper/Whisper/MelStreamer.cpp @@ -0,0 +1,493 @@ +#include "stdafx.h" +#include "MelStreamer.h" +#include "../Utils/parallelFor.h" +using namespace Whisper; + +MelStreamer::MelStreamer( const Filters& filters, ProfileCollection& prof, IMFSourceReader* source, bool stereo ) : + reader( source, stereo ), + melContext( filters ), + profiler( prof ) +{ } + +void MelStreamer::dropOldChunks( size_t off ) +{ + const bool stereo = reader.outputsStereo(); + for( size_t i = streamStartOffset; i < off; i++ ) + { + queuePcmMono.pop_front(); + queueMel.pop_front(); + if( stereo ) + queuePcmStereo.pop_front(); + } + streamStartOffset = off; +} + +HRESULT MelStreamer::ensurePcmChunks( size_t len ) +{ + if( readerEof ) + return queuePcmMono.empty() ? E_EOF : S_FALSE; + + const bool loadStereo = reader.outputsStereo(); + + const size_t neededChunks = len + FFT_SIZE / FFT_STEP; + while( true ) + { + if( queuePcmMono.size() >= neededChunks ) + return S_OK; + + PcmMonoChunk& mono = queuePcmMono.emplace_back(); + PcmStereoChunk* stereo = loadStereo ? &queuePcmStereo.emplace_back() : nullptr; + HRESULT hr = reader.readChunk( mono, stereo ); + if( SUCCEEDED( hr ) ) + continue; + + queuePcmMono.pop_back(); + if( loadStereo ) + queuePcmStereo.pop_back(); + + if( hr == E_EOF ) + { + readerEof = true; + return S_FALSE; + } + + return hr; + } +} + +size_t MelStreamer::serializePcm( size_t startOffset ) +{ + const ptrdiff_t chunks = (ptrdiff_t)queuePcmMono.size() - (ptrdiff_t)startOffset; + assert( chunks > 0 ); + + tempPcm.resize( chunks * FFT_STEP ); + float* rdi = tempPcm.data(); + + for( auto it = queuePcmMono.begin() + startOffset; it != queuePcmMono.end(); it++ ) + { + memcpy( rdi, it->mono.data(), FFT_STEP * 4 ); + rdi += FFT_STEP; + } + return chunks; +} + +namespace +{ + __forceinline __m128 transpose4x80( __m128 vmax, const float* c0, const float* c1, const float* c2, const float* c3, float* rdi, size_t stride ) + { + const float* const c0End = c0 + 80; + for( ; c0 < c0End; c0 += 4, c1 += 4, c2 += 4, c3 += 4, rdi += stride * 4 ) + { + __m128 r0 = _mm_loadu_ps( c0 ); + __m128 r1 = _mm_loadu_ps( c1 ); + __m128 r2 = _mm_loadu_ps( c2 ); + __m128 r3 = _mm_loadu_ps( c3 ); + + __m128 ax01 = _mm_max_ps( r0, r1 ); + __m128 ax02 = _mm_max_ps( r2, r3 ); + __m128 ax = _mm_max_ps( ax01, ax02 ); + vmax = _mm_max_ps( vmax, ax ); + + _MM_TRANSPOSE4_PS( r0, r1, r2, r3 ); + + _mm_storeu_ps( rdi, r0 ); + _mm_storeu_ps( rdi + stride, r1 ); + _mm_storeu_ps( rdi + stride * 2, r2 ); + _mm_storeu_ps( rdi + stride * 3, r3 ); + } + return vmax; + } + + __forceinline __m128 transpose80( __m128 vmax, const float* c0, float* rdi, size_t stride ) + { + const float* const c0End = c0 + 80; + for( ; c0 < c0End; c0 += 4, rdi += stride * 4 ) + { + __m128 r0 = _mm_loadu_ps( c0 ); + vmax = _mm_max_ps( vmax, r0 ); + + _mm_store_ss( rdi, r0 ); + *(int*)( rdi + stride ) = _mm_extract_ps( r0, 1 ); + *(int*)( rdi + stride * 2 ) = _mm_extract_ps( r0, 2 ); + *(int*)( rdi + stride * 3 ) = _mm_extract_ps( r0, 3 ); + } + return vmax; + } + + __forceinline float horizontalMaximum( __m128 v ) + { + v = _mm_max_ps( v, _mm_movehl_ps( v, v ) ); + v = _mm_max_ss( v, _mm_movehdup_ps( v ) ); + return _mm_cvtss_f32( v ); + } +} + +void MelStreamer::makeTransposedBuffer( size_t off, size_t len ) +{ + // Resize the output + assert( len <= queueMel.size() ); + outputMel.resize( len * N_MEL ); // N_MEL = 80 + + // First pass, copy transposed MEL data, and compute the maximum + const size_t lengthAligned = ( len / 4 ) * 4; + __m128 vMax = _mm_set1_ps( 1e-20f ); + float* rdi = outputMel.data(); + + size_t i; + for( i = 0; i < lengthAligned; i += 4, rdi += 4 ) + { + vMax = transpose4x80( vMax, + queueMel[ i ].data(), + queueMel[ i + 1 ].data(), + queueMel[ i + 2 ].data(), + queueMel[ i + 3 ].data(), + rdi, len ); + } + for( ; i < len; i++, rdi++ ) + vMax = transpose80( vMax, queueMel[ i ].data(), rdi, len ); + + // Second pass, clamping and normalization + float mmax; + const size_t bufferEnd = off + len; + if( lastBufferEnd != bufferEnd ) + { + // Store maximum value in this class, along with the end sample index + mmax = horizontalMaximum( vMax ); + lastBufferEnd = bufferEnd; + lastBufferMax = mmax; + } + else + { + // We're probably at the and of the stream, the caller asked for a smalled slice of the samples with the same end as the last time. + // Discard the computed maximum value, and instead use the number stored in this class + mmax = lastBufferMax; + } + + mmax -= 8.0f; + vMax = _mm_set1_ps( mmax ); + + rdi = outputMel.data(); + float* const rdiEnd = rdi + outputMel.size(); + const __m128 add = _mm_set1_ps( 4 ); + const __m128 mul = _mm_set1_ps( 1.0f / 4.0f ); + for( ; rdi < rdiEnd; rdi += 4 ) + { + __m128 v = _mm_loadu_ps( rdi ); + v = _mm_max_ps( v, vMax ); + v = _mm_add_ps( v, add ); + v = _mm_mul_ps( v, mul ); + _mm_storeu_ps( rdi, v ); + } +} + +HRESULT MelStreamerSimple::makeBuffer( size_t off, size_t len, const float** buffer, size_t& stride ) noexcept +{ + if( off < streamStartOffset ) + { + logError( u8"MelStreamer doesn't support backwards seeks" ); + return E_UNEXPECTED; + } + + if( off > streamStartOffset ) + { + // The model wants to advance forward, drop now irrelevant chunks of data + dropOldChunks( off ); + } + + // Compute all these MEL chunks + const size_t availableMel = queueMel.size(); + if( availableMel < len ) + { + CHECK( ensurePcmChunks( len ) ); + + const size_t pcmChunks = serializePcm( availableMel ); + const size_t missingMelChunks = len - availableMel; + size_t i; + const size_t loop1 = std::min( missingMelChunks, pcmChunks ); + { + auto profilerBlock = profiler.cpuBlock( eCpuBlock::Spectrogram ); + for( i = 0; i < loop1; i++ ) + { + // if( readerEof && i + 1 == loop1 ) __debugbreak(); + auto& arr = queueMel.emplace_back(); + const float* sourcePcm = tempPcm.data() + i * FFT_STEP; + size_t availableChunks = pcmChunks - i; + size_t availableFloats = availableChunks * FFT_STEP; + melContext.fft( arr, sourcePcm, availableFloats ); + } + } + for( ; i < missingMelChunks; i++ ) + { + assert( readerEof ); + auto& arr = queueMel.emplace_back(); + memset( arr.data(), 0, N_MEL * 4 ); + } + } + + // Produce the result + makeTransposedBuffer( off, len ); + stride = len; + *buffer = outputMel.data(); + return S_OK; +} + +MelStreamerThread::MelStreamerThread( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, int countThreads, bool stereo ) : + MelStreamer( filters, profiler, source, stereo ), + workerThreads( countThreads ) +{ + if( workerThreads > 1 ) + { + check( ThreadPoolWork::create() ); + melContextsWorkers.reserve( workerThreads - 1 ); + for( int i = 1; i < workerThreads; i++ ) + melContextsWorkers.emplace_back( filters ); + } + + InitializeConditionVariable( &wakeMain ); + InitializeConditionVariable( &wakeBackground ); + threadStatus = eThreadStatus::NotStarted; + const HANDLE h = CreateThread( nullptr, 0, &threadProcStatic, this, 0, nullptr ); + if( nullptr == h ) + throw HRESULT_FROM_WIN32( GetLastError() ); + threadHandle.Attach( h ); +} + +using Lock = CComCritSecLock<CComAutoCriticalSection>; + +constexpr ptrdiff_t prebufferChunks = 3000 * 2; +constexpr ptrdiff_t chunksPerWakeup = 512; +constexpr ptrdiff_t minChunksPerThread = 64; + +HRESULT MelStreamerThread::threadMain() +{ + pendingChunks.reserve( chunksPerWakeup ); + + EnterCriticalSection( &m_cs.m_sec ); + threadStatus = eThreadStatus::Working; + + while( true ) + { + if( shuttingDown ) + { + LeaveCriticalSection( &m_cs.m_sec ); + return S_FALSE; + } + + // Count of available MEL chunks + const ptrdiff_t availableMel = queueMel.size(); + if( availableMel >= prebufferChunks ) + { + threadStatus = eThreadStatus::Idle; + SleepConditionVariableCS( &wakeBackground, &m_cs.m_sec, INFINITE ); + threadStatus = eThreadStatus::Working; + continue; + } + // Count of MEL chunks remaining in the whole stream + // availableMel of them are already on the queue + const ptrdiff_t remainingMel = (ptrdiff_t)getLength() - (ptrdiff_t)streamStartOffset; + LeaveCriticalSection( &m_cs.m_sec ); + + const ptrdiff_t missingChunks = prebufferChunks - availableMel; + ptrdiff_t chunks = std::min( missingChunks, chunksPerWakeup ); + chunks = std::min( chunks, remainingMel - availableMel ); + if( chunks <= 0 ) + return S_OK; // This thread has produced all chunks of the stream + + CHECK( ensurePcmChunks( availableMel + chunks ) ); + const size_t pcmChunks = serializePcm( availableMel ); + if( 0 == pcmChunks ) + return S_OK; + + pendingChunks.clear(); + + chunks = std::min( chunks, (ptrdiff_t)pcmChunks ); + { + auto profilerBlock = profiler.cpuBlock( eCpuBlock::Spectrogram ); + + if( this->workerThreads <= 1 || chunks < minChunksPerThread * 2 ) + { + // Thread pool disabled with a setting, or not enough work for the thread pool + for( ptrdiff_t i = 0; i < chunks; i++ ) + { + MelChunk& arr = pendingChunks.emplace_back(); + const float* sourcePcm = tempPcm.data() + i * FFT_STEP; + size_t availableChunks = pcmChunks - i; + size_t availableFloats = availableChunks * FFT_STEP; + melContext.fft( arr, sourcePcm, availableFloats ); + } + } + else + { + // Use thread pool for these FFTs + pendingChunks.resize( chunks ); + int nth = (int)( ( chunks + minChunksPerThread - 1 ) / minChunksPerThread ); + nth = std::min( nth, this->workerThreads ); + assert( nth > 1 ); + this->fftChunks = (int)chunks; + this->fftThreads = nth; + CHECK( ThreadPoolWork::parallelFor( nth ) ); + } + } + + EnterCriticalSection( &m_cs.m_sec ); + if( shuttingDown ) + { + LeaveCriticalSection( &m_cs.m_sec ); + return S_FALSE; + } + + for( const auto& a : pendingChunks ) + queueMel.push_back( a ); + + LeaveCriticalSection( &m_cs.m_sec ); + + WakeAllConditionVariable( &wakeMain ); + pendingChunks.clear(); + + EnterCriticalSection( &m_cs.m_sec ); + } +} + +HRESULT MelStreamerThread::threadPoolCallback( int ith ) noexcept +{ + SpectrogramContext& ctx = ( 0 != ith ) ? melContextsWorkers[ ith - 1 ] : melContext; + + // Figure out the slice of the chunks to generate in this thread + const int nth = this->fftThreads; + const int chunks = this->fftChunks; + const int i0 = ( ith * chunks ) / nth; + const int i1 = ( ( ith + 1 ) * chunks ) / nth; + + // Run these FFTs + const size_t pcmChunks = tempPcm.size() / FFT_STEP; + for( int i = i0; i < i1; i++ ) + { + MelChunk& arr = pendingChunks[ i ]; + const float* sourcePcm = tempPcm.data() + i * FFT_STEP; + size_t availableChunks = pcmChunks - i; + size_t availableFloats = availableChunks * FFT_STEP; + ctx.fft( arr, sourcePcm, availableFloats ); + } + return S_OK; +} + +HRESULT MelStreamerThread::run() noexcept +{ + HRESULT status; + try + { + status = threadMain(); + } + catch( HRESULT hr ) + { + status = hr; + } + catch( const std::bad_alloc& ) + { + status = E_OUTOFMEMORY; + } + catch( const std::exception& ) + { + status = E_FAIL; + } + + { + Lock lk( m_cs ); + threadStatus = SUCCEEDED( status ) ? eThreadStatus::Completed : eThreadStatus::Failed; + } + + // Especially when things fail, we want to wake the main thread up, so it's aware of the situation. + WakeAllConditionVariable( &wakeMain ); + return status; +} + +DWORD __stdcall MelStreamerThread::threadProcStatic( void* lpParameter ) +{ + setCurrentThreadName( "Whisper.dll MEL Streamer Thread" ); + MelStreamerThread* p = (MelStreamerThread*)lpParameter; + return (DWORD)p->run(); +} + +HRESULT MelStreamerThread::makeBuffer( size_t off, size_t len, const float** buffer, size_t& stride ) noexcept +{ + bool wakeThread = false; + + { + Lock lock( m_cs ); + if( off < streamStartOffset ) + { + logError( u8"MelStreamer doesn't support backwards seeks" ); + return E_UNEXPECTED; + } + + if( off > streamStartOffset ) + { + // The model wants to advance forward, drop now irrelevant chunks of data + dropOldChunks( off ); + wakeThread = ( threadStatus == eThreadStatus::Working || threadStatus == eThreadStatus::Idle ); + } + + while( true ) + { + const size_t availableMel = queueMel.size(); + if( availableMel >= len ) + break; + + const eThreadStatus ts = threadStatus; + if( ts == eThreadStatus::Working || ts == eThreadStatus::Idle ) + { + WakeAllConditionVariable( &wakeBackground ); + SleepConditionVariableCS( &wakeMain, &m_cs.m_sec, INFINITE ); + continue; + } + if( ts == eThreadStatus::Failed ) + { + DWORD code; + if( GetExitCodeThread( threadHandle, &code ) ) + return (HRESULT)code; + else + return HRESULT_FROM_WIN32( GetLastError() ); + } + assert( ts == eThreadStatus::Completed ); + break; + } + + if( queueMel.size() < len ) + { + assert( readerEof || threadStatus == eThreadStatus::Failed ); + while( queueMel.size() < len ) + { + auto& arr = queueMel.emplace_back(); + memset( arr.data(), 0, N_MEL * 4 ); + } + } + + // Produce the result + makeTransposedBuffer( off, len ); + + } // Unlock the critical section + + stride = len; + *buffer = outputMel.data(); + if( wakeThread ) + WakeAllConditionVariable( &wakeBackground ); + return S_OK; +} + +MelStreamerThread::~MelStreamerThread() +{ + if( !threadHandle ) + return; + + { + Lock lock( m_cs ); + if( threadStatus != eThreadStatus::Working ) + return; + shuttingDown = true; + } + + DWORD res = WaitForSingleObject( threadHandle, 100 ); + if( res == WAIT_OBJECT_0 ) + return; + // TODO: log a warning +}
\ No newline at end of file diff --git a/Whisper/Whisper/MelStreamer.h b/Whisper/Whisper/MelStreamer.h new file mode 100644 index 0000000..152c1b6 --- /dev/null +++ b/Whisper/Whisper/MelStreamer.h @@ -0,0 +1,99 @@ +#pragma once +#include <deque> +#include "../MF/PcmReader.h" +#include "melSpectrogram.h" +#include "iSpectrogram.h" +#include <atlbase.h> +#include "../Utils/parallelFor.h" +#include "../Utils/ProfileCollection.h" + +namespace Whisper +{ + // Base class for both single- and multi-threaded MEL streamers + class MelStreamer : public iSpectrogram + { + protected: + PcmReader reader; + std::deque<PcmMonoChunk> queuePcmMono; + using MelChunk = std::array<float, N_MEL>; + std::deque<MelChunk> queueMel; + size_t streamStartOffset = 0; + std::vector<float> tempPcm; + std::vector<float> outputMel; + SpectrogramContext melContext; + bool readerEof = false; + ProfileCollection& profiler; + std::deque<PcmStereoChunk> queuePcmStereo; + + // If the streamStartOffset value is less than the argument, + // remove ( off - streamStartOffset ) chunks from the start of all 3 queues, and advance streamStartOffset to the `off` argument + void dropOldChunks( size_t off ); + + // Ensure PCM queues have enough chunks to generate specified count of MEL chunks + // At the end of the stream, the method delivers less chunks then requested and returns S_FALSE + HRESULT ensurePcmChunks( size_t len ); + + // Copy mono PCM chunks from the queue (starting at the specified element index) into the continuous tempPcm vector + // Returns count of chunks copied there. + size_t serializePcm( size_t startOffset ); + + size_t lastBufferEnd = ~(size_t)0; + float lastBufferMax = 0.0f; + void makeTransposedBuffer( size_t off, size_t len ); + + size_t getLength() const noexcept override final { return reader.getLength(); } + + public: + MelStreamer( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, bool stereo ); + }; + + // Single-threaded MEL streamer: runs these FFTs on-demand, from within makeBuffer() method + class MelStreamerSimple : public MelStreamer + { + HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride ) noexcept override final; + + public: + MelStreamerSimple( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, bool stereo ) : + MelStreamer( filters, profiler, source, stereo ) { } + }; + + // Multi threaded MEL streamers: runs FFT on a background thread ahead of time + // The background thread tries to keep the queueMel full, this way the makeBuffer() method has very little to do + // makeBuffer() only transposes the data, and does clamping + normalization, both steps are pretty fast + class MelStreamerThread : public MelStreamer, + ThreadPoolWork + { + HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride ) noexcept override final; + + static DWORD __stdcall threadProcStatic( void* lpParameter ); + HRESULT run() noexcept; + HRESULT threadMain(); + + std::vector<MelChunk> pendingChunks; + int fftChunks = 0; + int fftThreads = 0; + std::vector<SpectrogramContext> melContextsWorkers; + CComAutoCriticalSection m_cs; + CONDITION_VARIABLE wakeMain, wakeBackground; + const int workerThreads; + enum struct eThreadStatus : uint8_t + { + NotStarted = 0, + Idle, + Working, + Completed, + Failed + }; + eThreadStatus threadStatus; + bool shuttingDown = false; + CHandle threadHandle; + + HRESULT threadPoolCallback( int ith ) noexcept override final; + + public: + + MelStreamerThread( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, int countThreads, bool stereo ); + + ~MelStreamerThread(); + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/ModelBuffers.cpp b/Whisper/Whisper/ModelBuffers.cpp new file mode 100644 index 0000000..7ad8e13 --- /dev/null +++ b/Whisper/Whisper/ModelBuffers.cpp @@ -0,0 +1,115 @@ +#include "stdafx.h" +#include "ModelLoader.h" + + +#if BUILD_BOTH_VERSIONS +namespace DirectCompute +{ + static ModelBuffers s_model; + const ModelBuffers& gpuModel = s_model; +} + +using namespace DirectCompute; + +ModelLoader::ModelLoader( int encoderLayers, int decoderLayers ) : + model( s_model ) +{ + if( encoderLayers <= 0 || decoderLayers <= 0 ) + throw E_INVALIDARG; + model.enc.layers.resize( (uint32_t)encoderLayers ); + model.dec.layers.resize( (uint32_t)decoderLayers ); +} + +void ModelLoader::add( const ggml_tensor* ggml, Tensor& gpu ) +{ + if( nullptr == ggml ) + throw E_POINTER; + + auto res = map.try_emplace( ggml, &gpu ); + if( !res.second ) + throw E_INVALIDARG; +} + +Tensor* ModelLoader::lookup( const ggml_tensor* ggml ) const +{ + auto it = map.find( ggml ); + if( it == map.end() ) + return nullptr; + return it->second; +} + +bool ModelLoader::tryLoad( const ggml_tensor* ggml ) +{ + Tensor* rdi = lookup( ggml ); + if( nullptr == rdi ) + return false; + HRESULT hr = rdi->create( *ggml, eBufferUse::Immutable, true ); + if( SUCCEEDED( hr ) ) + return true; + throw hr; +} +#endif + +__m128i __declspec( noinline ) DirectCompute::TensorPair::getMemoryUse() const +{ + return _mm_add_epi64( w.getMemoryUse(), b.getMemoryUse() ); +} + +__m128i DirectCompute::LayerEncoder::getMemoryUse() const +{ + __m128i v = attnLn0.getMemoryUse(); + v = _mm_add_epi64( v, attnLn1.getMemoryUse() ); + v = _mm_add_epi64( v, attnQuery.getMemoryUse() ); + v = _mm_add_epi64( v, attnKey.getMemoryUse() ); + v = _mm_add_epi64( v, attnValue.getMemoryUse() ); + v = _mm_add_epi64( v, mlpLn.getMemoryUse() ); + v = _mm_add_epi64( v, mlp0.getMemoryUse() ); + v = _mm_add_epi64( v, mlp1.getMemoryUse() ); + return v; +} + +__m128i DirectCompute::EncoderBuffers::getMemoryUse() const +{ + __m128i v = _mm_cvtsi64_si128( vectorMemoryUse( layers ) ); + v = _mm_add_epi64( v, positionalEmbedding.getMemoryUse() ); + v = _mm_add_epi64( v, conv1.getMemoryUse() ); + v = _mm_add_epi64( v, conv2.getMemoryUse() ); + v = _mm_add_epi64( v, lnPost.getMemoryUse() ); + for( const auto& layer : layers ) + v = _mm_add_epi64( v, layer.getMemoryUse() ); + return v; +} + +__m128i DirectCompute::LayerDecoder::getMemoryUse() const +{ + __m128i v = attnLn0.getMemoryUse(); + v = _mm_add_epi64( v, attnLn1.getMemoryUse() ); + v = _mm_add_epi64( v, attnQuery.getMemoryUse() ); + v = _mm_add_epi64( v, attnKey.getMemoryUse() ); + v = _mm_add_epi64( v, attnValue.getMemoryUse() ); + v = _mm_add_epi64( v, crossAttnLn0.getMemoryUse() ); + v = _mm_add_epi64( v, crossAttnLn1.getMemoryUse() ); + v = _mm_add_epi64( v, crossAttnQuery.getMemoryUse() ); + v = _mm_add_epi64( v, crossAttnKey.getMemoryUse() ); + v = _mm_add_epi64( v, crossAttnValue.getMemoryUse() ); + v = _mm_add_epi64( v, mlpLn.getMemoryUse() ); + v = _mm_add_epi64( v, mlp0.getMemoryUse() ); + v = _mm_add_epi64( v, mlp1.getMemoryUse() ); + return v; +} + +__m128i DirectCompute::DecoderBuffers::getMemoryUse() const +{ + __m128i v = _mm_cvtsi64_si128( vectorMemoryUse( layers ) ); + v = _mm_add_epi64( v, positionalEmbedding.getMemoryUse() ); + v = _mm_add_epi64( v, tokenEmbedding.getMemoryUse() ); + v = _mm_add_epi64( v, ln.getMemoryUse() ); + for( const auto& layer : layers ) + v = _mm_add_epi64( v, layer.getMemoryUse() ); + return v; +} + +__m128i DirectCompute::ModelBuffers::getMemoryUse() const +{ + return _mm_add_epi64( enc.getMemoryUse(), dec.getMemoryUse() ); +}
\ No newline at end of file diff --git a/Whisper/Whisper/ModelBuffers.h b/Whisper/Whisper/ModelBuffers.h new file mode 100644 index 0000000..403d274 --- /dev/null +++ b/Whisper/Whisper/ModelBuffers.h @@ -0,0 +1,114 @@ +#pragma once +#include "../ML/Tensor.h" +#include <vector> + +namespace DirectCompute +{ + // A pair of tensors containing weights and biases; apparently, both tensors are of the same shape. + struct TensorPair + { + Tensor w, b; + + __m128i getMemoryUse() const; + }; + + // A set of tensors for one encoder's layer + struct LayerEncoder + { + // encoder.blocks.*.attn_ln + TensorPair attnLn0; + // encoder.blocks.*.attn.out + TensorPair attnLn1; + // encoder.blocks.*.attn.query + TensorPair attnQuery; + // encoder.blocks.*.attn.key + Tensor attnKey; + // encoder.blocks.*.attn.value + TensorPair attnValue; + // encoder.blocks.*.mlp_ln + TensorPair mlpLn; + // encoder.blocks.*.mlp.0 + TensorPair mlp0; + // encoder.blocks.*.mlp.2 + TensorPair mlp1; + + __m128i getMemoryUse() const; + }; + + // A set of tensors for the encoder + struct EncoderBuffers + { + // encoder.positional_embedding + Tensor positionalEmbedding; + // encoder.conv1 + TensorPair conv1; + // encoder.conv2 + TensorPair conv2; + // encoder.ln_post + TensorPair lnPost; + // A vector of layers + std::vector<LayerEncoder> layers; + + __m128i getMemoryUse() const; + }; + + // A set of tensors for one decoder's layer + struct LayerDecoder + { + // decoder.blocks.*.attn_ln + TensorPair attnLn0; + // decoder.blocks.*.attn.out + TensorPair attnLn1; + // decoder.blocks.*.attn.query + TensorPair attnQuery; + // decoder.blocks.*.attn.key + Tensor attnKey; + // decoder.blocks.*.attn.value + TensorPair attnValue; + // decoder.blocks.*.cross_attn_ln + TensorPair crossAttnLn0; + // decoder.blocks.*.cross_attn.out + TensorPair crossAttnLn1; + // decoder.blocks.*.cross_attn.query + TensorPair crossAttnQuery; + // decoder.blocks.*.cross_attn.key + Tensor crossAttnKey; + // decoder.blocks.*.cross_attn.value + TensorPair crossAttnValue; + // decoder.blocks.*.mlp_ln + TensorPair mlpLn; + // decoder.blocks.*.mlp.0 + TensorPair mlp0; + // decoder.blocks.*.mlp.2 + TensorPair mlp1; + + __m128i getMemoryUse() const; + }; + + // A set of tensors for the decoder + struct DecoderBuffers + { + // decoder.positional_embedding + Tensor positionalEmbedding; + // decoder.token_embedding + Tensor tokenEmbedding; + // decoder.ln + TensorPair ln; + // A vector of layers + std::vector<LayerDecoder> layers; + + __m128i getMemoryUse() const; + }; + + // A complete set of tensors for a model + struct ModelBuffers + { + EncoderBuffers enc; + DecoderBuffers dec; + __m128i getMemoryUse() const; + }; + +#if BUILD_BOTH_VERSIONS + extern const ModelBuffers& gpuModel; +#endif +}
\ No newline at end of file diff --git a/Whisper/Whisper/ModelImpl.cpp b/Whisper/Whisper/ModelImpl.cpp new file mode 100644 index 0000000..968c3ce --- /dev/null +++ b/Whisper/Whisper/ModelImpl.cpp @@ -0,0 +1,122 @@ +#include "stdafx.h" +#include "ModelImpl.h" +#include "../ML/mlStartup.h" +#include "ContextImpl.h" +#include <intrin.h> +#include "../Utils/ReadStream.h" +#include "../modelFactory.h" +using namespace Whisper; + +namespace +{ + volatile long s_refCounter = 0; +} + +HRESULT ModelImpl::FinalConstruct() +{ + if( 1 != InterlockedIncrement( &s_refCounter ) ) + return S_FALSE; + return DirectCompute::mlStartup(); +} + +void ModelImpl::FinalRelease() +{ + if( 0 == InterlockedDecrement( &s_refCounter ) ) + DirectCompute::mlShutdown(); +} + +HRESULT COMLIGHTCALL ModelImpl::createContext( iContext** pp ) +{ + ComLight::CComPtr<ComLight::Object<ContextImpl>> obj; + + iModel* m = this; + CHECK( ComLight::Object<ContextImpl>::create( obj, model, m ) ); + + obj.detach( pp ); + return S_OK; +} + +HRESULT ModelImpl::load( iReadStream* stm, bool hybrid, const sLoadModelCallbacks* callbacks ) +{ + return model.load( stm, hybrid, callbacks ); +} + +inline bool hasSse41() +{ + int cpu_info[ 4 ]; + __cpuid( cpu_info, 1 ); + return ( cpu_info[ 2 ] & ( 1 << 19 ) ) != 0; +} + +// True when the current CPU is good enough to run the hybrid model +inline bool hasAvxAndFma() +{ + // AVX needs OS support to preserve the 32-bytes registers across context switches, CPU support alone ain't enough + // Calling a kernel API to check that support + // The magic number is from there: https://stackoverflow.com/a/35096938/126995 + if( 0 == ( GetEnabledXStateFeatures() & 4 ) ) + return false; + + // FMA3 and F16C + int cpuInfo[ 4 ]; + __cpuid( cpuInfo, 1 ); + // The magic numbers are from "Feature Information" table on Wikipedia: + // https://en.wikipedia.org/wiki/CPUID#EAX=1:_Processor_Info_and_Feature_Bits + constexpr int requiredBits = ( 1 << 12 ) | ( 1 << 29 ); + if( requiredBits != ( cpuInfo[ 2 ] & requiredBits ) ) + return false; + + // BMI1 + // https://en.wikipedia.org/wiki/CPUID#EAX=7,_ECX=0:_Extended_Features + __cpuid( cpuInfo, 7 ); + if( 0 == ( cpuInfo[ 1 ] & ( 1 << 3 ) ) ) + return false; + + return true; +} + +HRESULT __stdcall Whisper::loadGpuModel( const wchar_t* path, bool hybrid, const sLoadModelCallbacks* callbacks, iModel** pp ) +{ + if( nullptr == path || nullptr == pp ) + return E_POINTER; + + if( hybrid ) + { +#if BUILD_HYBRID_VERSION + if( !hasAvxAndFma() ) + { + logError( u8"eModelImplementation.Hybrid model requires a CPU with AVX1, FMA3, F16C and BMI1 support" ); + return ERROR_HV_CPUID_FEATURE_VALIDATION; + } +#else + logError( u8"This build of the DLL doesn’t implement eModelImplementation.Hybrid model" ); + return E_NOTIMPL; +#endif + } + else if( !hasSse41() ) + { + logError( u8"eModelImplementation.GPU model requires a CPU with SSE 4.1 support" ); + return ERROR_HV_CPUID_FEATURE_VALIDATION; + } + + ComLight::Object<ReadStream> stream; + HRESULT hr = stream.open( path ); + if( FAILED( hr ) ) + { + logError16( L"Unable to open model binary file \"%s\"", path ); + return hr; + } + + ComLight::CComPtr<ComLight::Object<ModelImpl>> obj; + CHECK( ComLight::Object<ModelImpl>::create( obj ) ); + hr = obj->load( &stream, hybrid, callbacks ); + if( FAILED( hr ) ) + { + logError16( L"Error loading the model from \"%s\"", path ); + return hr; + } + + obj.detach( pp ); + logInfo16( L"Loaded model from \"%s\" to VRAM", path ); + return S_OK; +}
\ No newline at end of file diff --git a/Whisper/Whisper/ModelImpl.h b/Whisper/Whisper/ModelImpl.h new file mode 100644 index 0000000..4bcea12 --- /dev/null +++ b/Whisper/Whisper/ModelImpl.h @@ -0,0 +1,40 @@ +#pragma once +#include "../API/iContext.cl.h" +#include "../ComLightLib/comLightServer.h" +#include "WhisperModel.h" +#include "../ComLightLib/streams.h" + +namespace Whisper +{ + using ComLight::iReadStream; + + class ModelImpl : public ComLight::ObjectRoot<iModel> + { + WhisperModel model; + + HRESULT COMLIGHTCALL createContext( iContext** pp ) override final; + + HRESULT COMLIGHTCALL getSpecialTokens( SpecialTokens& rdi ) override final + { + model.vocab.getSpecialTokens( rdi ); + return S_OK; + } + + HRESULT COMLIGHTCALL isMultilingual() override final + { + return model.vocab.is_multilingual() ? S_OK : S_FALSE; + } + + const char* COMLIGHTCALL stringFromToken( whisper_token token ) override final + { + return model.vocab.string( token ); + } + + public: + + HRESULT FinalConstruct(); + void FinalRelease(); + + HRESULT load( iReadStream* stm, bool hybrid, const sLoadModelCallbacks* callbacks ); + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/ModelLoader.h b/Whisper/Whisper/ModelLoader.h new file mode 100644 index 0000000..b136dd2 --- /dev/null +++ b/Whisper/Whisper/ModelLoader.h @@ -0,0 +1,29 @@ +#pragma once +#include "ModelBuffers.h" +#include <map> + +namespace DirectCompute +{ + struct ModelLoader + { + ModelLoader( int encoderLayers, int decoderLayers ); + + void add( const ggml_tensor* ggml, Tensor& gpu ); + + void add( const ggml_tensor* w, const ggml_tensor* b, TensorPair& gpu ) + { + add( w, gpu.w ); + add( b, gpu.b ); + } + + bool tryLoad( const ggml_tensor* ggml ); + + ModelBuffers& model; + + private: + + Tensor* lookup( const ggml_tensor* ggml ) const; + + std::map<const ggml_tensor*, Tensor*> map; + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/Spectrogram.cpp b/Whisper/Whisper/Spectrogram.cpp new file mode 100644 index 0000000..400f9b1 --- /dev/null +++ b/Whisper/Whisper/Spectrogram.cpp @@ -0,0 +1,124 @@ +#include "stdafx.h" +#include "Spectrogram.h" +#include <memory> +#define _USE_MATH_DEFINES +#include <math.h> +#include "../Utils/parallelFor.h" +#include "../API/iMediaFoundation.cl.h" +#include "../ML/testUtils.h" +#include "melSpectrogram.h" +using namespace Whisper; + +class alignas( 64 ) Spectrogram::MelContext +{ + const float* const samples; + const size_t countSamples; + Spectrogram& result; + const int n_threads; + SpectrogramContext context; + +public: + + MelContext( const float* rsi, size_t len, const Filters& f, Spectrogram& rdi, int countThreads ) : + samples( rsi ), countSamples( len ), result( rdi ), n_threads( countThreads ), + context( f ) + { } + + void run( int ith ); + + static HRESULT workCallback( int ith, void* ctx ) noexcept; +}; + +void Spectrogram::MelContext::run( int ith ) +{ + std::array<float, N_MEL> arr; + for( uint32_t i = ith; i < result.length; i += n_threads ) + { + const int offset = i * FFT_STEP; + const float* rsi = samples + offset; + context.fft( arr, rsi, countSamples - offset ); + + for( size_t j = 0; j < N_MEL; j++ ) + result.data[ j * result.length + i ] = arr[ j ]; + } +} + +HRESULT Spectrogram::MelContext::workCallback( int ith, void* ctx ) noexcept +{ + std::vector<Spectrogram::MelContext>& contexts = *( std::vector<Spectrogram::MelContext>* )ctx; + try + { + contexts.at( ith ).run( ith ); + return S_OK; + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + catch( const std::exception& ) + { + return E_FAIL; + } +} + +HRESULT Spectrogram::pcmToMel( const iAudioBuffer* buffer, const Filters& filters, int threads ) +{ + if( nullptr == buffer ) + return E_POINTER; + const uint32_t countSamples = buffer->countSamples(); + if( 0 == countSamples ) + return OLE_E_BLANK; + const float* const samples = buffer->getPcmMono(); + + length = ( countSamples ) / FFT_STEP; + data.resize( N_MEL * length ); + + if( threads < 2 ) + { + MelContext ctx{ samples, countSamples, filters, *this, 1 }; + ctx.run( 0 ); + } + else + { + std::vector<MelContext> contexts; + contexts.reserve( threads ); + for( int i = 0; i < threads; i++ ) + contexts.emplace_back( MelContext{ samples, countSamples, filters, *this, (int)threads } ); + CHECK( parallelFor( &MelContext::workCallback, threads, &contexts ) ); + } + + // clamping and normalization + double mmax = -1e20; + for( double f : data ) + mmax = std::max( mmax, f ); + //printf("%s: max = %f\n", __func__, mmax); + + mmax -= 8.0; + + for( float& f : data ) + { + if( f < mmax ) + f = (float)mmax; + f = (float)( ( f + 4.0 ) / 4.0 ); + } + // DirectCompute::dbgWriteBinaryFile( LR"(C:\Temp\2remove\ML\mel-my.bin)", data.data(), data.size() * 4 ); + return S_OK; +} + +void Whisper::computeSignalEnergy( std::vector<float>& result, const iAudioBuffer* buffer, int n_samples_per_half_window ) +{ + const size_t countSamples = buffer->countSamples(); + const float* const samples = buffer->getPcmMono(); + + const int hw = n_samples_per_half_window; + result.resize( countSamples ); + + for( size_t i = 0; i < countSamples; i++ ) + { + float sum = 0; + for( int j = -hw; j <= hw; j++ ) + if( i + j >= 0 && i + j < countSamples ) + sum += fabsf( samples[ i + j ] ); + result[ i ] = sum / ( 2 * hw + 1 ); + } +}
\ No newline at end of file diff --git a/Whisper/Whisper/Spectrogram.h b/Whisper/Whisper/Spectrogram.h new file mode 100644 index 0000000..04e2c06 --- /dev/null +++ b/Whisper/Whisper/Spectrogram.h @@ -0,0 +1,42 @@ +#pragma once +#include "WhisperModel.h" +#include "iSpectrogram.h" +#include "audioConstants.h" + +namespace Whisper +{ + struct iAudioBuffer; + + class Spectrogram: public iSpectrogram + { + uint32_t length = 0; + static constexpr uint32_t mel = N_MEL; + std::vector<float> data; + + HRESULT makeBuffer( size_t off, size_t len, const float** buffer, size_t& stride ) noexcept override final + { + if( off + len > length ) + return E_BOUNDS; + *buffer = &data[ off ]; + stride = length; + return S_OK; + } + + class MelContext; + + public: + size_t getLength() const noexcept override final + { + return length; + } + HRESULT pcmToMel( const iAudioBuffer* buffer, const Filters& filters, int threads = 1 ); + + size_t memoryUsage() const + { + return data.size() * 4; + } + }; + + // average the fabs of the signal + void computeSignalEnergy( std::vector<float>& result, const iAudioBuffer* buffer, int n_samples_per_half_window ); +}
\ No newline at end of file diff --git a/Whisper/Whisper/TranscribeResult.h b/Whisper/Whisper/TranscribeResult.h new file mode 100644 index 0000000..8f8e408 --- /dev/null +++ b/Whisper/Whisper/TranscribeResult.h @@ -0,0 +1,43 @@ +#pragma once +#include "../API/iTranscribeResult.cl.h" +#include "../ComLightLib/comLightServer.h" + +namespace Whisper +{ + class TranscribeResult : public ComLight::ObjectRoot<iTranscribeResult> + { + HRESULT COMLIGHTCALL getSize( sTranscribeLength& rdi ) const noexcept override final + { + rdi.countSegments = (uint32_t)segments.size(); + rdi.countTokens = (uint32_t)tokens.size(); + return S_OK; + } + const sSegment* COMLIGHTCALL getSegments() const noexcept override final + { + if( !segments.empty() ) + return segments.data(); + return nullptr; + } + const sToken* COMLIGHTCALL getTokens() const noexcept override final + { + if( !tokens.empty() ) + return tokens.data(); + return nullptr; + } + + public: + std::vector<sSegment> segments; + std::vector<sToken> tokens; + }; + + class TranscribeResultStatic : public ComLight::Object<TranscribeResult> + { + uint32_t COMLIGHTCALL Release() override final + { + // When the ref.counter reaches zero, Object.Release() method calls `delete this`. + // We don't want that for the aggregated object. + // Instead we only decrement the ref.counter, but do not delete the object even when the counter reaches zero. + return RefCounter::implRelease(); + } + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/Vocabulary.cpp b/Whisper/Whisper/Vocabulary.cpp new file mode 100644 index 0000000..2e5eaec --- /dev/null +++ b/Whisper/Whisper/Vocabulary.cpp @@ -0,0 +1,129 @@ +#include "stdafx.h" +#include "Vocabulary.h" +#include "loaderUtils.h" +using ComLight::iReadStream; +using namespace Whisper; + +void Vocabulary::addExtra( int index, const char* format, int i ) +{ + const int len = std::snprintf( nullptr, 0, format, i ); + const size_t offset = stringData.size(); + stringData.resize( offset + len + 1 ); + char* const rdi = stringData.data() + offset; + std::snprintf( rdi, len + 1, format, i ); + rdi[ len ] = '\0'; + tokens[ index ] = reinterpret_cast<const char*>( offset ); +} + +void Vocabulary::completeBuild() +{ + stringData.shrink_to_fit(); + + const size_t dataLength = stringData.size(); + for( auto& s : tokens ) + { + // The reason this hack works - on Windows, lower 2GB of address space is reserved to the kernel. + // That's why the strings from the read only section of this DLL like "[_PREV_]" are guaranteed to have their addresses much larger than the size of the data buffer + const size_t ri = reinterpret_cast<size_t>( s ); + if( ri < dataLength ) + s = stringData.data() + ri; + } + + int64_t cb = stringData.size(); + cb += tokens.size() * sizeof( void* ); + constexpr double mulKb = 1.0 / ( 1 << 10 ); + logDebug( u8"Loaded vocabulary, %zu strings, %.1f kb RAM", tokens.size(), mulKb * cb ); +} + +HRESULT Vocabulary::load( ComLight::iReadStream* stm, int lengthInHeader ) +{ + if( lengthInHeader <= 0 ) + return E_INVALIDARG; + + tokens.clear(); + stringData.clear(); + + int countWords = 0; + CHECK( readStruct( stm, countWords ) ); + if( countWords <= 0 ) + return E_INVALIDARG; + + const size_t count = (uint32_t)countWords; + const size_t actualCount = std::max( count, (size_t)lengthInHeader ); + tokens.resize( actualCount ); + + for( int i = 0; i < count; i++ ) + { + int countChars = 0; + CHECK( readStruct( stm, countChars ) ); + if( countChars < 0 ) + { + logError( u8"Vocabulary.load failed: string length is negative" ); + return E_INVALIDARG; + } + if( countChars == 0 ) + { + // This happens with `ggml-large.bin` and `ggml-large-v1.bin` models. + // A bug in the model maybe? + tokens[ i ] = ""; + continue; + } + const size_t len = (size_t)countChars; + + const size_t offset = stringData.size(); + stringData.resize( offset + len + 1 ); + + CHECK( readBytes( stm, &stringData[ offset ], len ) ); + *stringData.rbegin() = '\0'; + + tokens[ i ] = reinterpret_cast<const char*>( offset ); + } + + n_vocab = lengthInHeader; + + if( is_multilingual() ) + { + token_eot++; + token_sot++; + token_prev++; + token_solm++; + token_not++; + token_beg++; + }; + + if( countWords < lengthInHeader ) + { + for( int i = countWords; i < lengthInHeader; i++ ) + { + if( i > token_beg ) + addExtra( i, "[_TT_%i]", i - token_beg ); + else if( i == token_eot ) + tokens[ i ] = "[_EOT_]"; + else if( i == token_sot ) + tokens[ i ] = "[_SOT_]"; + else if( i == token_prev ) + tokens[ i ] = "[_PREV_]"; + else if( i == token_not ) + tokens[ i ] = "[_NOT_]"; + else if( i == token_beg ) + tokens[ i ] = "[_BEG_]"; + else + addExtra( i, "[_extra_token_%i]", i ); + } + } + + completeBuild(); + return S_OK; +} + +void Vocabulary::getSpecialTokens( SpecialTokens& rdi ) const +{ + rdi.TranscriptionEnd = token_eot; + rdi.TranscriptionStart = token_sot; + rdi.PreviousWord = token_prev; + rdi.SentenceStart = token_solm; + rdi.Not = token_not; + rdi.TranscriptionBegin = token_beg; + rdi.TaskTranslate = token_translate; + rdi.TaskTranscribe = token_transcribe; +}
\ No newline at end of file diff --git a/Whisper/Whisper/Vocabulary.h b/Whisper/Whisper/Vocabulary.h new file mode 100644 index 0000000..6250494 --- /dev/null +++ b/Whisper/Whisper/Vocabulary.h @@ -0,0 +1,58 @@ +#pragma once +#include "../../ComLightLib/streams.h" +#include "../API/SpecialTokens.h" + +namespace Whisper +{ + class Vocabulary + { + std::vector<const char*> tokens; + std::vector<char> stringData; + + void addExtra( int index, const char* format, int i ); + + void completeBuild(); + public: + + int n_vocab = 51864; + + HRESULT load( ComLight::iReadStream* stm, int lengthInHeader ); + + using id = int; + + id token_eot = 50256; + id token_sot = 50257; + id token_prev = 50360; + id token_solm = 50361; // ?? + id token_not = 50362; // no timestamps + id token_beg = 50363; + + // available tasks + static const id token_translate = 50358; + static const id token_transcribe = 50359; + + bool is_multilingual() const + { + return n_vocab == 51865; + } + + const char* string( int id ) const + { + if( id >= 0 && id < (int)tokens.size() ) + return tokens[ id ]; + return nullptr; + } + + size_t size() const + { + return tokens.size(); + } + + void getSpecialTokens( SpecialTokens& rdi ) const; + + size_t getMemoryUse() const + { + return vectorMemoryUse( tokens ) + vectorMemoryUse( stringData ); + } + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/WhisperContext.cpp b/Whisper/Whisper/WhisperContext.cpp new file mode 100644 index 0000000..c983039 --- /dev/null +++ b/Whisper/Whisper/WhisperContext.cpp @@ -0,0 +1,673 @@ +#include "stdafx.h" +#include "WhisperContext.h" +#include "ModelBuffers.h" +#include <optional> +#include "../Utils/Trace/tracing.h" +#include "../D3D/RenderDoc/renderDoc.h" +#include "../ML/testUtils.h" +using namespace DirectCompute; + +namespace +{ + // True to measure GPU time of individual shaders which run during the encode step of the algorithm + constexpr bool profileEncodeShaders = true; + // True to measure GPU time of individual shaders which run during the decode step of the algorithm + constexpr bool profileDecodeShaders = true; + + LPCTSTR traceFileNative = LR"(C:\Temp\2remove\Whisper\gpu.bin)"; + LPCTSTR traceFileHybrid = LR"(C:\Temp\2remove\Whisper\hybrid.bin)"; + + TensorsArena::sArenaConfigs defaultArenaConfigs() + { + TensorsArena::sArenaConfigs res = {}; + return res; + } +} + +WhisperContext::Arenas::Arenas() : + enc( defaultArenaConfigs() ), encLayer( defaultArenaConfigs() ), dec( defaultArenaConfigs() ), decLayer( defaultArenaConfigs() ) +{ } + +Tensor WhisperContext::DecoderLayerPool::tensor( eDataType type, const std::array<uint32_t, 4>& ne ) +{ + assert( type == eDataType::FP32 ); + return result.tensor( eDataType::FP32, ne, &DirectCompute::defaultNewCapacity ); +} + +class WhisperContext::ArenaRaii +{ + WhisperContext& context; + iTensorArena* prevCurrent; + +public: + + ArenaRaii( WhisperContext& ctx, iTensorArena& ta ) : + context( ctx ) + { + prevCurrent = ctx.currentArena; + ctx.currentArena = &ta; + } + + ~ArenaRaii() + { + context.currentArena->reset(); + context.currentArena = prevCurrent; + } +}; + +WhisperContext::WhisperContext( const Whisper::WhisperModel& wm, Whisper::ProfileCollection& pc ) : + MlContext( pc ), + gpuModel( wm.tensors ) +{ +#if BUILD_HYBRID_VERSION + if( !wm.hybridTensors.layers.empty() ) + { + hybridContext = std::make_unique<HybridContext>( wm ); + check( hybridContext->create() ); +#if SAVE_DEBUG_TRACE + Tracing::traceCreate( traceFileHybrid ); +#endif + } + else +#endif + { +#if SAVE_DEBUG_TRACE + Tracing::traceCreate( traceFileNative ); +#endif + } +} + +#if BUILD_BOTH_VERSIONS +namespace +{ + thread_local WhisperContext* ts_context = nullptr; + const ModelBuffers& getGlobalModel() + { + return gpuModel; + } +} + +/* +WhisperContext::WhisperContext() : + gpuModel( getGlobalModel() ), +{ + if( nullptr != ts_context ) + throw HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); + ts_context = this; +}*/ + +WhisperContext::~WhisperContext() +{ + Tracing::traceClose(); + if( ts_context == nullptr ) + return; + assert( ts_context == this ); + ts_context = nullptr; +} + +WhisperContext& WhisperContext::current() +{ + WhisperContext* c = ts_context; + if( nullptr == c ) + throw OLE_E_BLANK; + return *c; +} +#else +WhisperContext& WhisperContext::current() +{ + throw E_NOTIMPL; +} +#endif + +Tensor WhisperContext::createTensor( eDataType type, const std::array<uint32_t, 4>& ne ) +{ + // return MlContext::createTensor( type, ne ); + + iTensorArena* const ca = currentArena; + if( nullptr != ca ) + return ca->tensor( type, ne ); + else + return MlContext::createTensor( type, ne ); +} + +void WhisperContext::fmaRepeat( Tensor& cur, const TensorPair& that ) +{ + MlContext::fmaRepeat( cur, that.w, that.b ); +} + +Tensor WhisperContext::convolutionAndGelu( const Tensor& mel, uint32_t n_ctx ) +{ + const EncoderBuffers& model = gpuModel.enc; + Tensor cur = conv_1d_1s( model.conv1.w, mel ); + Tracing::tensor( "enc.conv1", cur ); + addRepeatGelu( cur, model.conv1.b ); + Tracing::tensor( "enc.temp1", cur ); + + cur = conv_1d_2s( model.conv2.w, cur ); + addRepeatGelu( cur, model.conv2.b ); + + const Tensor& posEmbed = model.positionalEmbedding; + const uint32_t peStride = posEmbed.ne[ 0 ]; + constexpr uint32_t peOffset = 0; + + Tensor e_pe = view2d( posEmbed, posEmbed.ne[ 0 ], n_ctx, peStride, peOffset ); + cur = add( e_pe, transpose( cur ) ); + return cur; +} + +Tensor WhisperContext::encodeLayer( const Tensor& source, size_t index, uint32_t n_state, uint32_t n_head, uint32_t n_ctx ) +{ + auto prof = profiler.block( eProfilerBlock::EncodeLayer ); + ArenaRaii arenaRaii{ *this, arenas.encLayer }; + + const LayerEncoder& layer = gpuModel.enc.layers[ index ]; + // norm + Tensor cur = norm( source ); + if( 0 == index ) + Tracing::tensor( "enc-norm", cur ); + fmaRepeat( cur, layer.attnLn0 ); + + // self-attention + Tensor Qcur; + Tensor reshaped; + if( gpuInfo.useReshapedMatMul() ) + { + const uint16_t tag = profiler.setNextTag( "enc.layer.1" ); + reshaped = reshapePanels( cur ); + profiler.setNextTag( tag ); + Qcur = mulMatTiledEx( layer.attnQuery.w, reshaped ); + } + else + { + profiler.setNextTag( "enc.layer.1" ); + Qcur = mulMat( layer.attnQuery.w, cur ); + } + + if( 0 == index ) + Tracing::tensor( "enc-Qcur", Qcur ); + addRepeat( Qcur, layer.attnQuery.b ); + + // note: no bias for Key + Tensor Kcur; + if( gpuInfo.useReshapedMatMul() ) + { + // Already reshaped by the previous `if` + profiler.setNextTag( "enc.layer.2" ); + Kcur = mulMatTiledEx( layer.attnKey, reshaped ); + } + else + { + profiler.setNextTag( "enc.layer.2" ); + Kcur = mulMat( layer.attnKey, cur ); + } + + if( 0 == index ) + Tracing::tensor( "enc-Kcur", Kcur ); + + Tensor Vcur; + if( gpuInfo.useReshapedMatMul() ) + { + // Already reshaped by the previous `if` + profiler.setNextTag( "enc.layer.3" ); + Vcur = mulMatTiledEx( layer.attnValue.w, reshaped ); + } + else + { + profiler.setNextTag( "enc.layer.3" ); + Vcur = mulMat( layer.attnValue.w, cur ); + } + + if( 0 == index ) + Tracing::tensor( "enc-Vcur", Vcur ); + addRepeat( Vcur, layer.attnValue.b ); + + // ------ + Tensor Q = permute( copy( Qcur, eDataType::FP16, { n_state / n_head, n_head, n_ctx } ), 0, 2, 1, 3 ); + Tensor K = permute( copy( Kcur, eDataType::FP16, { n_state / n_head, n_head, n_ctx } ), 0, 2, 1, 3 ); + Tensor V = copy( permute( Vcur.reshape3d( n_state / n_head, n_head, n_ctx ), 1, 2, 0, 3 ), eDataType::FP16, { n_ctx, n_state / n_head, n_head } ); + Tensor KQV = flashAttention( Q, K, V, false ); + if( 0 == index ) + Tracing::tensor( "enc-KQV", KQV ); + Tensor KQV_merged = permute( KQV, 0, 2, 1, 3 ); + copyInPlace( cur, KQV_merged, eDataType::FP32, { n_state, n_ctx } ); + + // projection + if( gpuInfo.useReshapedMatMul() ) + { + const uint16_t tag = profiler.setNextTag( "enc.layer.4" ); + cur = reshapePanels( cur ); + profiler.setNextTag( tag ); + cur = mulMatTiledEx( layer.attnLn1.w, cur ); + } + else + { + profiler.setNextTag( "enc.layer.4" ); + cur = mulMat( layer.attnLn1.w, cur ); + } + addRepeat( cur, layer.attnLn1.b ); + + // add the input + addInPlace( cur, source ); + + // feed-forward network + Tensor inpFF = cur; + + cur = norm( inpFF ); + fmaRepeat( cur, layer.mlpLn ); + + // fully connected + if( gpuInfo.useReshapedMatMul() ) + { + const uint16_t tag = profiler.setNextTag( "enc.layer.5" ); + cur = reshapePanels( cur ); + profiler.setNextTag( tag ); + cur = mulMatTiledEx( layer.mlp0.w, cur ); + } + else + { + profiler.setNextTag( "enc.layer.5" ); + cur = mulMat( layer.mlp0.w, cur ); + } + addRepeatGelu( cur, layer.mlp0.b ); + + // projection + if( gpuInfo.useReshapedMatMul() ) + { + const uint16_t tag = profiler.setNextTag( "enc.layer.6" ); + cur = reshapePanels( cur ); + profiler.setNextTag( tag ); + cur = mulMatTiledEx( layer.mlp1.w, cur ); + } + else + { + profiler.setNextTag( "enc.layer.6" ); + cur = mulMat( layer.mlp1.w, cur ); + } + + addRepeat( cur, layer.mlp1.b ); + + // output from this layer + addInPlace( cur, inpFF ); + return cur; +} + +void WhisperContext::createKeyValueBuffers( const sEncodeParams& encParams ) +{ + { + const uint32_t n_audio_ctx = encParams.n_audio_ctx; + const uint32_t n_mem = encParams.n_text_layer * encParams.n_audio_ctx; + const uint32_t n_elements = encParams.n_text_state * n_mem; + kvCross.resize( n_elements ); + } + +#if BUILD_HYBRID_VERSION + if( !hybridContext ) +#endif + { + const uint32_t n_mem = encParams.n_text_layer * encParams.n_text_ctx; + const uint32_t n_elements = encParams.n_text_state * n_mem; + kv.resize( n_elements ); + } +} + +Tensor WhisperContext::encode( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams ) +{ + auto prof = profiler.block( eProfilerBlock::Encode ); + CaptureRaii renderdocCapture; + profiler.profileShaders = profileEncodeShaders; + + createKeyValueBuffers( encParams ); + // Upload the source + check( melInput.create( spectrogram, encParams ) ); + Tracing::tensor( "enc.input", melInput ); + + arenas.enc.clear(); + ArenaRaii arenaRaii{ *this, arenas.enc }; + + // Initial few steps + Tensor cur = convolutionAndGelu( melInput, encParams.n_ctx ); + + // Process all these layers + { + const size_t layersCount = encParams.layersCount; + for( size_t i = 0; i < layersCount; i++ ) + { + Tracing::tensor( { "enc.layer[ %i ].in", i }, cur ); + cur = encodeLayer( cur, i, encParams.n_state, encParams.n_head, encParams.n_ctx ); + } + } + Tracing::tensor( "enc.layers", cur ); + + // A few last steps + { + cur = norm( cur ); + // cur = ln_f_g*cur + ln_f_b + fmaRepeat( cur, gpuModel.enc.lnPost ); + } + + // pre-compute cross-attention buffers + { + Tensor reshaped; + if( gpuInfo.useReshapedMatMul() ) + { + if( cur.ne[ 1 ] != 1 ) + { + profiler.setNextTag( "enc.cross" ); + reshaped = reshapePanels( cur ); + } + else + reshaped = cur; + } + + const size_t layersCount = encParams.n_text_layer; + const uint32_t stride = encParams.n_state * encParams.n_ctx; + const float finalScaling = (float)pow( float( encParams.n_state ) / encParams.n_head, -0.25 ); + for( size_t i = 0; i < layersCount; i++ ) + { + const LayerDecoder& layer = gpuModel.dec.layers[ i ]; + Tensor Kcross, Vcross; + if( gpuInfo.useReshapedMatMul() ) + Kcross = mulMatEx( layer.crossAttnKey, reshaped, "enc.cross.1" ); + else + { + profiler.setNextTag( "enc.cross.1" ); + Kcross = mulMat( layer.crossAttnKey, cur ); + } + scale( Kcross, finalScaling ); + + if( gpuInfo.useReshapedMatMul() ) + Vcross = mulMatEx( layer.crossAttnValue.w, reshaped, "enc.cross.2" ); + else + { + profiler.setNextTag( "enc.cross.2" ); + Vcross = mulMat( layer.crossAttnValue.w, cur ); + } + addRepeat( Vcross, layer.crossAttnValue.b ); + + Tensor k = kvCross.keys.view( stride, stride * (uint32_t)i ); + copyImpl( Kcross, k, Kcross.getType() == eDataType::FP32 ); + + Tensor v = kvCross.values.view( stride, stride * (uint32_t)i ); + copyImpl( Vcross, v, Vcross.getType() == eDataType::FP32 ); + } + } + +#if BUILD_HYBRID_VERSION + if( hybridContext ) + { + // When running hybrid model, download cross-attention buffers from VRAM to system RAM + check( hybridContext->downloadKeyValues( kvCross ) ); + } +#endif + return cur; +} + +struct WhisperContext::sLayerDecParams +{ + uint32_t n_state, n_head, N; + uint32_t n_ctx, n_past, M; +}; + +Tensor WhisperContext::decodeLayer( const Tensor& inpL, size_t il, const sLayerDecParams& ldp ) +{ + auto prof = profiler.block( eProfilerBlock::DecodeLayer ); + const auto& layer = gpuModel.dec.layers[ il ]; + std::optional<ArenaRaii> arenaRaii{ std::in_place, *this, arenas.decLayer }; + if( 0 == il ) Tracing::tensor( "dec-inpL", inpL ); + + // norm + Tensor cur = norm( inpL ); + fmaRepeat( cur, layer.attnLn0 ); + if( 0 == il ) Tracing::tensor( "dec-norm", cur ); + + // self-attention + { + profiler.setNextTag( "dec.layer.1" ); + Tensor Qcur = mulMat( layer.attnQuery.w, cur ); + if( 0 == il ) Tracing::tensor( "dec-Qcur-0", Qcur ); + const float scaling = (float)pow( float( (int)ldp.n_state ) / (int)ldp.n_head, -0.25 ); + addRepeatScale( Qcur, layer.attnQuery.b, scaling ); + if( 0 == il ) Tracing::tensor( "dec-Qcur-1", Qcur ); + + // note: no bias for Key + profiler.setNextTag( "dec.layer.2" ); + Tensor Kcur = mulMat( layer.attnKey, cur ); + scale( Kcur, scaling ); + if( 0 == il ) Tracing::tensor( "dec-Kcur", Kcur ); + + profiler.setNextTag( "dec.layer.3" ); + Tensor Vcur = mulMat( layer.attnValue.w, cur ); + addRepeat( Vcur, layer.attnValue.b ); + if( 0 == il ) Tracing::tensor( "dec-Vcur", Vcur ); + + // store key and value to memory + { + const uint32_t len = ldp.N * ldp.n_state; + const uint32_t off = ldp.n_state * ( (uint32_t)il * ldp.n_ctx + ldp.n_past ); + Tensor k = kv.keys.view( len, off ); + Tensor v = kv.values.view( len, off ); + copyImpl( Kcur, k, true ); + copyImpl( Vcur, v, true ); + } + + // ------ + Tensor Q = permute( copy( Qcur, eDataType::FP32, { ldp.n_state / ldp.n_head, ldp.n_head, ldp.N } ), 0, 2, 1, 3 ); + Tensor K = permute( kv.keys.view( ( ldp.n_past + ldp.N ) * ldp.n_state, (uint32_t)il * ldp.n_ctx * ldp.n_state ) + .reshape3d( ldp.n_state / ldp.n_head, ldp.n_head, ldp.n_past + ldp.N ), + 0, 2, 1, 3 ); + profiler.setNextTag( "dec.layer.4" ); + Tensor KQ = mulMat( K, Q ); + if( 0 == il ) Tracing::tensor( "dec-KQ-0", KQ ); + diagMaskInf( KQ, ldp.n_past ); + if( 0 == il ) Tracing::tensor( "dec-KQ-1", KQ ); + softMax( KQ ); + if( 0 == il ) Tracing::tensor( "dec-KQ-2", KQ ); + + Tensor V_trans = permute( + kv.values + .view( ( ldp.n_past + ldp.N ) * ldp.n_state, (uint32_t)il * ldp.n_ctx * ldp.n_state ) + .reshape3d( ldp.n_state / ldp.n_head, ldp.n_head, ldp.n_past + ldp.N ), + 1, 2, 0, 3 ); + + profiler.setNextTag( "dec.layer.5" ); + Tensor KQV = mulMat( V_trans, KQ ); + if( 0 == il ) Tracing::tensor( "dec-KQV", KQV ); + + Tensor KQV_merged = permute( KQV, 0, 2, 1, 3 ); + copyInPlace( cur, KQV_merged, eDataType::FP32, { ldp.n_state, ldp.N } ); + } + + { + profiler.setNextTag( "dec.layer.6" ); + cur = mulMat( layer.attnLn1.w, cur ); + addRepeat( cur, layer.attnLn1.b ); + } + + // add the input + Tensor inpCA = add( cur, inpL ); + + // norm + { + cur = norm( inpCA ); + fmaRepeat( cur, layer.crossAttnLn0 ); + } + + // cross-attention + { + profiler.setNextTag( "dec.layer.7" ); + Tensor Qcur = mulMat( layer.crossAttnQuery.w, cur ); + addRepeatScale( Qcur, layer.crossAttnQuery.b, (float)pow( float( (int)ldp.n_state ) / (int)ldp.n_head, -0.25 ) ); + + // Kcross is already scaled + const uint32_t len = ldp.M * ldp.n_state; + const uint32_t off = (uint32_t)il * len; + Tensor Kcross = kvCross.keys.view( len, off ).reshape3d( ldp.n_state / ldp.n_head, ldp.n_head, ldp.M ); + Tensor Vcross = kvCross.values.view( len, off ).reshape3d( ldp.n_state / ldp.n_head, ldp.n_head, ldp.M ); + + // ------ + Tensor Q = permute( copy( Qcur, eDataType::FP32, { ldp.n_state / ldp.n_head, ldp.n_head, ldp.N } ), 0, 2, 1, 3 ); + Tensor K = permute( Kcross, 0, 2, 1, 3 ); + profiler.setNextTag( "dec.layer.8" ); + Tensor KQ = mulMat( K, Q ); + softMax( KQ ); + Tensor V_trans = permute( Vcross, 1, 2, 0, 3 ); + profiler.setNextTag( "dec.layer.9" ); + Tensor KQV = mulMat( V_trans, KQ ); + if( 0 == il ) Tracing::tensor( "dec-KQV", KQV ); + Tensor KQV_merged = permute( KQV, 0, 2, 1, 3 ); + + copyInPlace( cur, KQV_merged, eDataType::FP32, { ldp.n_state, ldp.N } ); + } + + // projection + { + profiler.setNextTag( "dec.layer.10" ); + cur = mulMat( layer.crossAttnLn1.w, cur ); + addRepeat( cur, layer.crossAttnLn1.b ); + } + // add the input + addInPlace( cur, inpCA ); + Tensor inpFF = cur; + + // feed-forward network + { + // norm + cur = norm( inpFF ); + fmaRepeat( cur, layer.mlpLn ); + + if( gpuInfo.useReshapedMatMul() ) + cur = mulMatEx( layer.mlp0.w, cur, "dec.layer.11" ); + else + { + profiler.setNextTag( "dec.layer.11" ); + cur = mulMat( layer.mlp0.w, cur ); + } + + addRepeatGelu( cur, layer.mlp0.b ); + + // projection + if( gpuInfo.useReshapedMatMul() ) + { + if( cur.ne[ 1 ] != 1 ) + { + const uint16_t tag = profiler.setNextTag( "dec.layer.12" ); + cur = reshapePanels( cur ); + + // The mulMatTiledEx() line creates a layer output tensor. We have a special pool for such tensors so they survive the destruction of the arena. + arenaRaii.emplace( *this, decPool ); + profiler.setNextTag( tag ); + cur = mulMatTiledEx( layer.mlp1.w, cur ); + } + else + { + // The mulMatByRowTiledEx() line creates a layer output tensor. We have a special pool for such tensors so they survive the destruction of the arena. + arenaRaii.emplace( *this, decPool ); + profiler.setNextTag( "dec.layer.12" ); + cur = mulMatByRowTiledEx( layer.mlp1.w, cur ); + } + } + else + { + // The mulMat() line creates a layer output tensor. We have a special pool for such tensors so they survive the destruction of the arena. + arenaRaii.emplace( *this, decPool ); + profiler.setNextTag( "dec.layer.12" ); + cur = mulMat( layer.mlp1.w, cur ); + } + addRepeat( cur, layer.mlp1.b ); + } + + // output from this layer + addInPlace( cur, inpFF ); + return cur; +} + +void WhisperContext::decode( const int* tokens, const int n_tokens, const sDecodeParams& decParams, std::vector<float>& probs, int threads ) +{ + auto cppp = profiler.cpuBlock( Whisper::eCpuBlock::DecodeStep ); + +#if BUILD_HYBRID_VERSION + if( hybridContext ) + { + HybridContext::sDecParams sdp; + sdp.n_threads = threads; + sdp.M = decParams.M; + check( hybridContext->decode( tokens, n_tokens, decParams.n_past, sdp, probs ) ); + return; + } +#endif + + auto prof = profiler.block( eProfilerBlock::DecodeStep ); + CaptureRaii renderdocCapture; + profiler.profileShaders = profileDecodeShaders; + ArenaRaii arenaRaii{ *this, arenas.dec }; + + assert( n_tokens > 0 ); + const uint32_t N = (uint32_t)n_tokens; + decoderInput.resize( N ); + + Tensor embd = decoderInput.embedding( tokens ); + Tensor cur = addRows( gpuModel.dec.tokenEmbedding, gpuModel.dec.positionalEmbedding, embd, decParams.n_past ); + Tracing::tensor( "dec-rows", cur ); + + { + sLayerDecParams ldp; + ldp.n_state = decParams.n_state; + ldp.n_head = decParams.n_head; + ldp.N = N; + ldp.n_ctx = decParams.n_ctx; + ldp.n_past = decParams.n_past; + ldp.M = decParams.M; +#if 1 + for( size_t i = 0; i < decParams.n_text_layer; i++ ) + cur = decodeLayer( cur, i, ldp ); +#else + dbgDecodeTest = decodeLayer( cur, 0, ldp ); + return; +#endif + } + + // norm + cur = norm( cur ); + fmaRepeat( cur, gpuModel.dec.ln ); + + profiler.setNextTag( "dec.logits" ); + cur = mulMat( gpuModel.dec.tokenEmbedding, cur ); + + // logits -> probs + softMax( cur ); + + decoderOutput.copyFromVram( cur ); + assert( decoderOutput.size() == N * decParams.n_vocab ); + + decoderOutput.copyToVector( probs ); + Tracing::vector( "probs", probs ); +} + +__m128i WhisperContext::Arenas::getMemoryUse() const +{ + __m128i res = enc.getMemoryUse(); + res = _mm_add_epi64( res, encLayer.getMemoryUse() ); + res = _mm_add_epi64( res, dec.getMemoryUse() ); + res = _mm_add_epi64( res, decLayer.getMemoryUse() ); + return res; +} + +__m128i WhisperContext::DecoderLayerPool::getMemoryUse() const +{ + size_t cb = result.getCapacity() * 4; + __m128i res = _mm_setzero_si128(); + return _mm_insert_epi64( res, (int64_t)cb, 1 ); +} + +__m128i WhisperContext::getMemoryUse() const +{ + __m128i res = MlContext::getMemoryUse(); + res = _mm_add_epi64( res, arenas.getMemoryUse() ); + res = _mm_add_epi64( res, decPool.getMemoryUse() ); + res = _mm_add_epi64( res, melInput.getMemoryUse() ); + res = _mm_add_epi64( res, kv.getMemoryUse() ); + res = _mm_add_epi64( res, kvCross.getMemoryUse() ); + res = _mm_add_epi64( res, decoderInput.getMemoryUse() ); + res = _mm_add_epi64( res, decoderOutput.getMemoryUse() ); + return res; +}
\ No newline at end of file diff --git a/Whisper/Whisper/WhisperContext.h b/Whisper/Whisper/WhisperContext.h new file mode 100644 index 0000000..227e6b6 --- /dev/null +++ b/Whisper/Whisper/WhisperContext.h @@ -0,0 +1,126 @@ +#pragma once +#include "../ML/MlContext.h" +#include "MelInputTensor.h" +#include "KeyValueBuffers.h" +#include "sEncodeParams.h" +#include "DecoderInputBuffers.h" +#include "DecoderResultBuffer.h" +#include "../ML/TensorsArena.h" +#include "iSpectrogram.h" +#include "../Hybrid/HybridContext.h" +#include <memory> +#include "WhisperModel.h" +#include <tuple> +#include <optional> + +namespace DirectCompute +{ + struct TensorPair; + struct ModelBuffers; + + class WhisperContext : public MlContext + { + struct Arenas + { + TensorsArena enc; + TensorsArena encLayer; + TensorsArena dec; + TensorsArena decLayer; + + Arenas(); + __m128i getMemoryUse() const; + }; + + iTensorArena* currentArena = nullptr; + Arenas arenas; + + // Specialized tensor arena for decoder layer outputs, with just a single tensor + class DecoderLayerPool : public iTensorArena + { + PooledTensor result; + public: + Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final; + void reset() override final { } + __m128i getMemoryUse() const; + }; + + DecoderLayerPool decPool; + + class ArenaRaii; + + MelInputTensor melInput; + KeyValueBuffers kv, kvCross; + DecoderInputBuffers decoderInput; + DecoderResultBuffer decoderOutput; + const ModelBuffers& gpuModel; +#if BUILD_HYBRID_VERSION + std::unique_ptr<HybridContext> hybridContext; +#endif + struct sWhisperMel + { + uint32_t n_len; + uint32_t n_mel; + const std::vector<float>& data; + }; + + void createKeyValueBuffers( const sEncodeParams& encParams ); + // Encoder methods + Tensor convolutionAndGelu( const Tensor& mel, uint32_t n_ctx ); + Tensor encodeLayer( const Tensor& source, size_t index, uint32_t n_state, uint32_t n_head, uint32_t n_ctx ); + + struct sLayerDecParams; + + // Decoder methods + Tensor decodeLayer( const Tensor& source, size_t index, const sLayerDecParams& ldp ); + + // cur = add( mul( repeat( that.w, cur ), cur ), repeat( that.b, cur ) ); + void fmaRepeat( Tensor& cur, const TensorPair& that ); + + Tensor createTensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final; + + public: +#if BUILD_BOTH_VERSIONS + WhisperContext(); + ~WhisperContext(); +#else + ~WhisperContext() = default; +#endif + WhisperContext( const Whisper::WhisperModel& wm, Whisper::ProfileCollection& pc ); + WhisperContext( const WhisperContext& ) = delete; + + Tensor encode( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams ); + + void decode( const int* tokens, const int n_tokens, const sDecodeParams& decParams, std::vector<float>& probs, int threads ); + + static WhisperContext& current(); + + // Create a RAII object which measures both CPU and GPU time for the complete runFull() method + decltype( auto ) completeProfiler() + { + return std::make_tuple( + profiler.cpuBlock( Whisper::eCpuBlock::Run ), + profiler.block( eProfilerBlock::Run ) ); + } + + // Create a RAII object which measures CPU and optionally GPU time for the loop which calls decode() method + decltype( auto ) decodeProfiler() + { +#if BUILD_HYBRID_VERSION + if( hybridContext ) + return std::make_tuple( + profiler.cpuBlock( Whisper::eCpuBlock::Decode ), + std::optional<GpuProfiler::BlockRaii>{} ); + else + return std::make_tuple( + profiler.cpuBlock( Whisper::eCpuBlock::Decode ), + std::optional<GpuProfiler::BlockRaii>{ std::in_place, profiler.block( eProfilerBlock::Decode ) } ); +#else + return std::make_tuple( + profiler.cpuBlock( Whisper::eCpuBlock::Decode ), + profiler.block( eProfilerBlock::Decode ) ); +#endif + } + + __m128i getMemoryUse() const; + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/WhisperModel.cpp b/Whisper/Whisper/WhisperModel.cpp new file mode 100644 index 0000000..28e4540 --- /dev/null +++ b/Whisper/Whisper/WhisperModel.cpp @@ -0,0 +1,511 @@ +#include "stdafx.h" +#include "WhisperModel.h" +#include "loaderUtils.h" +#include "../D3D/createBuffer.h" +#include <atlcoll.h> +#include <atlstr.h> +#include "../Utils/GpuProfilerSimple.h" +#include "../Utils/CpuProfiler.h" +#include "../CPU/HybridLoader.h" +#include "../ML/Reshaper.h" +using namespace Whisper; +using namespace DirectCompute; + +namespace +{ + struct ParamsAndMelHeader + { + sModelParams mp; + uint32_t n_mel = 0, n_fft = 0; + }; + + enum struct ePostProcessing : uint8_t + { + None = 0, + MakePanels = 1 + }; + + struct PendingTensor + { + DirectCompute::Tensor* dest = nullptr; + ePostProcessing postProcessing = ePostProcessing::None; + + PendingTensor() = default; + PendingTensor( const PendingTensor& ) = default; + PendingTensor( DirectCompute::Tensor& tensor, ePostProcessing pp = ePostProcessing::None ) : + dest( &tensor ), postProcessing( pp ) { } + +#if RESHAPED_MATRIX_MULTIPLY + // If you wonder why not reshape them after all tensors are loaded, doing that on the fly is faster because CPU and GPU work in parallel + // In the current version, CPU reads data for a next tensor, while in the meantime GPU reshapes a previously loaded tensor. + HRESULT postProcess( Reshaper& rs, eDataType dt ) + { + switch( postProcessing ) + { + case ePostProcessing::None: + return S_OK; + case ePostProcessing::MakePanels: + if( gpuInfo.useReshapedMatMul() ) + { + // GpuInfo structure says we should use that new method + return rs.makePanels( *dest, dt ); + } + else + { + // The feature ain't enabled on the current user's GPU + return S_OK; + } + default: + return E_UNEXPECTED; + } + } +#endif + }; + + void populateEncodeTensorsMap( CAtlMap<CStringA, PendingTensor>& map, int layersEnc, DirectCompute::ModelBuffers& tensors ) + { + tensors.enc.layers.resize( layersEnc ); + + CStringA tempString; + // Encoder tensors + auto& enc = tensors.enc; + + map[ "encoder.positional_embedding" ] = enc.positionalEmbedding; + map[ "encoder.conv1.weight" ] = enc.conv1.w; + map[ "encoder.conv1.bias" ] = enc.conv1.b; + + map[ "encoder.conv2.weight" ] = enc.conv2.w; + map[ "encoder.conv2.bias" ] = enc.conv2.b; + + map[ "encoder.ln_post.weight" ] = enc.lnPost.w; + map[ "encoder.ln_post.bias" ] = enc.lnPost.b; + + auto add = [ & ]( const char* name, int i, DirectCompute::Tensor& t, ePostProcessing pp = ePostProcessing::None ) + { + tempString.Format( "encoder.blocks.%i.%s", i, name ); + map[ tempString ] = PendingTensor{ t, pp }; + }; + auto add2 = [ & ]( const char* name, int i, DirectCompute::TensorPair& t, ePostProcessing ppWeight = ePostProcessing::None, ePostProcessing ppBias = ePostProcessing::None ) + { + tempString.Format( "encoder.blocks.%i.%s.weight", i, name ); + map[ tempString ] = PendingTensor{ t.w, ppWeight }; + tempString.Format( "encoder.blocks.%i.%s.bias", i, name ); + map[ tempString ] = PendingTensor{ t.b, ppBias }; + }; + + for( int i = 0; i < layersEnc; i++ ) + { + auto& gpu = enc.layers[ i ]; + add2( "mlp_ln", i, gpu.mlpLn ); + add2( "mlp.0", i, gpu.mlp0, ePostProcessing::MakePanels ); + add2( "mlp.2", i, gpu.mlp1, ePostProcessing::MakePanels ); + add2( "attn_ln", i, gpu.attnLn0 ); + add2( "attn.query", i, gpu.attnQuery, ePostProcessing::MakePanels ); + add( "attn.key.weight", i, gpu.attnKey, ePostProcessing::MakePanels ); + add2( "attn.value", i, gpu.attnValue, ePostProcessing::MakePanels ); + add2( "attn.out", i, gpu.attnLn1, ePostProcessing::MakePanels ); + } + } + + void populateDecodeTensorsMap( CAtlMap<CStringA, PendingTensor>& map, int layersDec, DirectCompute::ModelBuffers& tensors, bool hybrid ) + { + tensors.dec.layers.resize( layersDec ); + CStringA tempString; + // Decoder tensors + + auto& dec = tensors.dec; + if( !hybrid ) + { + map[ "decoder.positional_embedding" ] = dec.positionalEmbedding; + map[ "decoder.token_embedding.weight" ] = dec.tokenEmbedding; + map[ "decoder.ln.weight" ] = dec.ln.w; + map[ "decoder.ln.bias" ] = dec.ln.b; + } + + auto add = [ & ]( const char* name, int i, DirectCompute::Tensor& t, ePostProcessing pp = ePostProcessing::None ) + { + tempString.Format( "decoder.blocks.%i.%s", i, name ); + map[ tempString ] = PendingTensor{ t, pp }; + }; + auto add2 = [ & ]( const char* name, int i, DirectCompute::TensorPair& t, ePostProcessing ppWeight = ePostProcessing::None, ePostProcessing ppBias = ePostProcessing::None ) + { + tempString.Format( "decoder.blocks.%i.%s.weight", i, name ); + map[ tempString ] = PendingTensor{ t.w, ppWeight }; + tempString.Format( "decoder.blocks.%i.%s.bias", i, name ); + map[ tempString ] = PendingTensor{ t.b, ppBias }; + }; + + for( int i = 0; i < layersDec; i++ ) + { + auto& gpu = dec.layers[ i ]; + add( "cross_attn.key.weight", i, gpu.crossAttnKey, ePostProcessing::MakePanels ); + add2( "cross_attn.value", i, gpu.crossAttnValue, ePostProcessing::MakePanels ); + if( hybrid ) + continue; + + add2( "mlp_ln", i, gpu.mlpLn ); + add2( "mlp.0", i, gpu.mlp0, ePostProcessing::MakePanels ); + add2( "mlp.2", i, gpu.mlp1, ePostProcessing::MakePanels ); + add2( "attn_ln", i, gpu.attnLn0 ); + add2( "attn.query", i, gpu.attnQuery ); + add( "attn.key.weight", i, gpu.attnKey ); + add2( "attn.value", i, gpu.attnValue ); + add2( "attn.out", i, gpu.attnLn1 ); + add2( "cross_attn_ln", i, gpu.crossAttnLn0 ); + add2( "cross_attn.query", i, gpu.crossAttnQuery ); + add2( "cross_attn.out", i, gpu.crossAttnLn1 ); + } + } + + void populateTensorsMap( CAtlMap<CStringA, PendingTensor>& map, int layersEnc, int layersDec, DirectCompute::ModelBuffers& tensors, bool hybrid ) + { + populateEncodeTensorsMap( map, layersEnc, tensors ); + populateDecodeTensorsMap( map, layersDec, tensors, hybrid ); + } + + struct sTensorHeader + { + int n_dims, length, ftype; + }; + + // compare signed int32 lanes for a <= b + inline __m128i cmple( __m128i a, __m128i b ) + { + __m128i i = _mm_min_epi32( a, b ); + return _mm_cmpeq_epi32( a, i ); + } + + inline bool allPositive( const std::array<int, 4>& ne ) + { + const __m128i v = _mm_loadu_si128( ( const __m128i* )ne.data() ); + const __m128i le = cmple( v, _mm_setzero_si128() ); + return (bool)_mm_testz_si128( le, le ); + } + + inline const char* cstr( const CStringA& s ) { return s; } +} + +class WhisperModel::CallbacksImpl : public CpuCompute::iLoaderProgressSink +{ + sLoadModelCallbacks lmcb; + int64_t fileSize; + + HRESULT gotBytes( int64_t cb ) override final + { + if( nullptr != lmcb.cancel ) + { + HRESULT hr = lmcb.cancel( lmcb.pv ); + CHECK( hr ); + if( S_OK != hr ) + return HRESULT_FROM_WIN32( ERROR_CANCELLED ); + } + + if( nullptr != lmcb.progress ) + { + postponedBytes -= cb; + assert( postponedBytes >= 0 ); + int64_t pos = fileSize - postponedBytes; + const double progressVal = (double)pos / (double)fileSize; + HRESULT hr = lmcb.progress( progressVal, lmcb.pv ); + CHECK( hr ); + } + return S_OK; + } +public: + int64_t postponedBytes; + + CallbacksImpl() + { + lmcb.progress = nullptr; + lmcb.cancel = nullptr; + lmcb.pv = nullptr; + fileSize = 0; + postponedBytes = 0; + } + + HRESULT initialize( ComLight::iReadStream* stm, const sLoadModelCallbacks* rsi ) + { + if( nullptr == rsi ) + return S_OK; + lmcb = *rsi; + if( nullptr != lmcb.progress ) + CHECK( stm->getLength( fileSize ) ); + return S_OK; + } + + HRESULT call( ComLight::iReadStream* stm ) + { + if( nullptr != lmcb.cancel ) + { + HRESULT hr = lmcb.cancel( lmcb.pv ); + CHECK( hr ); + if( S_OK != hr ) + return HRESULT_FROM_WIN32( ERROR_CANCELLED ); + } + + if( nullptr != lmcb.progress ) + { + int64_t pos; + CHECK( stm->getPosition( pos ) ); + pos -= postponedBytes; + const double progressVal = (double)pos / (double)fileSize; + HRESULT hr = lmcb.progress( progressVal, lmcb.pv ); + CHECK( hr ); + } + return S_OK; + } +}; + +HRESULT WhisperModel::loadGpu( ComLight::iReadStream* stm, CallbacksImpl& callbacks ) +{ + CAtlMap<CStringA, PendingTensor> map; + populateTensorsMap( map, parameters.n_audio_layer, parameters.n_text_layer, tensors, false ); + +#if RESHAPED_MATRIX_MULTIPLY + DirectCompute::Reshaper reshape; +#endif + + std::vector<uint8_t> bytesVector; + size_t countLoaded = 0; + CStringA name; + int64_t cb = 0; + while( true ) + { + CHECK( callbacks.call( stm ) ); + + sTensorHeader header; + HRESULT hr = readStruct( stm, header ); + if( hr == E_EOF ) + break; + if( FAILED( hr ) ) + return hr; + if( header.n_dims < 1 || header.n_dims>3 ) + return E_INVALIDARG; + + std::array<int, 4> ne = { 1, 1, 1, 1 }; + CHECK( readBytes( stm, ne.data(), header.n_dims * 4 ) ); + if( !allPositive( ne ) ) + return E_INVALIDARG; + + char* nameBuffer = name.GetBufferSetLength( header.length ); + hr = readBytes( stm, nameBuffer, header.length ); + name.ReleaseBuffer(); + if( FAILED( hr ) ) + return hr; + + auto p = map.Lookup( name ); + if( nullptr == p ) + { + logError( u8"%s: unknown tensor '%s' in model file", __func__, cstr( name ) ); + return E_INVALIDARG; + } + + DirectCompute::eDataType dt; + size_t cbElement; + if( header.ftype == 0 ) + { + dt = DirectCompute::eDataType::FP32; + cbElement = 4; + } + else + { + dt = DirectCompute::eDataType::FP16; + cbElement = 2; + } + + const size_t totalElts = (size_t)(uint32_t)ne[ 0 ] * (uint32_t)ne[ 1 ] * (uint32_t)ne[ 2 ]; + if( totalElts * cbElement > UINT_MAX ) + return DISP_E_OVERFLOW; + + try + { + bytesVector.resize( cbElement * totalElts ); + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + CHECK( readBytes( stm, bytesVector.data(), bytesVector.size() ) ); + cb += bytesVector.size(); + CHECK( p->m_value.dest->createImmutable( dt, ne, bytesVector.data() ) ); +#if RESHAPED_MATRIX_MULTIPLY + CHECK( p->m_value.postProcess( reshape, dt ) ); +#endif + countLoaded++; + } + + if( countLoaded != map.GetCount() ) + { + logError( u8"Not all tensors loaded from model file - expected %zu, got %zu", map.GetCount(), countLoaded ); + return E_INVALIDARG; + } + + constexpr double mulMb = 1.0 / ( 1 << 20 ); + logDebug( u8"Loaded %zu GPU tensors, %g MB VRAM", countLoaded, mulMb * cb ); + return S_OK; +} + +#if BUILD_HYBRID_VERSION +HRESULT WhisperModel::loadHybrid( ComLight::iReadStream* stm, CallbacksImpl& callbacks ) +{ + CAtlMap<CStringA, PendingTensor> map; + populateTensorsMap( map, parameters.n_audio_layer, parameters.n_text_layer, tensors, true ); + +#if RESHAPED_MATRIX_MULTIPLY + DirectCompute::Reshaper reshape; +#endif + + CpuCompute::HybridLoader loader( hybridTensors, parameters.n_text_layer ); + + std::vector<uint8_t> bytesVector; + size_t countLoaded = 0; + CStringA name; + int64_t cb = 0; + while( true ) + { + CHECK( callbacks.call( stm ) ); + + sTensorHeader header; + HRESULT hr = readStruct( stm, header ); + if( hr == E_EOF ) + break; + if( FAILED( hr ) ) + return hr; + if( header.n_dims < 1 || header.n_dims > 3 ) + return E_INVALIDARG; + + std::array<int, 4> ne = { 1, 1, 1, 1 }; + CHECK( readBytes( stm, ne.data(), header.n_dims * 4 ) ); + if( !allPositive( ne ) ) + return E_INVALIDARG; + + char* nameBuffer = name.GetBufferSetLength( header.length ); + hr = readBytes( stm, nameBuffer, header.length ); + name.ReleaseBuffer(); + if( FAILED( hr ) ) + return hr; + + auto p = map.Lookup( name ); + if( nullptr == p ) + { + HRESULT hr = loader.setupTensor( name, header.n_dims, header.ftype, ne, stm, callbacks.postponedBytes ); + if( hr == S_OK ) + continue; + logError( u8"%s: unknown tensor '%s' in model file", __func__, cstr( name ) ); + return E_INVALIDARG; + } + + DirectCompute::eDataType dt; + size_t cbElement; + if( header.ftype == 0 ) + { + dt = DirectCompute::eDataType::FP32; + cbElement = 4; + } + else + { + dt = DirectCompute::eDataType::FP16; + cbElement = 2; + } + + const size_t totalElts = (size_t)(uint32_t)ne[ 0 ] * (uint32_t)ne[ 1 ] * (uint32_t)ne[ 2 ]; + if( totalElts * cbElement > UINT_MAX ) + return DISP_E_OVERFLOW; + + try + { + bytesVector.resize( cbElement * totalElts ); + } + catch( const std::bad_alloc& ) + { + return E_OUTOFMEMORY; + } + CHECK( readBytes( stm, bytesVector.data(), bytesVector.size() ) ); + CHECK( p->m_value.dest->createImmutable( dt, ne, bytesVector.data() ) ); +#if RESHAPED_MATRIX_MULTIPLY + CHECK( p->m_value.postProcess( reshape, dt ) ); +#endif + countLoaded++; + cb += bytesVector.size(); + } + + if( countLoaded != map.GetCount() ) + { + logError( u8"Not all tensors loaded from model file - expected %zu, got %zu", map.GetCount(), countLoaded ); + return E_INVALIDARG; + } + + constexpr double mulMb = 1.0 / ( 1 << 20 ); + logDebug( u8"Loaded %zu GPU tensors, %g MB VRAM", countLoaded, mulMb * cb ); + + CHECK( loader.completeLoad( stm, callbacks ) ); + return S_OK; +} +#endif + +HRESULT WhisperModel::load( ComLight::iReadStream* stm, bool hybrid, const sLoadModelCallbacks* callbacks ) +{ + CpuProfiler cpuPerf; + CallbacksImpl cb; + CHECK( cb.initialize( stm, callbacks ) ); + // verify magic + { + uint32_t magic; + CHECK( readStruct( stm, magic ) ); + if( magic != 0x67676d6c ) + { + logError( u8"Invalid model file, bad magic" ); + return E_INVALIDARG; + } + } + + // hparams and MEL filters + { + ParamsAndMelHeader pmh; + CHECK( readStruct( stm, pmh ) ); + parameters = pmh.mp; + assert( parameters.n_text_state == parameters.n_audio_state ); + + filters.n_mel = pmh.n_mel; + filters.n_fft = pmh.n_fft; + const size_t len = (size_t)filters.n_mel * filters.n_fft; + filters.data.resize( len ); + CHECK( readBytes( stm, filters.data.data(), len * 4 ) ); + + const int64_t cb = len * 4; + constexpr double mulKb = 1.0 / ( 1 << 10 ); + logDebug( u8"Loaded MEL filters, %.1f kb RAM", mulKb * cb ); + } + CHECK( cb.call( stm ) ); + + // Vocabulary + CHECK( vocab.load( stm, parameters.n_vocab ) ); + CHECK( cb.call( stm ) ); + + DirectCompute::GpuProfilerSimple gpuProfiler; + CHECK( gpuProfiler.create() ); + + if( hybrid ) + { +#if BUILD_HYBRID_VERSION + CHECK( loadHybrid( stm, cb ) ) +#else + return E_NOTIMPL; +#endif + } + else + CHECK( loadGpu( stm, cb ) ); + + CHECK( gpuProfiler.time( loadTimeGpu ) ); + loadTimeCpu = cpuPerf.elapsed(); + return S_OK; +} + +__m128i Whisper::WhisperModel::getMemoryUse() const +{ + size_t cb = vocab.getMemoryUse(); + cb += vectorMemoryUse( filters.data ); + __m128i v = _mm_cvtsi64_si128( (int64_t)cb ); + v = _mm_add_epi64( v, tensors.getMemoryUse() ); + return v; +}
\ No newline at end of file diff --git a/Whisper/Whisper/WhisperModel.h b/Whisper/Whisper/WhisperModel.h new file mode 100644 index 0000000..c9b72aa --- /dev/null +++ b/Whisper/Whisper/WhisperModel.h @@ -0,0 +1,54 @@ +#pragma once +#include "Vocabulary.h" +#include "ModelBuffers.h" +#include "../../ComLightLib/streams.h" +#include "../CPU/DecoderTensors.h" +#include "../API/sLoadModelCallbacks.h" +#include "sModelParams.h" + +namespace Whisper +{ + struct Filters + { + uint32_t n_mel; + uint32_t n_fft; + std::vector<float> data; + }; + + // The complete model, as loaded from a GGML binary file. + // The entire model is immutable, and can be safely used from multiple threads in parallel. + // The tensors are uploaded to VRAM and don’t stay in system memory, everything else is in the system RAM. + struct WhisperModel + { + sModelParams parameters; + Vocabulary vocab; + Filters filters; + DirectCompute::ModelBuffers tensors; + +#if BUILD_HYBRID_VERSION + CpuCompute::DecoderTensors hybridTensors; +#endif + + HRESULT load( ComLight::iReadStream* stm, bool hybrid, const sLoadModelCallbacks* callbacks ); + + // A vector of 2 uint64_t values, both numbers are 100 nanosecond ticks: + // 0. The time it took to load the model, measured on CPU + // 1. The time it took to upload all these tensors to VRAM, measured on GPU + __m128i getLoadTimes() const + { + static_assert( offsetof( WhisperModel, loadTimeCpu ) + 8 == offsetof( WhisperModel, loadTimeGpu ) ); + return _mm_loadu_si128( ( const __m128i* )( &loadTimeCpu ) ); + } + + __m128i getMemoryUse() const; + + private: + uint64_t loadTimeCpu = 0; + uint64_t loadTimeGpu = 0; + + class CallbacksImpl; + + HRESULT loadGpu( ComLight::iReadStream* stm, CallbacksImpl& callbacks ); + HRESULT loadHybrid( ComLight::iReadStream* stm, CallbacksImpl& callbacks ); + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/audioConstants.h b/Whisper/Whisper/audioConstants.h new file mode 100644 index 0000000..232ab22 --- /dev/null +++ b/Whisper/Whisper/audioConstants.h @@ -0,0 +1,14 @@ +#pragma once +#include <stdint.h> + +namespace Whisper +{ + // WHISPER_SAMPLE_RATE, 16 kHz + constexpr uint32_t SAMPLE_RATE = 16000; + // WHISPER_N_FFT, 25 milliseconds + constexpr uint32_t FFT_SIZE = 400; + // WHISPER_HOP_LENGTH, 10 milliseconds + constexpr uint32_t FFT_STEP = 160; + // WHISPER_N_MEL + constexpr uint32_t N_MEL = 80; +}
\ No newline at end of file diff --git a/Whisper/Whisper/iSpectrogram.h b/Whisper/Whisper/iSpectrogram.h new file mode 100644 index 0000000..2e3199d --- /dev/null +++ b/Whisper/Whisper/iSpectrogram.h @@ -0,0 +1,38 @@ +#pragma once +#include "audioConstants.h" + +namespace Whisper +{ + __interface iSpectrogram + { + // Make a buffer with length * N_MEL floats, starting at the specified offset + // An implementation of this interface may visualize the spectrogram, making pieces on demand + HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride ); + + // Apparently, the length unit is 160 input samples = 10 milliseconds of audio + size_t getLength() const; + }; + + // RAII class to deal with iSpectrogram's makeBuffer method. + // Throws exceptions when things fail. + class MelBufferRaii + { + const float* pointer; + size_t stride; + public: + + HRESULT make( iSpectrogram& mel, size_t off, size_t len ) + { + return mel.makeBuffer( off, len, &pointer, stride ); + } + + const float* operator[]( size_t idx ) const + { + assert( idx < N_MEL ); + return pointer + idx * stride; + } + + const BYTE* bytePtr() const { return (const BYTE*)pointer; } + LONG strideBytes() const { return (LONG)stride * 4; } + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/languageCodez.inl b/Whisper/Whisper/languageCodez.inl new file mode 100644 index 0000000..c302769 --- /dev/null +++ b/Whisper/Whisper/languageCodez.inl @@ -0,0 +1,100 @@ +// This file is generated by a tool, from the `languageCodez.tsv` file in this repository +Lang{ 0x6661, 68, "afrikaans" }, +Lang{ 0x7173, 58, "albanian" }, +Lang{ 0x6D61, 75, "amharic" }, +Lang{ 0x7261, 13, "arabic" }, +Lang{ 0x7968, 53, "armenian" }, +Lang{ 0x7361, 91, "assamese" }, +Lang{ 0x7A61, 45, "azerbaijani" }, +Lang{ 0x6162, 96, "bashkir" }, +Lang{ 0x7565, 51, "basque" }, +Lang{ 0x6562, 71, "belarusian" }, +Lang{ 0x6E62, 43, "bengali" }, +Lang{ 0x7362, 56, "bosnian" }, +Lang{ 0x7262, 50, "breton" }, +Lang{ 0x6762, 33, "bulgarian" }, +Lang{ 0x6163, 11, "catalan" }, +Lang{ 0x687A, 1, "chinese" }, +Lang{ 0x7268, 32, "croatian" }, +Lang{ 0x7363, 24, "czech" }, +Lang{ 0x6164, 26, "danish" }, +Lang{ 0x6C6E, 12, "dutch" }, +Lang{ 0x6E65, 0, "english" }, +Lang{ 0x7465, 48, "estonian" }, +Lang{ 0x6F66, 79, "faroese" }, +Lang{ 0x6966, 18, "finnish" }, +Lang{ 0x7266, 6, "french" }, +Lang{ 0x6C67, 60, "galician" }, +Lang{ 0x616B, 70, "georgian" }, +Lang{ 0x6564, 2, "german" }, +Lang{ 0x6C65, 22, "greek" }, +Lang{ 0x7567, 74, "gujarati" }, +Lang{ 0x7468, 80, "haitian creole" }, +Lang{ 0x6168, 95, "hausa" }, +Lang{ 0x776168, 93, "hawaiian" }, +Lang{ 0x7769, 20, "hebrew" }, +Lang{ 0x6968, 17, "hindi" }, +Lang{ 0x7568, 27, "hungarian" }, +Lang{ 0x7369, 52, "icelandic" }, +Lang{ 0x6469, 16, "indonesian" }, +Lang{ 0x7469, 15, "italian" }, +Lang{ 0x616A, 7, "japanese" }, +Lang{ 0x776A, 97, "javanese" }, +Lang{ 0x6E6B, 47, "kannada" }, +Lang{ 0x6B6B, 57, "kazakh" }, +Lang{ 0x6D6B, 64, "khmer" }, +Lang{ 0x6F6B, 5, "korean" }, +Lang{ 0x6F6C, 77, "lao" }, +Lang{ 0x616C, 35, "latin" }, +Lang{ 0x766C, 42, "latvian" }, +Lang{ 0x6E6C, 94, "lingala" }, +Lang{ 0x746C, 34, "lithuanian" }, +Lang{ 0x626C, 86, "luxembourgish" }, +Lang{ 0x6B6D, 49, "macedonian" }, +Lang{ 0x676D, 90, "malagasy" }, +Lang{ 0x736D, 23, "malay" }, +Lang{ 0x6C6D, 37, "malayalam" }, +Lang{ 0x746D, 84, "maltese" }, +Lang{ 0x696D, 36, "maori" }, +Lang{ 0x726D, 61, "marathi" }, +Lang{ 0x6E6D, 55, "mongolian" }, +Lang{ 0x796D, 87, "myanmar" }, +Lang{ 0x656E, 54, "nepali" }, +Lang{ 0x6F6E, 29, "norwegian" }, +Lang{ 0x6E6E, 83, "nynorsk" }, +Lang{ 0x636F, 69, "occitan" }, +Lang{ 0x7370, 81, "pashto" }, +Lang{ 0x6166, 41, "persian" }, +Lang{ 0x6C70, 10, "polish" }, +Lang{ 0x7470, 8, "portuguese" }, +Lang{ 0x6170, 62, "punjabi" }, +Lang{ 0x6F72, 25, "romanian" }, +Lang{ 0x7572, 4, "russian" }, +Lang{ 0x6173, 85, "sanskrit" }, +Lang{ 0x7273, 44, "serbian" }, +Lang{ 0x6E73, 65, "shona" }, +Lang{ 0x6473, 73, "sindhi" }, +Lang{ 0x6973, 63, "sinhala" }, +Lang{ 0x6B73, 39, "slovak" }, +Lang{ 0x6C73, 46, "slovenian" }, +Lang{ 0x6F73, 67, "somali" }, +Lang{ 0x7365, 3, "spanish" }, +Lang{ 0x7573, 98, "sundanese" }, +Lang{ 0x7773, 59, "swahili" }, +Lang{ 0x7673, 14, "swedish" }, +Lang{ 0x6C74, 89, "tagalog" }, +Lang{ 0x6774, 72, "tajik" }, +Lang{ 0x6174, 28, "tamil" }, +Lang{ 0x7474, 92, "tatar" }, +Lang{ 0x6574, 40, "telugu" }, +Lang{ 0x6874, 30, "thai" }, +Lang{ 0x6F62, 88, "tibetan" }, +Lang{ 0x7274, 9, "turkish" }, +Lang{ 0x6B74, 82, "turkmen" }, +Lang{ 0x6B75, 21, "ukrainian" }, +Lang{ 0x7275, 31, "urdu" }, +Lang{ 0x7A75, 78, "uzbek" }, +Lang{ 0x6976, 19, "vietnamese" }, +Lang{ 0x7963, 38, "welsh" }, +Lang{ 0x6979, 76, "yiddish" }, +Lang{ 0x6F79, 66, "yoruba" }, diff --git a/Whisper/Whisper/languageCodez.tsv b/Whisper/Whisper/languageCodez.tsv new file mode 100644 index 0000000..4048618 --- /dev/null +++ b/Whisper/Whisper/languageCodez.tsv @@ -0,0 +1,99 @@ +en 0 english +zh 1 chinese +de 2 german +es 3 spanish +ru 4 russian +ko 5 korean +fr 6 french +ja 7 japanese +pt 8 portuguese +tr 9 turkish +pl 10 polish +ca 11 catalan +nl 12 dutch +ar 13 arabic +sv 14 swedish +it 15 italian +id 16 indonesian +hi 17 hindi +fi 18 finnish +vi 19 vietnamese +iw 20 hebrew +uk 21 ukrainian +el 22 greek +ms 23 malay +cs 24 czech +ro 25 romanian +da 26 danish +hu 27 hungarian +ta 28 tamil +no 29 norwegian +th 30 thai +ur 31 urdu +hr 32 croatian +bg 33 bulgarian +lt 34 lithuanian +la 35 latin +mi 36 maori +ml 37 malayalam +cy 38 welsh +sk 39 slovak +te 40 telugu +fa 41 persian +lv 42 latvian +bn 43 bengali +sr 44 serbian +az 45 azerbaijani +sl 46 slovenian +kn 47 kannada +et 48 estonian +mk 49 macedonian +br 50 breton +eu 51 basque +is 52 icelandic +hy 53 armenian +ne 54 nepali +mn 55 mongolian +bs 56 bosnian +kk 57 kazakh +sq 58 albanian +sw 59 swahili +gl 60 galician +mr 61 marathi +pa 62 punjabi +si 63 sinhala +km 64 khmer +sn 65 shona +yo 66 yoruba +so 67 somali +af 68 afrikaans +oc 69 occitan +ka 70 georgian +be 71 belarusian +tg 72 tajik +sd 73 sindhi +gu 74 gujarati +am 75 amharic +yi 76 yiddish +lo 77 lao +uz 78 uzbek +fo 79 faroese +ht 80 haitian creole +ps 81 pashto +tk 82 turkmen +nn 83 nynorsk +mt 84 maltese +sa 85 sanskrit +lb 86 luxembourgish +my 87 myanmar +bo 88 tibetan +tl 89 tagalog +mg 90 malagasy +as 91 assamese +tt 92 tatar +haw 93 hawaiian +ln 94 lingala +ha 95 hausa +ba 96 bashkir +jw 97 javanese +su 98 sundanese
\ No newline at end of file diff --git a/Whisper/Whisper/loaderUtils.h b/Whisper/Whisper/loaderUtils.h new file mode 100644 index 0000000..650556d --- /dev/null +++ b/Whisper/Whisper/loaderUtils.h @@ -0,0 +1,24 @@ +#pragma once +#include "../../ComLightLib/streams.h" + +namespace Whisper +{ + inline HRESULT readBytes( ComLight::iReadStream* stm, void* rdi, size_t cb ) + { + if( cb > INT_MAX ) + return DISP_E_OVERFLOW; + if( cb == 0 ) + return S_FALSE; + int n; + CHECK( stm->read( rdi, (int)cb, n ) ); + if( n != (int)cb ) + return E_EOF; + return S_OK; + } + + template<typename T> + inline HRESULT readStruct( ComLight::iReadStream* stm, T& dest ) + { + return readBytes( stm, &dest, sizeof( T ) ); + } +}
\ No newline at end of file diff --git a/Whisper/Whisper/melSpectrogram.cpp b/Whisper/Whisper/melSpectrogram.cpp new file mode 100644 index 0000000..f297557 --- /dev/null +++ b/Whisper/Whisper/melSpectrogram.cpp @@ -0,0 +1,298 @@ +#include "stdafx.h" +#include <cmath> +#include "melSpectrogram.h" + +namespace Whisper +{ + HanningWindow::HanningWindow() + { + for( int i = 0; i < FFT_SIZE; i++ ) + { + // TODO [low]: use XMVectorCos instead + hann[ i ] = (float)( 0.5 * ( 1.0 - std::cos( ( 2.0 * M_PI * i ) / ( FFT_SIZE ) ) ) ); + } + } + const HanningWindow s_hanning; +} + +namespace +{ + using namespace Whisper; + + uint32_t tempVectorSizeRecursion( uint32_t len ) + { + // out.resize( in.size() * 2 ); + const uint32_t res = len * 2; + if( len == 1 ) + return res; + if( len % 2 == 1 ) + return res; // dft + + const uint32_t even = ( len + 1 ) / 2; + const uint32_t odd = len / 2; + const uint32_t evenFft = tempVectorSizeRecursion( even ); + const uint32_t oddFft = tempVectorSizeRecursion( odd ); + return res + even + odd + evenFft + oddFft; + } + + // 6000 + // const uint32_t tempBufferSize = FFT_SIZE + tempVectorSizeRecursion( FFT_SIZE ); + constexpr uint32_t tempBufferSize = 6000; + + // naive Discrete Fourier Transform + // input is real-valued + // output is complex-valued + inline void dft( const float* rsi, size_t len, float* rdi ) + { + for( size_t k = 0; k < len; k++ ) + { + float re = 0; + float im = 0; + + for( int n = 0; n < len; n++ ) + { + float angle = (float)( 2 * M_PI * (int)k * n / len ); + re += (float)( rsi[ n ] * std::cosf( angle ) ); + im -= (float)( rsi[ n ] * std::sinf( angle ) ); + } + + rdi[ k * 2 + 0 ] = re; + rdi[ k * 2 + 1 ] = im; + } + } + + inline void splitEvenOdd( const float* rsi, size_t len, float* rdiEven, float* rdiOdd ) + { + const float* const rsiEndAligned = rsi + ( len & ( ~(size_t)7 ) ); + const size_t rem = len % 8; + + for( ; rsi < rsiEndAligned; rsi += 8, rdiEven += 4, rdiOdd += 4 ) + { + const __m128 v1 = _mm_loadu_ps( rsi ); + const __m128 v2 = _mm_loadu_ps( rsi + 4 ); + const __m128 e = _mm_shuffle_ps( v1, v2, _MM_SHUFFLE( 2, 0, 2, 0 ) ); + const __m128 o = _mm_shuffle_ps( v1, v2, _MM_SHUFFLE( 3, 1, 3, 1 ) ); + _mm_storeu_ps( rdiEven, e ); + _mm_storeu_ps( rdiOdd, o ); + } + +#pragma loop( no_vector ) + for( size_t i = 0; i < rem; i++, rsi++ ) + { + if( i % 2 == 0 ) + { + *rdiEven = *rsi; + rdiEven++; + } + else + { + *rdiOdd = *rsi; + rdiOdd++; + } + } + } + inline __m128 set2( float f ) + { + __m128 v = _mm_set_ss( f ); + return _mm_moveldup_ps( v ); + } + inline __m128 load2( const float* rsi ) + { + return _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) ); + } + inline void store2( float* rdi, __m128 vec ) + { + _mm_store_sd( (double*)rdi, _mm_castps_pd( vec ) ); + } + // [ x, y ] => [ x, y, x, y ] + inline __m128 dup2( __m128 x ) + { + __m128d v = _mm_castps_pd( x ); + v = _mm_movedup_pd( v ); + return _mm_castpd_ps( v ); + } + inline __m128 load2dup( const float* rsi ) + { + return _mm_castpd_ps( _mm_loaddup_pd( (const double*)rsi ) ); + } + inline void store2high( float* rdi, __m128 vec ) + { + _mm_storeh_pd( (double*)rdi, _mm_castps_pd( vec ) ); + } +} + +using namespace Whisper; + +SpectrogramContext::SpectrogramContext( const Filters& flt ) : + filters( flt ) +{ + assert( tempBufferSize == FFT_SIZE + tempVectorSizeRecursion( FFT_SIZE ) ); + tempBuffer = std::make_unique<float[]>( tempBufferSize ); +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +float* SpectrogramContext::fftRecursion( float* temp, const float* const rsi, const size_t len ) +{ + float* const out = temp; + temp += len * 2; + if( len == 1 ) + { + out[ 0 ] = rsi[ 0 ]; + out[ 1 ] = 0; + return temp; + } + + if( len % 2 == 1 ) + { + dft( rsi, len, out ); + return temp; + } + + const size_t lenEven = ( len + 1 ) / 2; + const size_t lenOdd = len / 2; + float* const even = temp; + temp += lenEven; + + float* const odd = temp; + temp += lenOdd; + splitEvenOdd( rsi, len, even, odd ); + + const float* const evenFft = temp; + temp = fftRecursion( temp, even, lenEven ); + + const float* const oddFft = temp; + temp = fftRecursion( temp, odd, lenOdd ); + + const size_t N = len; + const __m128 maskNegateHigh = _mm_setr_ps( 0, 0, -0.0f, -0.0f ); + for( size_t k = 0; k < N / 2; k++ ) + { + const float theta = (float)( 2 * M_PI * (double)(int)k / N ); + + /* + const float re = std::cosf( theta ); + const float im = -std::sinf( theta ); + + float re_odd = oddFft[ 2 * k + 0 ]; + float im_odd = oddFft[ 2 * k + 1 ]; + + out[ 2 * k + 0 ] = evenFft[ 2 * k + 0 ] + re * re_odd - im * im_odd; + out[ 2 * k + 1 ] = evenFft[ 2 * k + 1 ] + re * im_odd + im * re_odd; + + out[ 2 * ( k + N / 2 ) + 0 ] = evenFft[ 2 * k + 0 ] - re * re_odd + im * im_odd; + out[ 2 * ( k + N / 2 ) + 1 ] = evenFft[ 2 * k + 1 ] - re * im_odd - im * re_odd; + */ + + const __m128 re = _mm_set_ss( std::cosf( theta ) ); + const __m128 im = _mm_set_ss( std::sinf( theta ) ); + __m128 reIm = _mm_shuffle_ps( re, im, _MM_SHUFFLE( 0, 0, 0, 0 ) ); + // [ re, re, im, im ] + reIm = _mm_xor_ps( reIm, maskNegateHigh ); + + // [ re_odd, im_odd ] + __m128 odd = load2( oddFft + 2 * k ); + // [ re_odd, im_odd, im_odd, re_odd ] + odd = _mm_shuffle_ps( odd, odd, _MM_SHUFFLE( 0, 1, 1, 0 ) ); + + // re_odd * re, im_odd * re, im_odd * im, re_odd * im ] + const __m128 products4 = _mm_mul_ps( reIm, odd ); + + // re_odd * re, im_odd * re, re_odd * re, im_odd * re + __m128 prod1 = dup2( products4 ); + // im_odd * im, re_odd * im, im_odd * im, re_odd * im + __m128 prod2 = _mm_movehl_ps( products4, products4 ); + + // re_odd * re, im_odd * re, -re_odd * re, -im_odd * re + prod1 = _mm_xor_ps( prod1, maskNegateHigh ); + // im_odd * im, re_odd * im, -im_odd * im, -re_odd * im + prod2 = _mm_xor_ps( prod2, maskNegateHigh ); + + const __m128 even = load2dup( evenFft + 2 * k ); + __m128 res; + res = _mm_add_ps( even, prod1 ); + res = _mm_addsub_ps( res, prod2 ); + store2( out + 2 * k, res ); + store2high( out + 2 * ( k + N / 2 ), res ); + } + + return temp; +} + +void SpectrogramContext::fft( std::array<float, N_MEL>& rdi, const float* pcm, size_t length ) +{ + assert( length > 0 ); + length = std::min( length, (size_t)FFT_SIZE ); + + float* const temp = tempBuffer.get(); + // Apply Hanning window + for( size_t i = 0; i < length; i++ ) + temp[ i ] = pcm[ i ] * s_hanning[ i ]; + if( length < FFT_SIZE ) + memset( temp + length, 0, ( FFT_SIZE - length ) * 4 ); + + float* const fftOut = temp + FFT_SIZE; + float* bufferEnd = fftRecursion( fftOut, temp, FFT_SIZE ); + assert( bufferEnd == tempBuffer.get() + tempBufferSize ); + + // for( size_t j = 0; j < FFT_SIZE; j++ ) + // fft_out[ j ] = ( fft_out[ 2 * j + 0 ] * fft_out[ 2 * j + 0 ] + fft_out[ 2 * j + 1 ] * fft_out[ 2 * j + 1 ] ); + for( size_t j = 0; j < 4; j++ ) + { + __m128 tmp = load2( fftOut + 2 * j ); + tmp = _mm_mul_ps( tmp, tmp ); + tmp = _mm_add_ss( tmp, _mm_movehdup_ps( tmp ) ); + _mm_store_ss( fftOut + j, tmp ); + } + for( size_t j = 4; j < FFT_SIZE; j += 4 ) + { + __m128 low = _mm_loadu_ps( fftOut + 2 * j ); + __m128 high = _mm_loadu_ps( fftOut + 2 * j + 4 ); + low = _mm_mul_ps( low, low ); + high = _mm_mul_ps( high, high ); + __m128 res = _mm_hadd_ps( low, high ); + _mm_storeu_ps( fftOut + j, res ); + } + + // for( size_t j = 1; j < FFT_SIZE / 2; j++ ) + // fftOut[ j ] += fftOut[ FFT_SIZE - j ]; + for( size_t j = 1; j < 4; j++ ) + fftOut[ j ] += fftOut[ FFT_SIZE - j ]; + for( size_t j = 4; j < FFT_SIZE / 2; j += 4 ) + { + __m128 curr = _mm_loadu_ps( fftOut + j ); + // Too bad _mm_loadr_ps requires alignment + __m128 high = _mm_loadu_ps( fftOut + ( FFT_SIZE - 3 ) - j ); + high = _mm_shuffle_ps( high, high, _MM_SHUFFLE( 0, 1, 2, 3 ) ); + curr = _mm_add_ps( curr, high ); + _mm_storeu_ps( fftOut + j, curr ); + } + + constexpr size_t n_fft = 1 + ( FFT_SIZE / 2 ); + + // mel spectrogram + for( size_t j = 0; j < N_MEL; j++ ) + { + double sum = 0.0; + for( size_t k = 0; k < n_fft; k++ ) + sum += fftOut[ k ] * filters.data[ j * n_fft + k ]; + if( sum < 1e-10 ) + sum = 1e-10; + sum = log10( sum ); + rdi[ j ] = (float)sum; + } + + /* + const float* ptr = rdi.data(); + const float* const ptrEnd = ptr + rdi.size(); + static_assert( 0 == N_MEL % 4 ); + __m128 ax = _mm_loadu_ps( ptr ); + for( ptr += 4; ptr < ptrEnd; ptr += 4 ) + ax = _mm_max_ps( ax, _mm_loadu_ps( ptr ) ); + ax = _mm_max_ps( ax, _mm_movehl_ps( ax, ax ) ); + ax = _mm_max_ss( ax, _mm_movehdup_ps( ax ) ); + return _mm_cvtss_f32( ax ); + */ +}
\ No newline at end of file diff --git a/Whisper/Whisper/melSpectrogram.h b/Whisper/Whisper/melSpectrogram.h new file mode 100644 index 0000000..7e66f06 --- /dev/null +++ b/Whisper/Whisper/melSpectrogram.h @@ -0,0 +1,34 @@ +#pragma once +#include "audioConstants.h" +#include "WhisperModel.h" +#include <memory> + +namespace Whisper +{ + class HanningWindow + { + std::array<float, FFT_SIZE> hann; + public: + HanningWindow(); + + float operator[]( size_t i ) const + { + return hann[ i ]; + } + }; + + extern const HanningWindow s_hanning; + + class SpectrogramContext + { + const Filters& filters; + static float* fftRecursion( float* temp, const float* const rsi, const size_t len ); + std::unique_ptr<float[]> tempBuffer; + + public: + SpectrogramContext( const Filters& flt ); + + // First step of the MEL algorithm, and recursively compute the FFT + void fft( std::array<float, N_MEL>& rdi, const float* pcm, size_t length ); + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/sEncodeParams.h b/Whisper/Whisper/sEncodeParams.h new file mode 100644 index 0000000..428ef0d --- /dev/null +++ b/Whisper/Whisper/sEncodeParams.h @@ -0,0 +1,20 @@ +#pragma once +#include <stdint.h> + +namespace DirectCompute +{ + struct sEncodeParams + { + uint32_t n_ctx, n_mels, mel_offset; + uint32_t layersCount, n_state, n_head; + uint32_t n_audio_ctx, n_text_state, n_text_layer, n_text_ctx; + }; + + struct sDecodeParams + { + uint32_t n_state, n_head; + uint32_t n_ctx, n_past, M; + uint32_t n_text_layer; + uint32_t n_vocab; + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/sModelParams.h b/Whisper/Whisper/sModelParams.h new file mode 100644 index 0000000..0419542 --- /dev/null +++ b/Whisper/Whisper/sModelParams.h @@ -0,0 +1,19 @@ +#pragma once +namespace Whisper +{ + // default hparams (Whisper tiny) + struct sModelParams + { + int n_vocab = 51864; + int n_audio_ctx = 1500; + int n_audio_state = 384; + int n_audio_head = 6; + int n_audio_layer = 4; + int n_text_ctx = 448; + int n_text_state = 384; + int n_text_head = 6; + int n_text_layer = 4; + int n_mels = 80; + int f16 = 1; + }; +}
\ No newline at end of file diff --git a/Whisper/Whisper/sTokenData.h b/Whisper/Whisper/sTokenData.h new file mode 100644 index 0000000..3a50a92 --- /dev/null +++ b/Whisper/Whisper/sTokenData.h @@ -0,0 +1,23 @@ +#pragma once +#include <stdint.h> + +namespace Whisper +{ + using whisper_token = int; + + struct sTokenData + { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + float p; // probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + float vlen; // voice length of the token + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + }; + + +}
\ No newline at end of file diff --git a/Whisper/Whisper/voiceActivityDetection.cpp b/Whisper/Whisper/voiceActivityDetection.cpp new file mode 100644 index 0000000..31763ec --- /dev/null +++ b/Whisper/Whisper/voiceActivityDetection.cpp @@ -0,0 +1,199 @@ +#include "stdafx.h" +#include "voiceActivityDetection.h" +#include <DirectXMath.h> +using namespace Whisper; + +// Initially ported (poorly) from there https://github.com/panmasuo/voice-activity-detection MIT license +// The code is based on that article: +// https://www.researchgate.net/publication/255667085_A_simple_but_efficient_real-time_voice_activity_detection_algorithm + +inline VAD::Feature VAD::defaultPrimaryThresholds() +{ + Feature f; + f.energy = 40; + f.F = 185; + f.SFM = 5; + return f; +} + +VAD::VAD() : + primThresh( defaultPrimaryThresholds() ) +{ + fft_signal = std::make_unique<cplx[]>( FFT_POINTS ); +} + +inline void VAD::fft( cplx* buf, cplx* out, size_t n, size_t step ) +{ + if( step < n ) + { + fft( out, buf, n, step * 2 ); + fft( out + step, buf + step, n, step * 2 ); + + for( size_t i = 0; i < n; i += 2 * step ) + { + // cplx t = cexp(-I * M_PI * i / n) * out[i + step]; + // using namespace std::complex_literals; + const float mul = (float)M_PI * (float)(int)i / (float)(int)n; + // cplx t0 = std::exp( -1.0if * mul ) * out[ i + step ]; + float sine, cosine; + DirectX::XMScalarSinCos( &sine, &cosine, -mul ); + const cplx exponent{ cosine, sine }; + cplx t = exponent * out[ i + step ]; + + buf[ i / 2 ] = out[ i ] + t; + buf[ ( i + n ) / 2 ] = out[ i ] - t; + } + } +} + +void VAD::fft() const +{ + cplx out[ FFT_POINTS ]; + memcpy( &out[ 0 ], fft_signal.get(), FFT_POINTS * sizeof( cplx ) ); + + fft( fft_signal.get(), out, FFT_POINTS, 1 ); +} + +constexpr float mulInt16FromFloat = 32768.0; + +float VAD::computeEnergy( const float* rsi ) +{ + // calculate_energy + double sum = 0; + for( size_t i = 0; i < FFT_POINTS; i++ ) + { + float f = rsi[ i ]; + f *= mulInt16FromFloat; + f *= f; + sum += f; + } + return std::sqrtf( (float)( sum * ( 1.0 / FFT_POINTS ) ) ); +} + +float VAD::computeDominant( const cplx* spectrum ) +{ + // calculate_dominant, reworked heavily + float maxMagSquared = 0; + int maxFreq = 0; + + for( int i = 0; i < FFT_POINTS / 2; i++ ) + { + const float real = (float)spectrum[ i ].real(); + const float imag = (float)spectrum[ i ].imag(); + float sq = real * real + imag * imag; + if( sq <= maxMagSquared ) + continue; + maxMagSquared = sq; + maxFreq = i; + } + return (float)maxFreq * FFT_STEP; +} + +float VAD::computreSpectralFlatnessMeasure( const cplx* spectrum ) +{ + // calculate_sfm + double sum_ari = 0; + double sum_geo = 0; + for( size_t i = 0; i < FFT_POINTS; i++ ) + { + // sig = cabsf( spectrum[ i ] ); + float sig = std::abs( spectrum[ i ] ); + sum_ari += sig; + sum_geo += std::log( sig ); + } + sum_ari = sum_ari / FFT_POINTS; + sum_geo = std::exp( sum_geo / FFT_POINTS ); + return -10.0f * std::log10f( (float)( sum_geo / sum_ari ) ); +} + +void VAD::clear() +{ + memset( &state, 0, sizeof( State ) ); + state.currThresh = primThresh; +} + +size_t VAD::detect( const float* rsi, size_t length ) +{ + // The cryptic numbers in the comments are from section 3 "Proposed VAD Algorithm" of the article, on page 2550, on the right + const size_t frames = length / FFT_POINTS; + if( frames <= 0 ) + { + clear(); + return 0; + } + + // Load detection state from the field + Feature currThresh = state.currThresh; + Feature minFeature = state.minFeature; + Feature curr = state.curr; + + size_t lastSpeech = state.lastSpeech; + float silenceRun = state.silenceRun; + size_t i = state.i; + + // Run the loop just on the [ state.i .. frames ] slice of the input PCM + rsi += i * FFT_POINTS; + for( ; i < frames; i++, rsi += FFT_POINTS ) + { + // 3-2 calculate FFT + for( size_t j = 0; j < FFT_POINTS; j++ ) + { + const float re = rsi[ j ] * mulInt16FromFloat; + fft_signal[ j ] = { re, 0.0f }; + } + fft(); + + // 3-1 + 3-2 calculate features + curr.energy = computeEnergy( rsi ); + curr.F = computeDominant( fft_signal.get() ); + curr.SFM = computreSpectralFlatnessMeasure( fft_signal.get() ); + + // 3-3 calculate minimum value for first 30 frames + if( i == 0 ) + minFeature = curr; + else if( i < 30 ) + { + minFeature.energy = std::min( minFeature.energy, curr.energy ); + minFeature.F = std::min( minFeature.F, curr.F ); + minFeature.SFM = std::min( minFeature.SFM, curr.SFM ); + } + + // 3-4 set thresholds + currThresh.energy = primThresh.energy * std::log10f( minFeature.energy ); + + // 3-5 calculate decision + uint8_t counter = 0; + if( ( curr.energy - minFeature.energy ) >= currThresh.energy ) + counter = 1; + if( ( curr.F - minFeature.F ) >= currThresh.F ) + counter++; + if( ( curr.SFM - minFeature.SFM ) >= currThresh.SFM ) + counter++; + + if( counter > 1 ) + { + // 3-6 If counter > 1 mark the current frame as speech + lastSpeech = ( i + 1 ) * FFT_POINTS; + silenceRun = 0.0f; + } + else + { + silenceRun += 1.0f; + // 3-7 If current frame is marked as silence, update the energy minimum value + minFeature.energy = ( ( silenceRun * minFeature.energy ) + curr.energy ) / ( silenceRun + 1 ); + } + + // 3-8 + currThresh.energy = primThresh.energy * std::log10f( minFeature.energy ); + } + + // Store the updated detection state back into that field + state.currThresh = currThresh; + state.minFeature = minFeature; + state.curr = curr; + state.lastSpeech = (uint32_t)lastSpeech; + state.silenceRun = silenceRun; + state.i = (uint32_t)i; + + return lastSpeech; +}
\ No newline at end of file diff --git a/Whisper/Whisper/voiceActivityDetection.h b/Whisper/Whisper/voiceActivityDetection.h new file mode 100644 index 0000000..80952d2 --- /dev/null +++ b/Whisper/Whisper/voiceActivityDetection.h @@ -0,0 +1,54 @@ +#pragma once +#include <complex> +#include <memory> +#include "audioConstants.h" + +namespace Whisper +{ + class VAD + { + using cplx = std::complex<float>; + std::unique_ptr<cplx[]> fft_signal; + + struct Feature + { + float energy; + float F; + float SFM; + }; + const Feature primThresh; + static Feature defaultPrimaryThresholds(); + + struct State + { + Feature currThresh; + Feature minFeature; + Feature curr; + + uint32_t lastSpeech; + float silenceRun; + uint32_t i; + }; + State state; + + static inline void fft( cplx* buf, cplx* out, size_t n, size_t step ); + void fft() const; + + static float computeEnergy( const float* rsi ); + static float computeDominant( const cplx* spectrum ); + static float computreSpectralFlatnessMeasure( const cplx* spectrum ); + + public: + + VAD(); + + // When no speech is detected, returns 0 + // When speech is detected, returns sample position for the end of the speech + size_t detect( const float* rsi, size_t length ); + + void clear(); + + static constexpr uint32_t FFT_POINTS = 256; + static constexpr float FFT_STEP = (float)SAMPLE_RATE / (float)FFT_POINTS; + }; +}
\ No newline at end of file diff --git a/Whisper/misc.natvis b/Whisper/misc.natvis new file mode 100644 index 0000000..23b22b3 --- /dev/null +++ b/Whisper/misc.natvis @@ -0,0 +1,50 @@ +<?xml version="1.0" encoding="utf-8"?> +<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010"> + <!-- This file is only for Visual Studio debugger, it doesn't effect the binary in any way. + More info: https://learn.microsoft.com/en-us/visualstudio/debugger/create-custom-views-of-native-objects --> + + <Type Name="__m128i"> + <DisplayString>[ { m128i_i32[ 0 ] }, { m128i_i32[ 1 ] }, { m128i_i32[ 2 ] }, { m128i_i32[ 3 ] } ]</DisplayString> + </Type> + <Type Name="__m128"> + <DisplayString>[ { m128_f32[ 0 ] }, { m128_f32[ 1 ] }, { m128_f32[ 2 ] }, { m128_f32[ 3 ] } ]</DisplayString> + </Type> + <Type Name="__m256"> + <DisplayString>[ { m256_f32[ 0 ] }, { m256_f32[ 1 ] }, { m256_f32[ 2 ] }, { m256_f32[ 3 ] }, { m256_f32[ 4 ] }, { m256_f32[ 5 ] }, { m256_f32[ 6 ] }, { m256_f32[ 7 ] } ]</DisplayString> + </Type> + <Type Name="__m256i"> + <DisplayString>[ { m256i_i32[ 0 ] }, { m256i_i32[ 1 ] }, { m256i_i32[ 2 ] }, { m256i_i32[ 3 ] }, { m256i_i32[ 4 ] }, { m256i_i32[ 5 ] }, { m256i_i32[ 6 ] }, { m256i_i32[ 7 ] } ]</DisplayString> + </Type> + + <Type Name="std::array<*,4>"> + <DisplayString>[ { _Elems[ 0 ] }, { _Elems[ 1 ] }, { _Elems[ 2 ] }, { _Elems[ 3 ] } ]</DisplayString> + </Type> + <Type Name="std::array<*,3>"> + <DisplayString>[ { _Elems[ 0 ] }, { _Elems[ 1 ] }, { _Elems[ 2 ] } ]</DisplayString> + </Type> + <Type Name="std::array<*,2>"> + <DisplayString>[ { _Elems[ 0 ] }, { _Elems[ 1 ] } ]</DisplayString> + </Type> + <Type Name="DirectCompute::Tensor"> + <DisplayString>Size { ne }, strides { nb }</DisplayString> + </Type> + <Type Name="CpuCompute::Tensor"> + <DisplayString Condition="m_type==DirectCompute::eDataType::FP16">FP16 { ne }, strides { nb }</DisplayString> + <DisplayString Condition="m_type==DirectCompute::eDataType::FP32">FP32 { ne }, strides { nb }</DisplayString> + <DisplayString Condition="m_type==DirectCompute::eDataType::U32">U32 { ne }, strides { nb }</DisplayString> + <Expand> + <ArrayItems Condition="m_type==DirectCompute::eDataType::FP16"> + <Size>ne._Elems[ 0 ] * ne._Elems[ 1 ] * ne._Elems[ 2 ] * ne._Elems[ 3 ]</Size> + <ValuePointer>(uint16_t*)m_data</ValuePointer> + </ArrayItems> + <ArrayItems Condition="m_type==DirectCompute::eDataType::FP32"> + <Size>ne._Elems[ 0 ] * ne._Elems[ 1 ] * ne._Elems[ 2 ] * ne._Elems[ 3 ]</Size> + <ValuePointer>(float*)m_data</ValuePointer> + </ArrayItems> + <ArrayItems Condition="m_type==DirectCompute::eDataType::U32"> + <Size>ne._Elems[ 0 ] * ne._Elems[ 1 ] * ne._Elems[ 2 ] * ne._Elems[ 3 ]</Size> + <ValuePointer>(uint32_t*)m_data</ValuePointer> + </ArrayItems> + </Expand> + </Type> +</AutoVisualizer>
\ No newline at end of file diff --git a/Whisper/modelFactory.cpp b/Whisper/modelFactory.cpp new file mode 100644 index 0000000..a708551 --- /dev/null +++ b/Whisper/modelFactory.cpp @@ -0,0 +1,19 @@ +#include "stdafx.h" +#include "modelFactory.h" +#include "API/iContext.cl.h" + +HRESULT COMLIGHTCALL Whisper::loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp ) +{ + switch( impl ) + { + case eModelImplementation::GPU: + return loadGpuModel( path, false, callbacks, pp ); + case eModelImplementation::Hybrid: + return loadGpuModel( path, true, callbacks, pp ); + case eModelImplementation::Reference: + return loadReferenceCpuModel( path, pp ); + } + + logError( u8"Unknown model implementation 0x%X", (int)impl ); + return E_INVALIDARG; +}
\ No newline at end of file diff --git a/Whisper/modelFactory.h b/Whisper/modelFactory.h new file mode 100644 index 0000000..ebe77b1 --- /dev/null +++ b/Whisper/modelFactory.h @@ -0,0 +1,11 @@ +#pragma once +#include "API/sLoadModelCallbacks.h" + +namespace Whisper +{ + struct iModel; + + HRESULT __stdcall loadGpuModel( const wchar_t* path, bool hybrid, const sLoadModelCallbacks* callbacks, iModel** pp ); + + HRESULT __stdcall loadReferenceCpuModel( const wchar_t* path, iModel** pp ); +}
\ No newline at end of file diff --git a/Whisper/resource.h b/Whisper/resource.h new file mode 100644 index 0000000..7ca31da --- /dev/null +++ b/Whisper/resource.h @@ -0,0 +1,14 @@ +//{{NO_DEPENDENCIES}} +// Microsoft Visual C++ generated include file. +// Used by Resource.rc + +// Next default values for new objects +// +#ifdef APSTUDIO_INVOKED +#ifndef APSTUDIO_READONLY_SYMBOLS +#define _APS_NEXT_RESOURCE_VALUE 101 +#define _APS_NEXT_COMMAND_VALUE 40001 +#define _APS_NEXT_CONTROL_VALUE 1001 +#define _APS_NEXT_SYMED_VALUE 101 +#endif +#endif diff --git a/Whisper/source.compat/Readme.txt b/Whisper/source.compat/Readme.txt new file mode 100644 index 0000000..affee91 --- /dev/null +++ b/Whisper/source.compat/Readme.txt @@ -0,0 +1 @@ +The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_BOTH_VERSIONS macro in stdafx.h
\ No newline at end of file diff --git a/Whisper/source.compat/convertThings.cpp b/Whisper/source.compat/convertThings.cpp new file mode 100644 index 0000000..0e6e8c2 --- /dev/null +++ b/Whisper/source.compat/convertThings.cpp @@ -0,0 +1,234 @@ +#include "stdafx.h" +#if BUILD_BOTH_VERSIONS +#include "../API/iContext.cl.h" +#include "convertThings.h" +using namespace Whisper; + +sFullParams makeNewParams( const whisper_full_params& wfp ) +{ + assert( nullptr == wfp.encoder_begin_callback ); + assert( nullptr == wfp.new_segment_callback ); + + sFullParams res; + memset( &res, 0, sizeof( res ) ); + + res.strategy = (eSamplingStrategy)wfp.strategy; + res.cpuThreads = wfp.n_threads; + res.n_max_text_ctx = wfp.n_max_text_ctx; + res.offset_ms = wfp.offset_ms; + res.duration_ms = wfp.duration_ms; + + // flags + uint32_t flags = 0; + if( wfp.translate ) flags |= (uint32_t)eFullParamsFlags::Translate; + if( wfp.no_context ) flags |= (uint32_t)eFullParamsFlags::NoContext; + if( wfp.single_segment ) flags |= (uint32_t)eFullParamsFlags::SingleSegment; + if( wfp.print_special ) flags |= (uint32_t)eFullParamsFlags::PrintSpecial; + if( wfp.print_progress ) flags |= (uint32_t)eFullParamsFlags::PrintProgress; + if( wfp.print_realtime ) flags |= (uint32_t)eFullParamsFlags::PrintRealtime; + if( wfp.print_timestamps ) flags |= (uint32_t)eFullParamsFlags::PrintTimestamps; + if( wfp.token_timestamps ) flags |= (uint32_t)eFullParamsFlags::TokenTimestamps; + if( wfp.speed_up ) flags |= (uint32_t)eFullParamsFlags::SpeedupAudio; + res.flags = (eFullParamsFlags)flags; + + res.language = findLanguageKeyA( wfp.language ); + res.thold_pt = wfp.thold_pt; + res.thold_ptsum = wfp.thold_ptsum; + res.max_len = wfp.max_len; + res.greedy.n_past = wfp.greedy.n_past; + res.beam_search.n_past = wfp.beam_search.n_past; + res.beam_search.beam_width = wfp.beam_search.beam_width; + res.beam_search.n_best = wfp.beam_search.n_best; + res.audio_ctx = wfp.audio_ctx; + res.prompt_tokens = wfp.prompt_tokens; + res.prompt_n_tokens = wfp.prompt_n_tokens; + + return res; +} + +namespace +{ + class NewParamsTemp + { + char language[ 5 ]; + iContext* newContext; + pfnNewSegment newSegment; + pfnEncoderBegin encoderBegin; + + static bool encBegin( struct whisper_context* ctx, void* user_data ); + static void newSeg( struct whisper_context* ctx, int n_new, void* user_data ); + + public: + + void initialize( whisper_full_params& res, const Whisper::sFullParams& rsi, Whisper::iContext* context ) + { + *(uint32_t*)( &language[ 0 ] ) = rsi.language; + language[ 4 ] = '\0'; + res.language = language; + + newContext = context; + + if( nullptr != rsi.encoder_begin_callback ) + { + encoderBegin = rsi.encoder_begin_callback; + res.encoder_begin_callback = &encBegin; + res.encoder_begin_callback_user_data = rsi.encoder_begin_callback_user_data; + } + else + { + encoderBegin = nullptr; + res.encoder_begin_callback = nullptr; + res.encoder_begin_callback_user_data = nullptr; + } + + if( nullptr != rsi.new_segment_callback ) + { + newSegment = rsi.new_segment_callback; + res.new_segment_callback = &newSeg; + res.new_segment_callback_user_data = rsi.new_segment_callback_user_data; + } + else + { + newSegment = nullptr; + res.new_segment_callback = nullptr; + res.new_segment_callback_user_data = nullptr; + } + } + }; + + static thread_local NewParamsTemp npTemp; + + bool NewParamsTemp::encBegin( struct whisper_context* ctx, void* user_data ) + { + const NewParamsTemp& tmp = npTemp; + HRESULT hr = tmp.encoderBegin( tmp.newContext, user_data ); + if( SUCCEEDED( hr ) ) + return S_OK == hr; + throw hr; + } + + void NewParamsTemp::newSeg( struct whisper_context* ctx, int n_new, void* user_data ) + { + assert( n_new >= 0 ); + const NewParamsTemp& tmp = npTemp; + HRESULT hr = tmp.newSegment( tmp.newContext, (uint32_t)n_new, user_data ); + if( SUCCEEDED( hr ) ) + return; + throw hr; + } +} + +whisper_full_params makeOldParams( const Whisper::sFullParams& rsi, Whisper::iContext* context ) +{ + whisper_full_params res; + memset( &res, 0, sizeof( res ) ); + + res.strategy = (whisper_sampling_strategy)rsi.strategy; + res.n_threads = rsi.cpuThreads; + res.n_max_text_ctx = rsi.n_max_text_ctx; + res.offset_ms = rsi.offset_ms; + res.duration_ms = rsi.duration_ms; + + // flags + const uint32_t flags = (uint32_t)rsi.flags; + auto hasFlag = [ = ]( eFullParamsFlags bit ) { return 0 != ( flags & (uint32_t)bit ); }; + + res.translate = hasFlag( eFullParamsFlags::Translate ); + res.no_context = hasFlag( eFullParamsFlags::NoContext ); + res.single_segment = hasFlag( eFullParamsFlags::SingleSegment ); + res.print_special = hasFlag( eFullParamsFlags::PrintSpecial ); + res.print_progress = hasFlag( eFullParamsFlags::PrintProgress ); + res.print_realtime = hasFlag( eFullParamsFlags::PrintRealtime ); + res.print_timestamps = hasFlag( eFullParamsFlags::PrintTimestamps ); + res.token_timestamps = hasFlag( eFullParamsFlags::TokenTimestamps ); + res.speed_up = hasFlag( eFullParamsFlags::SpeedupAudio ); + + res.thold_pt = rsi.thold_pt; + res.thold_ptsum = rsi.thold_ptsum; + res.max_len = rsi.max_len; + res.greedy.n_past = rsi.greedy.n_past; + res.beam_search.n_past = rsi.beam_search.n_past; + res.beam_search.beam_width = rsi.beam_search.beam_width; + res.beam_search.n_best = rsi.beam_search.n_best; + res.audio_ctx = rsi.audio_ctx; + res.prompt_tokens = rsi.prompt_tokens; + res.prompt_n_tokens = rsi.prompt_n_tokens; + + NewParamsTemp& tmp = npTemp; + tmp.initialize( res, rsi, context ); + return res; +} + +#include "../Whisper/TranscribeResult.h" +#include <mfapi.h> + +namespace +{ + inline sTimeSpan time( int64_t wt ) + { + int64_t ticks = MFllMulDiv( wt, 10'000'000, 100, 0 ); + return sTimeSpan{ (uint64_t)ticks }; + } + + void makeNewResults( whisper_context* ctx, Whisper::eResultFlags flags, TranscribeResult& res ) + { + const bool makeTokens = 0 != ( flags & eResultFlags::Tokens ); + res.segments.clear(); + res.tokens.clear(); + + const int countSegments = whisper_full_n_segments( ctx ); + res.segments.resize( countSegments ); + const int tokenEot = whisper_token_eot( ctx ); + for( int i = 0; i < countSegments; i++ ) + { + sSegment& seg = res.segments[ i ]; + seg.text = whisper_full_get_segment_text( ctx, i ); + seg.time.begin = time( whisper_full_get_segment_t0( ctx, i ) ); + seg.time.end = time( whisper_full_get_segment_t1( ctx, i ) ); + + seg.firstToken = (uint32_t)res.tokens.size(); + seg.countTokens = 0; + if( !makeTokens ) + continue; + + const int countTokens = whisper_full_n_tokens( ctx, i ); + seg.countTokens = countTokens; + res.tokens.resize( res.tokens.size() + countTokens ); + for( int t = 0; t < countTokens; t++ ) + { + sToken& tok = res.tokens[ seg.firstToken + t ]; + tok.text = whisper_full_get_token_text( ctx, i, t ); + + const whisper_token_data src = whisper_full_get_token_data( ctx, i, t ); + tok.time.begin = time( src.t0 ); + tok.time.end = time( src.t1 ); + tok.probability = src.p; + tok.probabilityTimestamp = src.pt; + tok.ptsum = src.ptsum; + tok.vlen = src.vlen; + tok.id = src.id; + uint32_t flags = 0; + if( src.id >= tokenEot ) + flags |= eTokenFlags::Special; + tok.flags = (eTokenFlags)flags; + } + } + } +} + +HRESULT makeNewResults( whisper_context* ctx, Whisper::eResultFlags flags, Whisper::iTranscribeResult** pp ) +{ + static TranscribeResultStatic trs; + if( flags & eResultFlags::NewObject ) + { + return E_NOTIMPL; + } + else + { + makeNewResults( ctx, flags, trs ); + *pp = &trs; + ( *pp )->AddRef(); + return S_OK; + } +} +#endif
\ No newline at end of file diff --git a/Whisper/source.compat/convertThings.h b/Whisper/source.compat/convertThings.h new file mode 100644 index 0000000..5750734 --- /dev/null +++ b/Whisper/source.compat/convertThings.h @@ -0,0 +1,10 @@ +#pragma once +#include "../source/whisper.h" +#include "../API/sFullParams.h" +#include "../API/iTranscribeResult.cl.h" + +Whisper::sFullParams makeNewParams( const whisper_full_params& rsi ); + +whisper_full_params makeOldParams( const Whisper::sFullParams& rsi, Whisper::iContext* context ); + +HRESULT makeNewResults( whisper_context* ctx, Whisper::eResultFlags flags, Whisper::iTranscribeResult** pp );
\ No newline at end of file diff --git a/Whisper/source.compat/ggmlMsvc.c b/Whisper/source.compat/ggmlMsvc.c new file mode 100644 index 0000000..5a9f340 --- /dev/null +++ b/Whisper/source.compat/ggmlMsvc.c @@ -0,0 +1,37 @@ +#include <stdint.h> +#include <immintrin.h> +#include <stdio.h> +#include <assert.h> +#include "../source/ggml.h" + +__forceinline float _cvtsh_ss( uint16_t f16 ) +{ + __m128i i = _mm_cvtsi32_si128( f16 ); + __m128 f = _mm_cvtph_ps( i ); + return _mm_cvtss_f32( f ); +} + +__forceinline uint16_t _cvtss_sh( float f, int rounding ) +{ + assert( 0 == rounding ); + __m128 v = _mm_set_ss( f ); + __m128i i = _mm_cvtps_ph( v, 0 ); + return (uint16_t)(uint32_t)_mm_cvtsi128_si32( i ); +} + +FILE* fopen_msvc( const char* filename, const char* mode ) +{ + FILE* stream; + errno_t err = fopen_s( &stream, filename, mode ); + if( err == 0 ) + return stream; + return NULL; +} + +#define fopen fopen_msvc + +#include "../ML/testUtilsC.h" + +#define __F16C__ +#define __FMA__ +#include "../source/ggml.c"
\ No newline at end of file diff --git a/Whisper/source/LICENSE b/Whisper/source/LICENSE new file mode 100644 index 0000000..fb7ff0c --- /dev/null +++ b/Whisper/source/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Georgi Gerganov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Whisper/source/Readme.txt b/Whisper/source/Readme.txt new file mode 100644 index 0000000..affee91 --- /dev/null +++ b/Whisper/source/Readme.txt @@ -0,0 +1 @@ +The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_BOTH_VERSIONS macro in stdafx.h
\ No newline at end of file diff --git a/Whisper/source/ggml.c b/Whisper/source/ggml.c new file mode 100644 index 0000000..c318b20 --- /dev/null +++ b/Whisper/source/ggml.c @@ -0,0 +1,8336 @@ +#include "ggml.h" + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include <malloc.h> // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) +#include <alloca.h> +#endif + +#include <assert.h> +#include <time.h> +#include <math.h> +#include <stdlib.h> +#include <string.h> +#include <stdint.h> +#include <stdio.h> + +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef static_assert +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif + +#if defined _MSC_VER || defined(__MINGW32__) + +#if !defined(__MINGW32__) +#include <Windows.h> +#else +// ref: https://github.com/ggerganov/whisper.cpp/issues/168 +#include <windows.h> +#include <errno.h> +#endif + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; + +static void atomic_store(atomic_int* ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int* ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { + return atomic_fetch_add(ptr, -(dec)); +} + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(pthread_t thread, void* unused) { + return (int) WaitForSingleObject(thread, INFINITE); +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} +#else +#include <pthread.h> +#include <stdatomic.h> + +typedef void* thread_ret_t; +#endif + +#ifdef __HAIKU__ +#define static_assert(cond, msg) _Static_assert(cond, msg) +#endif + +#define GGML_DEBUG 0 +#define GGML_GELU_FP16 + +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +#define UNUSED(x) (void)(x) +#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) + +#define GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + logError( u8"GGML_ASSERT: %s:%d: %s", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#ifdef GGML_USE_ACCELERATE +#include <Accelerate/Accelerate.h> +#elif GGML_USE_OPENBLAS +#include <cblas.h> +#endif + +// floating point type used to accumulate sums +typedef double ggml_float; + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#ifdef __ARM_NEON + +// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include <arm_neon.h> + +float ggml_fp16_to_fp32(ggml_fp16_t x) { + return x; +} + +ggml_fp16_t ggml_fp32_to_fp16(float x) { + return x; +} + +#define GGML_FP16_TO_FP32(x) (x) +#define GGML_FP32_TO_FP16(x) (x) + +#else + +#ifdef __wasm_simd128__ +#include <wasm_simd128.h> +#else +#ifdef __POWER9_VECTOR__ +#include <altivec.h> +#undef bool +#define bool _Bool +#else +#include <immintrin.h> +#endif +#endif + +#ifdef __F16C__ +float ggml_fp16_to_fp32(ggml_fp16_t h) { + return _cvtsh_ss(h); +} +ggml_fp16_t ggml_fp32_to_fp16(float f) { + return _cvtss_sh(f, 0); +} + +#define GGML_FP16_TO_FP32(x) _cvtsh_ss(x) +#define GGML_FP32_TO_FP16(x) _cvtss_sh(x, 0) + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +float ggml_fp16_to_fp32(ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +ggml_fp16_t ggml_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define GGML_FP16_TO_FP32(x) ggml_fp16_to_fp32(x) +#define GGML_FP32_TO_FP16(x) ggml_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // __ARM_NEON + +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static ggml_fp16_t table_gelu_f16[1 << 16]; + +// precomputed exp table for f16 (128 KB) +static ggml_fp16_t table_exp_f16[1 << 16]; + +// +// timing +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +static int64_t timer_freq; +void ggml_time_init(void) { + LARGE_INTEGER frequency; + QueryPerformanceFrequency(&frequency); + timer_freq = frequency.QuadPart; +} +int64_t ggml_time_ms(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return (t.QuadPart * 1000) / timer_freq; +} +int64_t ggml_time_us(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return (t.QuadPart * 1000000) / timer_freq; +} +#else +void ggml_time_init(void) {} +int64_t ggml_time_ms(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; +} + +int64_t ggml_time_us(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; +} +#endif + +int64_t ggml_cycles(void) { + return clock(); +} + +int64_t ggml_cycles_per_ms(void) { + return CLOCKS_PER_SEC/1000; +} + +#ifdef GGML_PERF +#define ggml_perf_time_ms() ggml_time_ms() +#define ggml_perf_time_us() ggml_time_us() +#define ggml_perf_cycles() ggml_cycles() +#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() +#else +#define ggml_perf_time_ms() 0 +#define ggml_perf_time_us() 0 +#define ggml_perf_cycles() 0 +#define ggml_perf_cycles_per_ms() 0 +#endif + +// +// cache line +// + +#if defined(__cpp_lib_hardware_interference_size) +#define CACHE_LINE_SIZE hardware_destructive_interference_size +#else +#define CACHE_LINE_SIZE 64 +#endif + +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); + +// +// simd mappings +// + +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// GGML_F32_STEP / GGML_F16_STEP +// number of elements to process in a single step +// +// GGML_F32_EPR / GGML_F16_EPR +// number of elements to fit in a single register +// + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 NEON + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 float32x4_t +#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define GGML_F32x4_LOAD vld1q_f32 +#define GGML_F32x4_STORE vst1q_f32 +#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define GGML_F32x4_ADD vaddq_f32 +#define GGML_F32x4_MUL vmulq_f32 +#if defined(__ARM_FEATURE_QRDMX) + #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#else + #define GGML_F32x4_REDUCE_ONE(x) \ + (vgetq_lane_f32(x, 0) + \ + vgetq_lane_f32(x, 1) + \ + vgetq_lane_f32(x, 2) + \ + vgetq_lane_f32(x, 3)) +#endif +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \ + } \ + res = GGML_F32x4_REDUCE_ONE(x[0]); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD vld1q_f16 + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + { \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ + res = vaddvq_f32(vaddq_f32(t0, t1)); \ + } + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD GGML_F16x8_LOAD + #define GGML_F16_VEC_STORE GGML_F16x8_STORE + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD GGML_F32Cx4_LOAD + #define GGML_F16_VEC_STORE GGML_F32Cx4_STORE + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__AVX__) + +#define GGML_SIMD + +// F32 AVX + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO _mm256_setzero_ps() +#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32x8_LOAD _mm256_loadu_ps +#define GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define GGML_F32x8_ADD _mm256_add_ps +#define GGML_F32x8_MUL _mm256_mul_ps +#define GGML_F32x8_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 AVX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead +// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32 + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD _mm256_add_ps +#define GGML_F32Cx8_MUL _mm256_mul_ps +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD GGML_F32Cx8_LOAD +#define GGML_F16_VEC_STORE GGML_F32Cx8_STORE +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +// TODO: uncomment this when it works +//#define GGML_SIMD + +// F32 POWER9 + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +// TODO: not tested !! +#define GGML_F32x4 __vector float +#define GGML_F32x4_ZERO (__vector float){0.0f, 0.0f, 0.0f, 0.0f} +#define GGML_F32x4_SET1(x) (__vector float){x, x, x, x} +#define GGML_F32x4_LOAD vec_vsx_ld +#define GGML_F32x4_STORE vec_vsx_st +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vec_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vec_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vec_add(x[8*i], x[8*i+4]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 POWER9 +// TODO: implement here +// ... + +#elif defined(__wasm_simd128__) + +#define GGML_SIMD + +// F32 WASM + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 v128_t +#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F32x4_LOAD wasm_v128_load +#define GGML_F32x4_STORE wasm_v128_store +#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define GGML_F32x4_ADD wasm_f32x4_add +#define GGML_F32x4_MUL wasm_f32x4_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 WASM + +#define GGML_F16_STEP 16 +#define GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(p[0]); + tmp[1] = GGML_FP16_TO_FP32(p[1]); + tmp[2] = GGML_FP16_TO_FP32(p[2]); + tmp[3] = GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = GGML_FP32_TO_FP16(tmp[0]); + p[1] = GGML_FP32_TO_FP16(tmp[1]); + p[2] = GGML_FP32_TO_FP16(tmp[2]); + p[3] = GGML_FP32_TO_FP16(tmp[3]); +} + +#define GGML_F16x4 v128_t +#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define GGML_F16x4_FMA GGML_F32x4_FMA +#define GGML_F16x4_ADD wasm_f32x4_add +#define GGML_F16x4_MUL wasm_f32x4_mul +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F16_VEC GGML_F16x4 +#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 +#define GGML_F16_VEC_LOAD GGML_F16x4_LOAD +#define GGML_F16_VEC_STORE GGML_F16x4_STORE +#define GGML_F16_VEC_FMA GGML_F16x4_FMA +#define GGML_F16_VEC_ADD GGML_F16x4_ADD +#define GGML_F16_VEC_MUL GGML_F16x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE + +#endif + +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif + +// +// fundamental operations +// + +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } + +inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { + ggml_float sumf = 0.0; + +#ifdef GGML_SIMD + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + sumf += x[i]*y[i]; + } +#endif + + *s = sumf; +} + +inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + ggml_float sumf = 0.0; + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR); + + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + } +#elif defined(__POWER9_VECTOR__) + // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without + // being able to test it. hoping someone with access to a POWER9 machine can help out here. + const int n32 = (n & ~31); + + vector float sum0 = vec_splats (0.0f); + + for (int i = 0; i < n32; i += 32) { + // Use vec_xl, not vec_ld, because x is sometimes unaligned. + vector unsigned short x0 = vec_xl(i * 2 + 0, x); + vector unsigned short x1 = vec_xl(i * 2 + 16, x); + vector unsigned short x2 = vec_xl(i * 2 + 32, x); + vector unsigned short x3 = vec_xl(i * 2 + 48, x); + + vector unsigned short y0 = vec_xl(i * 2 + 0, y); + vector unsigned short y1 = vec_xl(i * 2 + 16, y); + vector unsigned short y2 = vec_xl(i * 2 + 32, y); + vector unsigned short y3 = vec_xl(i * 2 + 48, y); + + vector float fx0l = vec_extract_fp32_from_shortl(x0); + vector float fx0h = vec_extract_fp32_from_shorth(x0); + vector float fx1l = vec_extract_fp32_from_shortl(x1); + vector float fx1h = vec_extract_fp32_from_shorth(x1); + vector float fx2l = vec_extract_fp32_from_shortl(x2); + vector float fx2h = vec_extract_fp32_from_shorth(x2); + vector float fx3l = vec_extract_fp32_from_shortl(x3); + vector float fx3h = vec_extract_fp32_from_shorth(x3); + + vector float fy0l = vec_extract_fp32_from_shortl(y0); + vector float fy0h = vec_extract_fp32_from_shorth(y0); + vector float fy1l = vec_extract_fp32_from_shortl(y1); + vector float fy1h = vec_extract_fp32_from_shorth(y1); + vector float fy2l = vec_extract_fp32_from_shortl(y2); + vector float fy2h = vec_extract_fp32_from_shorth(y2); + vector float fy3l = vec_extract_fp32_from_shortl(y3); + vector float fy3h = vec_extract_fp32_from_shorth(y3); + + sum0 = vec_add(sum0, vec_mul(fx0l, fy0l)); + sum0 = vec_add(sum0, vec_mul(fx0h, fy0h)); + sum0 = vec_add(sum0, vec_mul(fx1l, fy1l)); + sum0 = vec_add(sum0, vec_mul(fx1h, fy1h)); + sum0 = vec_add(sum0, vec_mul(fx2l, fy2l)); + sum0 = vec_add(sum0, vec_mul(fx2h, fy2h)); + sum0 = vec_add(sum0, vec_mul(fx3l, fy3l)); + sum0 = vec_add(sum0, vec_mul(fx3h, fy3h)); + } + + sumf = vec_extract(sum0, 0) + vec_extract(sum0, 1) + + vec_extract(sum0, 2) + vec_extract(sum0, 3); + + for (int i = n32; i < n; ++i) { + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + } +#else + for (int i = 0; i < n; ++i) { + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + } +#endif + + *s = sumf; +} + +inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += x[i]*v; + } +#endif +} + +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + GGML_ASSERT(false); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#elif defined(__POWER9_VECTOR__) + // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without + // being able to test it. hoping someone with access to a POWER9 machine can help out here. + const int n32 = (n & ~31); + for (int i = 0; i < n32; i += 32) { + // Use vec_xl, not vec_ld, because x is sometimes unaligned! + vector unsigned short x0 = vec_xl(i * 2 + 0, x); + vector unsigned short x1 = vec_xl(i * 2 + 16, x); + vector unsigned short x2 = vec_xl(i * 2 + 32, x); + vector unsigned short x3 = vec_xl(i * 2 + 48, x); + + vector unsigned short y0 = vec_xl(i * 2 + 0, y); + vector unsigned short y1 = vec_xl(i * 2 + 16, y); + vector unsigned short y2 = vec_xl(i * 2 + 32, y); + vector unsigned short y3 = vec_xl(i * 2 + 48, y); + + vector float v4 = vec_splats(v); + + vector float fx0l = vec_extract_fp32_from_shortl(x0); + vector float fx0h = vec_extract_fp32_from_shorth(x0); + vector float fx1l = vec_extract_fp32_from_shortl(x1); + vector float fx1h = vec_extract_fp32_from_shorth(x1); + vector float fx2l = vec_extract_fp32_from_shortl(x2); + vector float fx2h = vec_extract_fp32_from_shorth(x2); + vector float fx3l = vec_extract_fp32_from_shortl(x3); + vector float fx3h = vec_extract_fp32_from_shorth(x3); + + vector float fy0l = vec_extract_fp32_from_shortl(y0); + vector float fy0h = vec_extract_fp32_from_shorth(y0); + vector float fy1l = vec_extract_fp32_from_shortl(y1); + vector float fy1h = vec_extract_fp32_from_shorth(y1); + vector float fy2l = vec_extract_fp32_from_shortl(y2); + vector float fy2h = vec_extract_fp32_from_shorth(y2); + vector float fy3l = vec_extract_fp32_from_shortl(y3); + vector float fy3h = vec_extract_fp32_from_shorth(y3); + + fy0l = vec_madd(fx0l, v4, fy0l); + fy0h = vec_madd(fx0h, v4, fy0h); + fy1l = vec_madd(fx1l, v4, fy1l); + fy1h = vec_madd(fx1h, v4, fy1h); + fy2l = vec_madd(fx2l, v4, fy2l); + fy2h = vec_madd(fx2h, v4, fy2h); + fy3l = vec_madd(fx3l, v4, fy3l); + fy3h = vec_madd(fx3h, v4, fy3h); + + y0 = vec_pack_to_short_fp32(fy0h, fy0l); + y1 = vec_pack_to_short_fp32(fy1h, fy1l); + y2 = vec_pack_to_short_fp32(fy2h, fy2l); + y3 = vec_pack_to_short_fp32(fy3h, fy3l); + + vec_xst(y0, i * 2 + 0, y); + vec_xst(y1, i * 2 + 16, y); + vec_xst(y2, i * 2 + 32, y); + vec_xst(y3, i * 2 + 48, y); + } + + for (int i = n32; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + +//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + +inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); } +inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); } +inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } + +static const ggml_float GELU_COEF_A = 0.044715; +static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; + +inline static float ggml_gelu_f32(float x) { + return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); +} + +inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = table_gelu_f16[i16[i]]; + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]); + } +} +#endif + +inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; } +inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } + +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) logDebug( __VA_ARGS__ ) + +// +// data types +// + +static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { + sizeof(int8_t ), + sizeof(int16_t), + sizeof(int32_t), + sizeof(ggml_fp16_t), + sizeof(float ), +}; + +static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { + "NONE", + + "DUP", + "ADD", + "SUB", + "MUL", + "DIV", + "SQR", + "SQRT", + "SUM", + "MEAN", + "REPEAT", + "ABS", + "SGN", + "NEG", + "STEP", + "RELU", + "GELU", + "NORM", + + "MUL_MAT", + + "SCALE", + "CPY", + "RESHAPE", + "VIEW", + "PERMUTE", + "TRANSPOSE", + "GET_ROWS", + "DIAG_MASK_INF", + "SOFT_MAX", + "ROPE", + "CONV_1D_1S", + "CONV_1D_2S", + + "FLASH_ATTN", + "FLASH_FF", +}; + +static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "none", + + "x", + "x+y", + "x-y", + "x*y", + "x/y", + "x^2", + "√x", + "Σx", + "Σx/n", + "repeat(x)", + "abs(x)", + "sgn(x)", + "-x", + "step(x)", + "relu(x)", + "gelu(x)", + "norm(x)", + + "X*Y", + + "x*v", + "x-\\>y", + "reshape(x)", + "view(x)", + "permute(x)", + "transpose(x)", + "get_rows(x)", + "diag_mask_inf(x)", + "soft_max(x)", + "rope(x)", + "conv_1d_1s(x)", + "conv_1d_2s(x)", + + "flash_attn(x)", + "flash_ff(x)", +}; + +// +// ggml object +// + +struct ggml_object { + size_t offset; + size_t size; + + struct ggml_object * next; + + char padding[8]; +}; + +static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); + +static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); +static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); + +// +// ggml context +// + +struct ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + + int n_objects; + + struct ggml_object * objects_begin; + struct ggml_object * objects_end; +}; + +struct ggml_context_container { + bool used; + + struct ggml_context context; +}; + +// +// compute types +// + +enum ggml_task_type { + GGML_TASK_INIT = 0, + GGML_TASK_COMPUTE, + GGML_TASK_FINALIZE, +}; + +struct ggml_compute_params { + enum ggml_task_type type; + + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; +}; + +// +// ggml state +// + +struct ggml_state { + struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; +}; + +// global state +static struct ggml_state g_state; +static atomic_int g_state_barrier = 0; + +// barrier via spin lock +inline static void ggml_critical_section_start() { + int processing = atomic_fetch_add(&g_state_barrier, 1); + + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); // TODO: reconsider this + processing = atomic_fetch_add(&g_state_barrier, 1); + } +} + +// TODO: make this somehow automatically executed +// some sort of "sentry" mechanism +inline static void ggml_critical_section_end() { + atomic_fetch_sub(&g_state_barrier, 1); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_print_object(const struct ggml_object * obj) { + GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n", + obj->offset, obj->size, (const void *) obj->next); +} + +void ggml_print_objects(const struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); + + while (obj != NULL) { + ggml_print_object(obj); + obj = obj->next; + } + + GGML_PRINT("%s: --- end ---\n", __func__); +} + +int ggml_nelements(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +int ggml_nrows(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +size_t ggml_nbytes(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type]; +} + +size_t ggml_type_size(enum ggml_type type) { + return GGML_TYPE_SIZE[type]; +} + +size_t ggml_element_size(const struct ggml_tensor * tensor) { + return GGML_TYPE_SIZE[tensor->type]; +} + +bool ggml_is_scalar(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_is_vector(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_is_matrix(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +bool ggml_is_contiguous(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && + tensor->nb[1] == tensor->nb[0]*tensor->ne[0] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0] ) && + (t0->ne[1] == t1->ne[1] ) && + (t0->ne[2] == t1->ne[2] ) && + (t0->ne[3] == t1->ne[3] ); +} + +// check if t1 can be represented as a repeatition of t0 +bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t1->ne[0]%t0->ne[0] == 0) && + (t1->ne[1]%t0->ne[1] == 0) && + (t1->ne[2]%t0->ne[2] == 0) && + (t1->ne[3]%t0->ne[3] == 0); +} + +int ggml_up32(int n) { + return (n + 31) & ~31; +} + +int ggml_up64(int n) { + return (n + 63) & ~63; +} + +// assert that pointer is aligned to GGML_MEM_ALIGN +#define ggml_assert_aligned(ptr) \ + assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_context * ggml_init(struct ggml_init_params params) { + // make this function thread safe + ggml_critical_section_start(); + + static bool is_first_call = true; + + if (is_first_call) { + // initialize GELU and EXP tables + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + ggml_fp16_t ii; + for (int i = 0; i < (1 << 16); ++i) { + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = GGML_FP16_TO_FP32(ii); + table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f)); + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + // initialize g_state + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + g_state = (struct ggml_state) { + /*.contexts =*/ { 0 }, + }; + + for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { + g_state.contexts[i].used = false; + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + is_first_call = false; + } + + // find non-used context in g_state + struct ggml_context * ctx = NULL; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (!g_state.contexts[i].used) { + g_state.contexts[i].used = true; + ctx = &g_state.contexts[i].context; + + GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); + break; + } + } + + if (ctx == NULL) { + GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); + + ggml_critical_section_end(); + + return NULL; + } + + *ctx = (struct ggml_context) { + .mem_size = params.mem_size, + .mem_buffer = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), + .mem_buffer_owned = params.mem_buffer ? false : true, + .n_objects = 0, + .objects_begin = NULL, + .objects_end = NULL, + }; + + ggml_assert_aligned(ctx->mem_buffer); + + GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + ggml_critical_section_end(); + + return ctx; +} + +void ggml_free(struct ggml_context * ctx) { + // make this function thread safe + ggml_critical_section_start(); + + bool found = false; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (&g_state.contexts[i].context == ctx) { + g_state.contexts[i].used = false; + + GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", + __func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size); + + if (ctx->mem_buffer_owned) { + free(ctx->mem_buffer); + } + + found = true; + break; + } + } + + if (!found) { + GGML_PRINT_DEBUG("%s: context not found\n", __func__); + } + + ggml_critical_section_end(); +} + +size_t ggml_used_mem(const struct ggml_context * ctx) { + return ctx->objects_end->offset + ctx->objects_end->size; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_tensor * ggml_new_tensor_impl( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int* ne, + void* data) { + // always insert objects at the end of the context's memory pool + struct ggml_object * obj_cur = ctx->objects_end; + + const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offset + cur_size; + + size_t size_needed = 0; + + if (data == NULL) { + size_needed += GGML_TYPE_SIZE[type]; + for (int i = 0; i < n_dims; i++) { + size_needed *= ne[i]; + } + // align to GGML_MEM_ALIGN + size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN; + + } + size_needed += sizeof(struct ggml_tensor); + + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__); + assert(false); + return NULL; + } + + char * const mem_buffer = ctx->mem_buffer; + + struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + + *obj_new = (struct ggml_object) { + .offset = cur_end + GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + }; + + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } + + ctx->objects_end = obj_new; + + //GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end); + + struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset); + + ggml_assert_aligned(result); + + *result = (struct ggml_tensor) { + /*.type =*/ type, + /*.n_dims =*/ n_dims, + /*.ne =*/ { 1, 1, 1, 1 }, + /*.nb =*/ { 0, 0, 0, 0 }, + /*.op =*/ GGML_OP_NONE, + /*.is_param =*/ false, + /*.grad =*/ NULL, + /*.src0 =*/ NULL, + /*.src1 =*/ NULL, + /*.opt =*/ { NULL }, + /*.n_tasks =*/ 0, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + /*.data =*/ data == NULL ? (void *)(result + 1) : data, + /*.pad =*/ { 0 }, + }; + + ggml_assert_aligned(result->data); + + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } + + result->nb[0] = GGML_TYPE_SIZE[type]; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; + } + + ctx->n_objects++; + + return result; +} + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int* ne) { + return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL); +} + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0) { + return ggml_new_tensor(ctx, type, 1, &ne0); +} + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1) { + const int ne[2] = { ne0, ne1 }; + return ggml_new_tensor(ctx, type, 2, ne); +} + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2) { + const int ne[3] = { ne0, ne1, ne2 }; + return ggml_new_tensor(ctx, type, 3, ne); +} + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2, + int ne3) { + const int ne[4] = { ne0, ne1, ne2, ne3 }; + return ggml_new_tensor(ctx, type, 4, ne); +} + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ggml_set_i32(result, value); + + return result; +} + +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + + ggml_set_f32(result, value); + + return result; +} + +struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { + return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); +} + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { + memset(tensor->data, 0, ggml_nbytes(tensor)); + return tensor; +} + +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } + + return tensor; +} + +struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } + + return tensor; +} + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +void * ggml_get_data(const struct ggml_tensor * tensor) { + return tensor->data; +} + +float * ggml_get_data_f32(const struct ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + return (float *)(tensor->data); +} + +struct ggml_tensor * ggml_view_tensor( + struct ggml_context * ctx, + const struct ggml_tensor * src) { + return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data); +} + +//////////////////////////////////////////////////////////////////////////////// + +// ggml_dup + +struct ggml_tensor * ggml_dup_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DUP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, true); +} + +// ggml_add + +struct ggml_tensor * ggml_add_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, true); +} + +// ggml_sub + +struct ggml_tensor * ggml_sub_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SUB; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, true); +} + +// ggml_mul + +struct ggml_tensor * ggml_mul_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + assert(is_node == false); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MUL; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, true); +} + +// ggml_div + +struct ggml_tensor * ggml_div_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + assert(is_node == false); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DIV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, true); +} + +// ggml_sqr + +struct ggml_tensor * ggml_sqr_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQR; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, true); +} + +// ggml_sqrt + +struct ggml_tensor * ggml_sqrt_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQRT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, true); +} + +// ggml_sum + +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_SUM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_mean + +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement + is_node = true; + } + + int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne); + + result->op = GGML_OP_MEAN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_repeat + +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_can_repeat(a, b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (ggml_are_same_shape(a, b) && !is_node) { + return a; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + + result->op = GGML_OP_REPEAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_abs + +struct ggml_tensor * ggml_abs_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ABS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_abs_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_abs_impl(ctx, a, true); +} + + +// ggml_sgn + +struct ggml_tensor * ggml_sgn_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SGN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sgn_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sgn_impl(ctx, a, true); +} + +// ggml_neg + +struct ggml_tensor * ggml_neg_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_NEG; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_neg_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_neg_impl(ctx, a, true); +} + +// ggml_step + +struct ggml_tensor * ggml_step_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_STEP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_step_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_step_impl(ctx, a, true); +} + +// ggml_relu + +struct ggml_tensor * ggml_relu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_RELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_relu_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_relu_impl(ctx, a, true); +} + +// ggml_gelu + +struct ggml_tensor * ggml_gelu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_GELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_gelu_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_gelu_impl(ctx, a, true); +} + +// ggml_norm + +struct ggml_tensor * ggml_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_norm_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_norm_impl(ctx, a, true); +} + +// ggml_mul_mat + +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_can_mul_mat(a, b)); + + // printUniqueTensorSize( "ggml_mul_mat", a->ne, b->ne ); + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = GGML_OP_MUL_MAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_scale + +struct ggml_tensor * ggml_scale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_is_scalar(b)); + assert(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SCALE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, true); +} + +// ggml_cpy + +struct ggml_tensor * ggml_cpy_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + // make a view of the destination + struct ggml_tensor * result = ggml_view_tensor(ctx, b); + + result->op = GGML_OP_CPY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, true); +} + +// ggml_reshape + +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_contiguous(a)); + assert(ggml_is_contiguous(b)); + assert(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1) { + assert(ggml_is_contiguous(a)); + assert(ggml_nelements(a) == ne0*ne1); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[2] = { ne0, ne1 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2) { + assert(ggml_is_contiguous(a)); + assert(ggml_nelements(a) == ne0*ne1*ne2); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[3] = { ne0, ne1, ne2 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_view_1d + +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + size_t offset) { + if (a->grad) { + assert(false); // gradient propagation is not supported + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); + + result->op = GGML_OP_VIEW; + result->grad = NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the offset here? + + return result; +} + +// ggml_view_2d + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + size_t nb1, + size_t offset) { + if (a->grad) { + assert(false); // gradient propagation is not supported + } + + const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 }; + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; + + result->op = GGML_OP_VIEW; + result->grad = NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the offset here? + + return result; +} + +// ggml_permute + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3) { + assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS); + assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS); + assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS); + assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS); + + assert(axis0 != axis1); + assert(axis0 != axis2); + assert(axis0 != axis3); + assert(axis1 != axis2); + assert(axis1 != axis3); + assert(axis2 != axis3); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + int ne[GGML_MAX_DIMS]; + int nb[GGML_MAX_DIMS]; + + ne[axis0] = a->ne[0]; + ne[axis1] = a->ne[1]; + ne[axis2] = a->ne[2]; + ne[axis3] = a->ne[3]; + + nb[axis0] = a->nb[0]; + nb[axis1] = a->nb[1]; + nb[axis2] = a->nb[2]; + nb[axis3] = a->nb[3]; + + result->ne[0] = ne[0]; + result->ne[1] = ne[1]; + result->ne[2] = ne[2]; + result->ne[3] = ne[3]; + + result->nb[0] = nb[0]; + result->nb[1] = nb[1]; + result->nb[2] = nb[2]; + result->nb[3] = nb[3]; + + result->op = GGML_OP_PERMUTE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the permutation here? + + return result; +} + +// ggml_transpose + +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->ne[0] = a->ne[1]; + result->ne[1] = a->ne[0]; + + result->nb[0] = a->nb[1]; + result->nb[1] = a->nb[0]; + + result->op = GGML_OP_TRANSPOSE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_get_rows + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); + + result->op = GGML_OP_GET_ROWS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_diag_mask_inf + +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + ((int32_t *) b->data)[0] = n_past; + + result->op = GGML_OP_DIAG_MASK_INF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_soft_max + +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_rope + +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode) { + assert(n_past >= 0); + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = n_dims; + ((int32_t *) b->data)[2] = mode; + + result->op = GGML_OP_ROPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_conv_1d_1s + +struct ggml_tensor * ggml_conv_1d_1s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(b)); + assert(a->ne[1] == b->ne[1]); + assert(a->ne[3] == 1); + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[4] = { b->ne[0], a->ne[2], 1, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + result->op = GGML_OP_CONV_1D_1S; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_conv_1d_2s + +struct ggml_tensor * ggml_conv_1d_2s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(b)); + assert(a->ne[1] == b->ne[1]); + assert(a->ne[3] == 1); + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + result->op = GGML_OP_CONV_1D_2S; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_flash_attn + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked) { + assert(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne); + + result->op = GGML_OP_FLASH_ATTN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = q; + result->src1 = k; + result->opt[0] = v; + result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0); + + return result; +} + +// ggml_flash_ff + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1) { + assert(ggml_can_mul_mat(b0, a)); + // TODO: more checks + + bool is_node = false; + + if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne); + + result->op = GGML_OP_FLASH_FF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b0; + result->opt[0] = b1; + result->opt[1] = c0; + result->opt[2] = c1; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor) { + tensor->is_param = true; + + assert(tensor->grad == NULL); + tensor->grad = ggml_dup_tensor(ctx, tensor); +} + +// ggml_compute_forward_dup + +static void ggml_compute_forward_dup_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_contiguous(dst)); + assert(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + if (ggml_is_contiguous(src0) && src0->type == dst->type) { + memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); + return; + } + + if (src0->nb[0] == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + int id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } +} + +static void ggml_compute_forward_dup_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + if (ggml_is_contiguous(src0) && src0->type == dst->type) { + memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); + return; + } + + if (src0->nb[0] == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + int id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } +} + +static void ggml_compute_forward_dup( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_dup_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_dup_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add + +static void ggml_compute_forward_add_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + const int j0 = (n/nth)*ith; + const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1); + + for (int j = j0; j < j1; j++) { + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + j*nb1), + (float *) ((char *) src0->data + j*nb01), + (float *) ((char *) src1->data + j*nb11)); + } + } else { + // src1 is not contiguous + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); + + dst_ptr[i] = src0_ptr[i] + *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_add( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sub + +static void ggml_compute_forward_sub_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sub_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_sub( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sub_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mul + +static void ggml_compute_forward_mul_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_mul_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_mul( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_div + +static void ggml_compute_forward_div_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_div_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_div( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_div_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sqr + +static void ggml_compute_forward_sqr_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqr_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqr( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqr_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sqrt + +static void ggml_compute_forward_sqrt_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqrt_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqrt( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqrt_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sum + +static void ggml_compute_forward_sum_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); + + *(float *) (dst->data) = 0.0f; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) (dst->data), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + } + } + } +} + +static void ggml_compute_forward_sum( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mean + +static void ggml_compute_forward_mean_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) = 0.0f; + + ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; + } + } + } +} + +static void ggml_compute_forward_mean( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mean_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_repeat + +static void ggml_compute_forward_repeat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_can_repeat(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: implement support for rank > 2 tensors + assert(src0->ne[2] == 1); + assert(src0->ne[3] == 1); + assert( dst->ne[2] == 1); + assert( dst->ne[3] == 1); + + const int nc = dst->ne[0]; + const int nr = dst->ne[1]; + const int nc0 = src0->ne[0]; + const int nr0 = src0->ne[1]; + const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat + const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat + + // TODO: support for transposed / permuted tensors + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i = 0; i < nrr; i++) { + for (int j = 0; j < ncr; j++) { + for (int k = 0; k < nr0; k++) { + ggml_vec_cpy_f32(nc0, + (float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])), + (float *) ((char *) src0->data + ( k)*(src0->nb[1]))); + } + } + } +} + +static void ggml_compute_forward_repeat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_abs + +static void ggml_compute_forward_abs_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_abs_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_abs( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_abs_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sgn + +static void ggml_compute_forward_sgn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sgn_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sgn( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sgn_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_neg + +static void ggml_compute_forward_neg_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_neg_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_neg( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_neg_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_step + +static void ggml_compute_forward_step_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_step_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_step( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_step_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_relu + +static void ggml_compute_forward_relu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_relu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_relu_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_gelu + +static void ggml_compute_forward_gelu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_norm + +static void ggml_compute_forward_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const ggml_float eps = 1e-5f; // TODO: make this a parameter + + // TODO: optimize + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float mean = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + mean += x[i00]; + } + + mean /= ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_float sum2 = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + ggml_float v = x[i00] - mean; + y[i00] = v; + sum2 += v*v; + } + + const float scale = 1.0/sqrt(sum2/ne00 + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_norm_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mul_mat + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +// helper function to determine if it is better to use BLAS or not +// for large matrices, BLAS is faster +static bool ggml_compute_forward_mul_mat_use_blas( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + UNUSED(src0); + + const int ne10 = src1->ne[0]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) { + //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); + return true; + } + + return false; +} +#endif + +static void ggml_compute_forward_mul_mat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + assert(ne02 == ne12); + assert(ne03 == ne13); + assert(ne2 == ne12); + assert(ne3 == ne13); + + // TODO: we don't support permuted src0 + assert(nb00 == sizeof(float) || nb01 == sizeof(float)); + + // dst cannot be transposed or permuted + assert(nb0 == sizeof(float)); + assert(nb0 <= nb1); + assert(nb1 <= nb2); + assert(nb2 <= nb3); + + assert(ne0 == ne01); + assert(ne1 == ne11); + assert(ne2 == ne02); + assert(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) return; + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + const float * x = (float *) (src0->data); + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + } + + //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + + float * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); + + for (int k = 1; k < nth; k++) { + ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); + } + + return; + } + + if (nb01 >= nb00) { + // TODO: do not support transposed src1 + assert(nb10 == sizeof(float)); + + // parallelize by src0 rows using ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + ggml_vec_dot_f32(ne00, + (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), + (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + } + } + } else { + // parallelize by src1 columns using ggml_vec_mad_f32 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + float * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + ggml_vec_mad_f32(ne01, + (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0), + (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)), + *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13))); + } + } + } + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) return; + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + int id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + for (int i00 = 0; i00 < ne00; ++i00) { + wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + } + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + // float * z = wdata + ne00*ne01; + + // z = x * yT + //{ + // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne01, ne11, ne00, + // 1.0f, x, ne00, + // y, ne00, + // 0.0f, z, ne11); + //} + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // transpose z + //for (int j = 0; j < ne11; ++j) { + // for (int i = 0; i < ne01; ++i) { + // d[j*ne01 + i] = z[i*ne11 + j]; + // } + //} + + { +#if 1 + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne00, + x, ne00, + 0.0f, d, ne01); +#else + // zT = (xT * y)T + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + ne01, ne11, ne10, + 1.0f, x, ne00, + y, ne00, + 0.0f, d, ne01); +#endif + } + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (nb01 >= nb00) { + ggml_fp16_t * const wdata = params->wdata; + + int id = 0; + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + } + } + } + } + + GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); + + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + + ggml_fp16_t * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + for (int i = ic0; i < ic1; ++i) { + ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); + } + + for (int k = 1; k < nth; k++) { + for (int i = ic0; i < ic1; ++i) { + ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); + } + } + + return; + } + + if (nb01 >= nb00) { + // fp16 -> half the size, so divide by 2 + // TODO: do not support transposed src1 + assert(nb10/2 == sizeof(ggml_fp16_t)); + + // parallelize by src0 rows using ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + ggml_fp16_t * wdata = params->wdata; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int i13 = i03; + const int i12 = i02; + + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + ggml_fp16_t * src1_col = wdata + (i13*ne12*ne11 + i12*ne11 + 0)*ne00; + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + for (int ic = 0; ic < ne11; ++ic) { + assert(ne00 % 32 == 0); + + ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + } + } + } else { + // parallelize by src1 columns using ggml_vec_mad_f32 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + ggml_fp16_t * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); + float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + + ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val); + } + } + } + } + } + + //int64_t t1 = ggml_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_mat_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_scale + +static void ggml_compute_forward_scale_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scale factor + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v); + } +} + +static void ggml_compute_forward_scale( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_scale_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_cpy + +static void ggml_compute_forward_cpy( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + +// ggml_compute_forward_reshape + +static void ggml_compute_forward_reshape( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(src0); + UNUSED(dst); +} + +// ggml_compute_forward_view + +static void ggml_compute_forward_view( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_permute + +static void ggml_compute_forward_permute( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_transpose + +static void ggml_compute_forward_transpose( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_get_rows + +static void ggml_compute_forward_get_rows_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); + } + } +} + +static void ggml_compute_forward_get_rows_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i*dst->nb[1]), + (float *) ((char *) src0->data + r*src0->nb[1])); + } +} + +static void ggml_compute_forward_get_rows( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_diag_mask_inf + +static void ggml_compute_forward_diag_mask_inf_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 1); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + + // TODO: handle transposed/permuted matrices + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n/nr; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int k = 0; k < nz; k++) { + for (int j = 0; j < nr; j++) { + for (int i = n_past; i < nc; i++) { + if (i > n_past + j) { + *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY; + } + } + } + } +} + +static void ggml_compute_forward_diag_mask_inf( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_soft_max + +static void ggml_compute_forward_soft_max_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *p = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(p[i])); + } +#endif + + float max = -INFINITY; + for (int i = 0; i < nc; i++) { + max = MAX(max, p[i]); + } + + ggml_float sum = 0.0; + + uint16_t ss; + for (int i = 0; i < nc; i++) { + if (p[i] == -INFINITY) { + p[i] = 0.0; + } else { + //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); + memcpy(&ss, &s, sizeof(ss)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]); + sum += val; + p[i] = val; + } + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(nc, p, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(p[i])); + assert(!isinf(p[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_rope + +static void ggml_compute_forward_rope_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 3); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + //const int ne0 = src0->ne[0]; + const int ne1 = src0->ne[1]; + const int ne2 = src0->ne[2]; + const int ne3 = src0->ne[3]; + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + const int nb3 = src0->nb[3]; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(float)); + + // TODO: optimize + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int p = (mode == 0 ? n_past + i2 : i2); + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < n_dims; i0 += 2) { + const double theta = pow(10000.0, ((double)-i0)/n_dims); + + const double cos_theta = cos(p*theta); + const double sin_theta = sin(p*theta); + + const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + double x0 = src[0]; + double x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } + } + } +} + +static void ggml_compute_forward_rope( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_conv_1d_1s + +static void ggml_compute_forward_conv_1d_1s_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f16(ew0, &v, + (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_1s_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_1s( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_1d_2s + +static void ggml_compute_forward_conv_1d_2s_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f16(ew0, &v, + (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_2s_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_2s( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_attn + +static void ggml_compute_forward_flash_attn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int neq0 = q->ne[0]; + const int neq1 = q->ne[1]; + const int neq2 = q->ne[2]; + const int neq3 = q->ne[3]; + + const int nek0 = k->ne[0]; + const int nek1 = k->ne[1]; + //const int nek2 = k->ne[2]; + //const int nek3 = k->ne[3]; + + //const int nev0 = v->ne[0]; + const int nev1 = v->ne[1]; + //const int nev2 = v->ne[2]; + //const int nev3 = v->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = neq0; + const int N = neq1; + const int P = nek1 - N; + const int M = P + N; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0/sqrt((double) D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(M + CACHE_LINE_SIZE_F32); + + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + for (int i = 0; i < M; i++) { + max = MAX(max, S[i]); + } + + ggml_float sum = 0.0; + + uint16_t ss; + for (int i = 0; i < M; i++) { + if (S[i] == -INFINITY) { + S[i] = 0.0; + } else { + //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max); + memcpy(&ss, &s, sizeof(ss)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]); + sum += val; + S[i] = val; + } + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + } + + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f32(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S); + } + } +} + +static void ggml_compute_forward_flash_attn_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int neq0 = q->ne[0]; + const int neq1 = q->ne[1]; + const int neq2 = q->ne[2]; + const int neq3 = q->ne[3]; + + const int nek0 = k->ne[0]; + const int nek1 = k->ne[1]; + //const int nek2 = k->ne[2]; + //const int nek3 = k->ne[3]; + + //const int nev0 = v->ne[0]; + const int nev1 = v->ne[1]; + //const int nev2 = v->ne[2]; + //const int nev3 = v->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = neq0; + const int N = neq1; + const int P = nek1 - N; + const int M = P + N; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0/sqrt((double) D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + for (int i = 0; i < M; i++) { + max = MAX(max, S[i]); + } + + ggml_float sum = 0.0; + + uint16_t ss; + for (int i = 0; i < M; i++) { + if (S[i] == -INFINITY) { + S[i] = 0.0; + } else { + //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max); + memcpy(&ss, &s, sizeof(ss)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]); + sum += val; + S[i] = val; + } + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f16(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } +} + +static void ggml_compute_forward_flash_attn( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_flash_ff + +static void ggml_compute_forward_flash_ff_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, // F16 + const struct ggml_tensor * b0, // F16 fc_w + const struct ggml_tensor * b1, // F32 fc_b + const struct ggml_tensor * c0, // F16 proj_w + const struct ggml_tensor * c1, // F32 proj_b + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int nea0 = a->ne[0]; + const int nea1 = a->ne[1]; + const int nea2 = a->ne[2]; + const int nea3 = a->ne[3]; + + const int neb00 = b0->ne[0]; + const int neb01 = b0->ne[1]; + //const int neb02 = b0->ne[2]; + //const int neb03 = b0->ne[3]; + + const int neb10 = b1->ne[0]; + const int neb11 = b1->ne[1]; + //const int neb12 = b1->ne[2]; + //const int neb13 = b1->ne[3]; + + const int nec00 = c0->ne[0]; + const int nec01 = c0->ne[1]; + //const int nec02 = c0->ne[2]; + //const int nec03 = c0->ne[3]; + + const int nec10 = c1->ne[0]; + const int nec11 = c1->ne[1]; + //const int nec12 = c1->ne[2]; + //const int nec13 = c1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nba0 = a->nb[0]; + const int nba1 = a->nb[1]; + const int nba2 = a->nb[2]; + const int nba3 = a->nb[3]; + + const int nbb00 = b0->nb[0]; + const int nbb01 = b0->nb[1]; + const int nbb02 = b0->nb[2]; + const int nbb03 = b0->nb[3]; + + const int nbb10 = b1->nb[0]; + //const int nbb11 = b1->nb[1]; + //const int nbb12 = b1->nb[2]; + //const int nbb13 = b1->nb[3]; + + const int nbc00 = c0->nb[0]; + const int nbc01 = c0->nb[1]; + const int nbc02 = c0->nb[2]; + const int nbc03 = c0->nb[3]; + + const int nbc10 = c1->nb[0]; + //const int nbc11 = c1->nb[1]; + //const int nbc12 = c1->nb[2]; + //const int nbc13 = c1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = nea0; + //const int N = nea1; + const int M = neb01; + + GGML_ASSERT(ne0 == nea0); + GGML_ASSERT(ne1 == nea1); + GGML_ASSERT(ne2 == nea2); + + GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb10 == sizeof(float)); + GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbc10 == sizeof(float)); + + GGML_ASSERT(neb00 == D); + GGML_ASSERT(neb01 == M); + GGML_ASSERT(neb10 == M); + GGML_ASSERT(neb11 == 1); + + GGML_ASSERT(nec00 == M); + GGML_ASSERT(nec01 == D); + GGML_ASSERT(nec10 == D); + GGML_ASSERT(nec11 == 1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by a rows using ggml_vec_dot_f32 + + // total rows in a + const int nr = nea1*nea2*nea3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // a indices + const int ia3 = ir/(nea2*nea1); + const int ia2 = (ir - ia3*nea2*nea1)/nea1; + const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; + + // S indices + const int i1 = ib01; + + ggml_vec_dot_f16(nea0, + S + i1, + (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), + (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + } + + ggml_vec_add_f32(neb01, S, S, (float *) b1->data); + //ggml_vec_gelu_f32(neb01, S, S); + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + ggml_vec_gelu_f16(neb01, S16, S16); + + { + // dst indices + const int i1 = ia1; + const int i2 = ia2; + const int i3 = ia3; + + for (int ic = 0; ic < nec01; ++ic) { + + ggml_vec_dot_f16(neb01, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), + S16); + } + + ggml_vec_add_f32(nec01, + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) c1->data); + } + } +} + +static void ggml_compute_forward_flash_ff( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b0, + const struct ggml_tensor * b1, + const struct ggml_tensor * c0, + const struct ggml_tensor * c1, + struct ggml_tensor * dst) { + switch (b0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); // TODO + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +///////////////////////////////// + +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + assert(params); + + switch (tensor->op) { + case GGML_OP_DUP: + { + ggml_compute_forward_dup(params, tensor->src0, tensor); + } break; + case GGML_OP_ADD: + { + ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SUB: + { + ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_MUL: + { + ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_DIV: + { + ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SQR: + { + ggml_compute_forward_sqr(params, tensor->src0, tensor); + } break; + case GGML_OP_SQRT: + { + ggml_compute_forward_sqrt(params, tensor->src0, tensor); + } break; + case GGML_OP_SUM: + { + ggml_compute_forward_sum(params, tensor->src0, tensor); + } break; + case GGML_OP_MEAN: + { + ggml_compute_forward_mean(params, tensor->src0, tensor); + } break; + case GGML_OP_REPEAT: + { + ggml_compute_forward_repeat(params, tensor->src0, tensor); + } break; + case GGML_OP_ABS: + { + ggml_compute_forward_abs(params, tensor->src0, tensor); + } break; + case GGML_OP_SGN: + { + ggml_compute_forward_sgn(params, tensor->src0, tensor); + } break; + case GGML_OP_NEG: + { + ggml_compute_forward_neg(params, tensor->src0, tensor); + } break; + case GGML_OP_STEP: + { + ggml_compute_forward_step(params, tensor->src0, tensor); + } break; + case GGML_OP_RELU: + { + ggml_compute_forward_relu(params, tensor->src0, tensor); + } break; + case GGML_OP_GELU: + { + ggml_compute_forward_gelu(params, tensor->src0, tensor); + } break; + case GGML_OP_NORM: + { + ggml_compute_forward_norm(params, tensor->src0, tensor); + } break; + case GGML_OP_MUL_MAT: + { + ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SCALE: + { + ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CPY: + { + ggml_compute_forward_cpy(params, tensor->src0, tensor); + } break; + case GGML_OP_RESHAPE: + { + ggml_compute_forward_reshape(params, tensor->src0, tensor); + } break; + case GGML_OP_VIEW: + { + ggml_compute_forward_view(params, tensor->src0); + } break; + case GGML_OP_PERMUTE: + { + ggml_compute_forward_permute(params, tensor->src0); + } break; + case GGML_OP_TRANSPOSE: + { + ggml_compute_forward_transpose(params, tensor->src0); + } break; + case GGML_OP_GET_ROWS: + { + ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_DIAG_MASK_INF: + { + ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SOFT_MAX: + { + ggml_compute_forward_soft_max(params, tensor->src0, tensor); + } break; + case GGML_OP_ROPE: + { + ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CONV_1D_1S: + { + ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CONV_1D_2S: + { + ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_FLASH_ATTN: + { + int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); + } break; + case GGML_OP_FLASH_FF: + { + ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); + } break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { + struct ggml_tensor * src0 = tensor->src0; + struct ggml_tensor * src1 = tensor->src1; + + switch (tensor->op) { + case GGML_OP_DUP: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_ADD: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_SUB: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_MUL: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, src1, tensor->grad), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + ggml_mul(ctx, src0, tensor->grad), + inplace); + } + } break; + case GGML_OP_DIV: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_div(ctx, tensor->grad, src1), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_sub_impl(ctx, + src1->grad, + ggml_mul(ctx, + tensor->grad, + ggml_div(ctx, tensor, src1)), + inplace); + } + } break; + case GGML_OP_SQR: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_mul(ctx, src0, tensor->grad), + ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)), + inplace); + } + } break; + case GGML_OP_SQRT: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_div(ctx, + ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor), + tensor), + inplace); + } + } break; + case GGML_OP_SUM: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_repeat(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case GGML_OP_MEAN: + { + assert(false); // TODO: implement + } break; + case GGML_OP_REPEAT: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_sum(ctx, tensor->grad), + inplace); + } + } break; + case GGML_OP_ABS: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_sgn(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_OP_SGN: + { + if (src0->grad) { + // noop + } + } break; + case GGML_OP_NEG: + { + if (src0->grad) { + src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_STEP: + { + if (src0->grad) { + // noop + } + } break; + case GGML_OP_RELU: + { + if (src0->grad) { + src0->grad = ggml_sub_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_step(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_OP_GELU: + { + assert(false); // TODO: not implemented + } break; + case GGML_OP_NORM: + { + assert(false); // TODO: not implemented + } break; + case GGML_OP_MUL_MAT: + { + if (src0->grad) { + // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad); + assert(false); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + // TODO: fix transpose, the node will break the graph connections + ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad), + inplace); + } + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CPY: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_RESHAPE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_VIEW: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_PERMUTE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_TRANSPOSE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_GET_ROWS: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_DIAG_MASK_INF: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_ROPE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D_1S: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D_2S: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_FLASH_ATTN: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_FLASH_FF: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { + if (node->grad == NULL) { + // this usually happens when we generate intermediate nodes from constants in the backward pass + // it can also happen during forward pass, if the user performs computations with constants + if (node->op != GGML_OP_NONE) { + //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); + } + } + + // check if already visited + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return; + } + } + + for (int i = 0; i < cgraph->n_leafs; i++) { + if (cgraph->leafs[i] == node) { + return; + } + } + + if (node->src0) { + ggml_visit_parents(cgraph, node->src0); + } + + if (node->src1) { + ggml_visit_parents(cgraph, node->src1); + } + + for (int i = 0; i < GGML_MAX_OPT; ++i) { + if (node->opt[i]) { + ggml_visit_parents(cgraph, node->opt[i]); + } + } + + if (node->op == GGML_OP_NONE && node->grad == NULL) { + // reached a leaf node, not part of the gradient graph (e.g. a constant) + assert(cgraph->n_leafs < GGML_MAX_NODES); + + cgraph->leafs[cgraph->n_leafs] = node; + cgraph->n_leafs++; + } else { + assert(cgraph->n_nodes < GGML_MAX_NODES); + + cgraph->nodes[cgraph->n_nodes] = node; + cgraph->grads[cgraph->n_nodes] = node->grad; + cgraph->n_nodes++; + } +} + +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { + if (!expand) { + cgraph->n_nodes = 0; + cgraph->n_leafs = 0; + } + + const int n0 = cgraph->n_nodes; + UNUSED(n0); + + ggml_visit_parents(cgraph, tensor); + + const int n_new = cgraph->n_nodes - n0; + GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); + + if (n_new > 0) { + // the last added node should always be starting point + assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + } +} + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { + ggml_build_forward_impl(cgraph, tensor, true); +} + +struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { + struct ggml_cgraph result = { + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.n_threads =*/ 0, + /*.work_size =*/ 0, + /*.work =*/ NULL, + /*.nodes =*/ { NULL }, + /*.grads =*/ { NULL }, + /*.leafs =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + ggml_build_forward_impl(&result, tensor, false); + + return result; +} + +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { + struct ggml_cgraph result = *gf; + + assert(gf->n_nodes > 0); + + // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph + if (keep) { + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->grad) { + node->grad = ggml_dup_tensor(ctx, node); + gf->grads[i] = node->grad; + } + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = gf->nodes[i]; + + // because we detached the grad nodes from the original graph, we can afford inplace operations + if (node->grad) { + ggml_compute_backward(ctx, node, keep); + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->is_param) { + GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + ggml_build_forward_impl(&result, node->grad, true); + } + } + + return result; +} + +// +// thread data +// +// synchronization is done via busy loops +// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops +// + +#ifdef __APPLE__ + +//#include <os/lock.h> + +//typedef os_unfair_lock ggml_lock_t; +// +//#define ggml_lock_init(x) UNUSED(x) +//#define ggml_lock_destroy(x) UNUSED(x) +//#define ggml_lock_lock os_unfair_lock_lock +//#define ggml_lock_unlock os_unfair_lock_unlock +// +//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#define ggml_lock_lock(x) UNUSED(x) +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#else + +//typedef pthread_spinlock_t ggml_lock_t; + +//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) +//#define ggml_lock_destroy pthread_spin_destroy +//#define ggml_lock_lock pthread_spin_lock +//#define ggml_lock_unlock pthread_spin_unlock + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#define ggml_lock_lock(x) UNUSED(x) +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#endif + +struct ggml_compute_state_shared { + ggml_lock_t spin; + + int n_threads; + + // synchronization primitives + atomic_int n_ready; + atomic_bool has_work; + atomic_bool stop; // stop all threads +}; + +struct ggml_compute_state { + ggml_thread_t thrd; + + struct ggml_compute_params params; + struct ggml_tensor * node; + + struct ggml_compute_state_shared * shared; +}; + +static thread_ret_t ggml_graph_compute_thread(void * data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + + const int n_threads = state->shared->n_threads; + + while (true) { + if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) { + atomic_store(&state->shared->has_work, false); + } else { + while (atomic_load(&state->shared->has_work)) { + if (atomic_load(&state->shared->stop)) { + return 0; + } + ggml_lock_lock (&state->shared->spin); + ggml_lock_unlock(&state->shared->spin); + } + } + + atomic_fetch_sub(&state->shared->n_ready, 1); + + // wait for work + while (!atomic_load(&state->shared->has_work)) { + if (atomic_load(&state->shared->stop)) { + return 0; + } + ggml_lock_lock (&state->shared->spin); + ggml_lock_unlock(&state->shared->spin); + } + + // check if we should stop + if (atomic_load(&state->shared->stop)) { + break; + } + + if (state->node) { + ggml_compute_forward(&state->params, state->node); + state->node = NULL; + } else { + break; + } + } + + return 0; +} + +void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { + if (cgraph->n_threads <= 0) { + cgraph->n_threads = 8; + } + + const int n_threads = cgraph->n_threads; + + struct ggml_compute_state_shared state_shared = { + /*.spin =*/ GGML_LOCK_INITIALIZER, + /*.n_threads =*/ n_threads, + /*.n_ready =*/ 0, + /*.has_work =*/ false, + /*.stop =*/ false, + }; + struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; + + // create thread pool + if (n_threads > 1) { + ggml_lock_init(&state_shared.spin); + + atomic_store(&state_shared.has_work, true); + + for (int j = 0; j < n_threads - 1; j++) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .params = { + .type = GGML_TASK_COMPUTE, + .ith = j + 1, + .nth = n_threads, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }, + .node = NULL, + .shared = &state_shared, + }; + int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + assert(rc == 0); + UNUSED(rc); + } + } + + // initialize tasks + work buffer + { + size_t work_size = 0; + + // thread scheduling for the different operations + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case GGML_OP_DUP: + { + node->n_tasks = 1; + } break; + case GGML_OP_ADD: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SUM: + case GGML_OP_MEAN: + case GGML_OP_REPEAT: + case GGML_OP_ABS: + case GGML_OP_SGN: + case GGML_OP_NEG: + case GGML_OP_STEP: + case GGML_OP_RELU: + { + node->n_tasks = 1; + } break; + case GGML_OP_GELU: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_NORM: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_MUL_MAT: + { + // TODO: use different scheduling for different matrix sizes + node->n_tasks = n_threads; + + size_t cur = 0; + + // TODO: better way to determine if the matrix is transposed + if (node->src0->nb[1] < node->src0->nb[0]) { + cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) + } else { + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); + } else { + cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); + } +#else + cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); +#endif + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = 0; + } else { + GGML_ASSERT(false); + } + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_SCALE: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_CPY: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: + case GGML_OP_DIAG_MASK_INF: + { + node->n_tasks = 1; + } break; + case GGML_OP_SOFT_MAX: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_ROPE: + { + node->n_tasks = 1; + } break; + case GGML_OP_CONV_1D_1S: + case GGML_OP_CONV_1D_2S: + { + node->n_tasks = n_threads; + + GGML_ASSERT(node->src0->ne[3] == 1); + GGML_ASSERT(node->src1->ne[2] == 1); + GGML_ASSERT(node->src1->ne[3] == 1); + + size_t cur = 0; + const int nk = node->src0->ne[0]; + + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*( + nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*( + nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else { + GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_ATTN: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_FF: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_NONE: + { + node->n_tasks = 1; + } break; + case GGML_OP_COUNT: + { + assert(false); + } break; + } + } + + if (cgraph->work != NULL && work_size > cgraph->work_size) { + assert(false); // TODO: better handling + } + + if (work_size > 0 && cgraph->work == NULL) { + cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1); + + GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size); + cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size); + } + } + + const int64_t perf_start_cycles = ggml_perf_cycles(); + const int64_t perf_start_time_us = ggml_perf_time_us(); + + for (int i = 0; i < cgraph->n_nodes; i++) { + GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); + + struct ggml_tensor * node = cgraph->nodes[i]; + + // TODO: this could be used to avoid unnecessary computations, but it needs to be improved + //if (node->grad == NULL && node->perf_runs > 0) { + // continue; + //} + + const int64_t perf_node_start_cycles = ggml_perf_cycles(); + const int64_t perf_node_start_time_us = ggml_perf_time_us(); + + // INIT + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_INIT, + /*.ith =*/ 0, + /*.nth =*/ node->n_tasks, + /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, + /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, + }; + + ggml_compute_forward(¶ms, node); + + // COMPUTE + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + // launch thread pool + for (int j = 0; j < n_threads - 1; j++) { + workers[j].params = (struct ggml_compute_params) { + .type = GGML_TASK_COMPUTE, + .ith = j + 1, + .nth = n_threads, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }; + workers[j].node = node; + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) > 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_store(&state_shared.has_work, true); + } + + params.type = GGML_TASK_COMPUTE; + ggml_compute_forward(¶ms, node); + + // wait for thread pool + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) != 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + } + + // FINALIZE + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + // launch thread pool + for (int j = 0; j < n_threads - 1; j++) { + workers[j].params = (struct ggml_compute_params) { + .type = GGML_TASK_FINALIZE, + .ith = j + 1, + .nth = n_threads, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }; + workers[j].node = node; + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) > 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_store(&state_shared.has_work, true); + } + + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + + // wait for thread pool + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) != 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + } + + // performance stats (node) + { + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += perf_cycles_cur; + node->perf_time_us += perf_time_us_cur; + } + } + + // join thread pool + if (n_threads > 1) { + atomic_store(&state_shared.stop, true); + atomic_store(&state_shared.has_work, true); + + for (int j = 0; j < n_threads - 1; j++) { + int rc = ggml_thread_join(workers[j].thrd, NULL); + assert(rc == 0); + UNUSED(rc); + } + + ggml_lock_destroy(&state_shared.spin); + } + + // performance stats (graph) + { + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; + + cgraph->perf_runs++; + cgraph->perf_cycles += perf_cycles_cur; + cgraph->perf_time_us += perf_time_us_cur; + + GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", + __func__, cgraph->perf_runs, + (double) perf_cycles_cur / (double) ggml_cycles_per_ms(), + (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs, + (double) perf_time_us_cur / 1000.0, + (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); + } +} + +void ggml_graph_reset(struct ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * grad = cgraph->grads[i]; + + if (grad) { + ggml_set_zero(grad); + } + } +} + +void ggml_graph_print(const struct ggml_cgraph * cgraph) { + int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; + + GGML_PRINT("=== GRAPH ===\n"); + + GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads); + GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size); + + GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + perf_total_per_op_us[node->op] += node->perf_time_us; + + GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + i, + node->ne[0], node->ne[1], node->ne[2], + GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, + (double) node->perf_cycles / (double) ggml_cycles_per_ms(), + (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, + (double) node->perf_time_us / 1000.0, + (double) node->perf_time_us / 1000.0 / node->perf_runs); + } + + GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * node = cgraph->leafs[i]; + + GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n", + i, + node->ne[0], node->ne[1], + GGML_OP_LABEL[node->op]); + } + + for (int i = 0; i < GGML_OP_COUNT; i++) { + GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0); + } + + GGML_PRINT("========================================\n"); +} + +// check if node is part of the graph +static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + if (cgraph == NULL) { + return true; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return true; + } + } + + return false; +} + +static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * parent = cgraph->nodes[i]; + + if (parent->grad == node) { + return parent; + } + } + + return NULL; +} + +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { + char color[16]; + + FILE * fp = fopen(filename, "w"); + assert(fp); + + fprintf(fp, "digraph G {\n"); + fprintf(fp, " newrank = true;\n"); + fprintf(fp, " rankdir = LR;\n"); + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + if (ggml_graph_get_parent(gb, node) != NULL) { + continue; + } + + if (node->is_param) { + snprintf(color, sizeof(color), "yellow"); + } else if (node->grad) { + if (ggml_graph_find(gf, node)) { + snprintf(color, sizeof(color), "green"); + } else { + snprintf(color, sizeof(color), "lightblue"); + } + } else { + snprintf(color, sizeof(color), "white"); + } + + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"%d [%d, %d] | <x>%s", + (void *) node, color, + i, node->ne[0], node->ne[1], + GGML_OP_SYMBOL[node->op]); + + if (node->grad) { + fprintf(fp, " | <g>%s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]); + } else { + fprintf(fp, "\"; ]\n"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + snprintf(color, sizeof(color), "pink"); + + if (ggml_nelements(node) == 1) { + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"<x>%.1e\"; ]\n", + (void *) node, color, ggml_get_f32_1d(node, 0)); + } else { + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"<x>CONST %d [%d, %d]\"; ]\n", + (void *) node, color, + i, node->ne[0], node->ne[1]); + } + } + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + struct ggml_tensor * parent = ggml_graph_get_parent(gb, node); + + if (node->src0) { + struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0); + + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n", + parent0 ? (void *) parent0 : (void *) node->src0, + parent0 ? "g" : "x", + parent ? (void *) parent : (void *) node, + parent ? "g" : "x", + parent ? "empty" : "vee", + parent ? "dashed" : "solid"); + } + + if (node->src1) { + struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1); + + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n", + parent1 ? (void *) parent1 : (void *) node->src1, + parent1 ? "g" : "x", + parent ? (void *) parent : (void *) node, + parent ? "g" : "x", + parent ? "empty" : "vee", + parent ? "dashed" : "solid"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + if (node->src0) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n", + (void *) node->src0, "x", + (void *) node, "x"); + } + + if (node->src1) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n", + (void *) node->src1, "x", + (void *) node, "x"); + } + } + + fprintf(fp, "}\n"); + + fclose(fp); + + GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to set tensor from array + for (int j = 0; j < ne; ++j) { + ggml_set_f32_1d(ps[p], j, x[i++]); + } + } +} + +static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int j = 0; j < ne; ++j) { + x[i++] = ggml_get_f32_1d(ps[p], j); + } + } +} + +static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int j = 0; j < ne; ++j) { + g[i++] = ggml_get_f32_1d(ps[p]->grad, j); + } + } +} + +// +// ADAM +// +// ref: https://arxiv.org/pdf/1412.6980.pdf +// + +static enum ggml_opt_result ggml_opt_adam( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb) { + assert(ggml_is_scalar(f)); + + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + assert(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + // constants + const float alpha = params.adam.alpha; + const float beta1 = params.adam.beta1; + const float beta2 = params.adam.beta2; + const float eps = params.adam.eps; + + float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters + float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient + float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared + float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment + float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment + float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat + float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat + + float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values + + // initialize + ggml_vec_set_f32(nx, m, 0.0f); + ggml_vec_set_f32(nx, v, 0.0f); + + // update view + ggml_opt_get_params(np, ps, x); + + // compute the function value + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + float fx_prev = ggml_get_f32_1d(f, 0); + if (pf) { + pf[0] = fx_prev; + } + + int n_no_improvement = 0; + float fx_best = fx_prev; + + // run the optimizer + for (int t = 0; t < params.adam.n_iter; ++t) { + GGML_PRINT_DEBUG ("=== iter %d ===\n", t); + + GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0)); + GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0)); + + for (int i = 0; i < np; ++i) { + GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, + ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0)); + } + + const int64_t t_start_wall = ggml_time_us(); + const int64_t t_start_cpu = ggml_cycles(); + UNUSED(t_start_wall); + UNUSED(t_start_cpu); + + { + // update the gradient + ggml_opt_get_grad(np, ps, g1); + + // m_t = beta1*m_t-1 + (1 - beta1)*g_t + ggml_vec_scale_f32(nx, m, beta1); + ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1); + + // g2 = g1^2 + ggml_vec_sqr_f32 (nx, g2, g1); + + // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 + ggml_vec_scale_f32(nx, v, beta2); + ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2); + + // m^hat = m_t / (1 - beta1^t) + // v^hat = v_t / (1 - beta2^t) + // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps) + ggml_vec_cpy_f32 (nx, mh, m); + ggml_vec_cpy_f32 (nx, vh, v); + + ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1))); + ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1))); + + ggml_vec_sqrt_f32 (nx, vh, vh); + ggml_vec_acc1_f32 (nx, vh, eps); + + ggml_vec_div_f32 (nx, mh, mh, vh); + ggml_vec_sub_f32 (nx, x, x, mh); + + // update the parameters + ggml_opt_set_params(np, ps, x); + } + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + const float fx = ggml_get_f32_1d(f, 0); + + // check convergence + if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) { + GGML_PRINT_DEBUG("converged\n"); + + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= t) { + const float rate = (pf[t%params.past] - fx)/fx; + + if (fabs(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[t%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx_best > fx) { + fx_best = fx; + n_no_improvement = 0; + } else { + ++n_no_improvement; + + if (n_no_improvement >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + fx_prev = fx; + + { + const int64_t t_end_cpu = ggml_cycles(); + GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); + UNUSED(t_end_cpu); + + const int64_t t_end_wall = ggml_time_us(); + GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); + UNUSED(t_end_wall); + } + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +// +// L-BFGS +// +// the L-BFGS implementation below is based on the following implementation: +// +// https://github.com/chokkan/liblbfgs +// + +struct ggml_lbfgs_iteration_data { + float alpha; + float ys; + float * s; + float * y; +}; + +static enum ggml_opt_result linesearch_backtracking( + struct ggml_context * ctx, + const struct ggml_opt_params * params, + int nx, + float * x, + float * fx, + float * g, + float * d, + float * step, + const float * xp, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + const int np, + struct ggml_tensor * ps[]) { + int count = 0; + + float width = 0.0f; + float dg = 0.0f; + float finit = 0.0f; + float dginit = 0.0f; + float dgtest = 0.0f; + + const float dec = 0.5f; + const float inc = 2.1f; + + if (*step <= 0.) { + return GGML_LINESEARCH_INVALID_PARAMETERS; + } + + // compute the initial gradient in the search direction + ggml_vec_dot_f32(nx, &dginit, g, d); + + // make sure that d points to a descent direction + if (0 < dginit) { + return GGML_LINESEARCH_FAIL; + } + + // initialize local variables + finit = *fx; + dgtest = params->lbfgs.ftol*dginit; + + while (true) { + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_mad_f32(nx, x, d, *step); + + // evaluate the function and gradient values + { + ggml_opt_set_params(np, ps, x); + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + ggml_opt_get_grad(np, ps, g); + + *fx = ggml_get_f32_1d(f, 0); + } + + ++count; + + if (*fx > finit + (*step)*dgtest) { + width = dec; + } else { + // Armijo condition is satisfied + if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) { + return count; + } + + ggml_vec_dot_f32(nx, &dg, g, d); + + // check the Wolfe condition + if (dg < params->lbfgs.wolfe * dginit) { + width = inc; + } else { + if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) { + // regular Wolfe conditions + return count; + } + + if(dg > -params->lbfgs.wolfe*dginit) { + width = dec; + } else { + // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) + return count; + } + return count; + } + } + + if (*step < params->lbfgs.min_step) { + return GGML_LINESEARCH_MINIMUM_STEP; + } + if (*step > params->lbfgs.max_step) { + return GGML_LINESEARCH_MAXIMUM_STEP; + } + if (params->lbfgs.max_linesearch <= count) { + return GGML_LINESEARCH_MAXIMUM_ITERATIONS; + } + + (*step) *= width; + } + + return GGML_LINESEARCH_FAIL; +} + +static enum ggml_opt_result ggml_opt_lbfgs( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb) { + if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || + params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { + if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) { + return GGML_OPT_INVALID_WOLFE; + } + } + + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + const int m = params.lbfgs.m; + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + assert(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters + float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters + float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient + float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient + float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction + + float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values + + float fx = 0.0f; // cost function value + float xnorm = 0.0f; // ||x|| + float gnorm = 0.0f; // ||g|| + float step = 0.0f; + + // initialize x from the graph nodes + ggml_opt_get_params(np, ps, x); + + // the L-BFGS memory + struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m); + + for (int i = 0; i < m; ++i) { + lm[i].alpha = 0.0f; + lm[i].ys = 0.0f; + lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; + lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; + } + + // evaluate the function value and its gradient + { + ggml_opt_set_params(np, ps, x); + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + ggml_opt_get_grad(np, ps, g); + + fx = ggml_get_f32_1d(f, 0); + } + + if (pf) { + pf[0] = fx; + } + + float fx_best = fx; + + // search direction = -gradient + ggml_vec_neg_f32(nx, d, g); + + // ||x||, ||g|| + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + + // already optimized + if (gnorm/xnorm <= params.lbfgs.eps) { + return GGML_OPT_OK; + } + + // initial step + ggml_vec_norm_inv_f32(nx, &step, d); + + int j = 0; + int k = 1; + int ls = 0; + int end = 0; + int bound = 0; + int n_no_improvement = 0; + + float ys = 0.0f; + float yy = 0.0f; + float beta = 0.0f; + + while (true) { + // store the current position and gradient vectors + ggml_vec_cpy_f32(nx, xp, x); + ggml_vec_cpy_f32(nx, gp, g); + + ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps); + + if (ls < 0) { + // linesearch failed - go back to the previous point and return + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_cpy_f32(nx, g, gp); + + return ls; + } + + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + + if (xnorm < 1.0) { + xnorm = 1.0; + } + if (gnorm/xnorm <= params.lbfgs.eps) { + // converged + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= k) { + const float rate = (pf[k%params.past] - fx)/fx; + + if (fabs(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[k%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx < fx_best) { + fx_best = fx; + n_no_improvement = 0; + } else { + n_no_improvement++; + + if (n_no_improvement >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) { + // reached the maximum number of iterations + return GGML_OPT_DID_NOT_CONVERGE; + } + + // update vectors s and y: + // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. + // y_{k+1} = g_{k+1} - g_{k}. + // + ggml_vec_sub_f32(nx, lm[end].s, x, xp); + ggml_vec_sub_f32(nx, lm[end].y, g, gp); + + // compute scalars ys and yy: + // ys = y^t \cdot s -> 1 / \rho. + // yy = y^t \cdot y. + // + ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s); + ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y); + + lm[end].ys = ys; + + // find new search direction + // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS + + bound = (m <= k) ? m : k; + k++; + end = (end + 1)%m; + + // initialize search direction with -g + ggml_vec_neg_f32(nx, d, g); + + j = end; + for (int i = 0; i < bound; ++i) { + j = (j + m - 1) % m; + // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} + ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d); + lm[j].alpha /= lm[j].ys; + // q_{i} = q_{i+1} - \alpha_{i} y_{i} + ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha); + } + + ggml_vec_scale_f32(nx, d, ys/yy); + + for (int i = 0; i < bound; ++i) { + // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} + ggml_vec_dot_f32(nx, &beta, lm[j].y, d); + beta /= lm[j].ys; + // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} + ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta); + j = (j + 1)%m; + } + + step = 1.0; + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { + struct ggml_opt_params result; + + switch (type) { + case GGML_OPT_ADAM: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_ADAM, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 100, + + .print_forward_graph = true, + .print_backward_graph = true, + + .adam = { + .n_iter = 10000, + .alpha = 0.001f, + .beta1 = 0.9f, + .beta2 = 0.999f, + .eps = 1e-8f, + .eps_f = 1e-5f, + .eps_g = 1e-3f, + }, + }; + } break; + case GGML_OPT_LBFGS: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_LBFGS, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 0, + + .print_forward_graph = true, + .print_backward_graph = true, + + .lbfgs = { + .m = 6, + .n_iter = 100, + .max_linesearch = 20, + + .eps = 1e-5f, + .ftol = 1e-4f, + .wolfe = 0.9f, + .min_step = 1e-20f, + .max_step = 1e+20f, + + .linesearch = GGML_LINESEARCH_DEFAULT, + }, + }; + } break; + } + + return result; +} + +enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f) { + bool free_ctx = false; + if (ctx == NULL) { + struct ggml_init_params params_ctx = { + .mem_size = 16*1024*1024, + .mem_buffer = NULL, + }; + + ctx = ggml_init(params_ctx); + if (ctx == NULL) { + return GGML_OPT_NO_CONTEXT; + } + + free_ctx = true; + } + + enum ggml_opt_result result = GGML_OPT_OK; + + // build forward + backward compute graphs + struct ggml_cgraph gf = ggml_build_forward (f); + struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, false); + + switch (params.type) { + case GGML_OPT_ADAM: + { + result = ggml_opt_adam(ctx, params, f, &gf, &gb); + } break; + case GGML_OPT_LBFGS: + { + result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb); + } break; + } + + if (params.print_forward_graph) { + ggml_graph_print (&gf); + ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot"); + } + + if (params.print_backward_graph) { + ggml_graph_print (&gb); + ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot"); + } + + if (free_ctx) { + ggml_free(ctx); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +int ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_neon(void) { +#if defined(__ARM_NEON) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_blas(void) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + return 1; +#else + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/Whisper/source/ggml.h b/Whisper/source/ggml.h new file mode 100644 index 0000000..a217d2d --- /dev/null +++ b/Whisper/source/ggml.h @@ -0,0 +1,737 @@ +#pragma once + +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct ggml_context * ctx = ggml_init(params); +// +// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// +// ggml_set_param(ctx, x); // x is an input variable +// +// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); +// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct ggml_cgraph gf = ggml_build_forward(f); +// +// // set the input variable and parameter values +// ggml_set_f32(x, 2.0f); +// ggml_set_f32(a, 3.0f); +// ggml_set_f32(b, 4.0f); +// +// ggml_graph_compute(ctx0, &gf); +// +// printf("f = %f\n", ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the ggml_graph_compute() function. +// +// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - ggml_permute() +// - ggml_conv_1d_1s() +// - ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct ggml_tensor) +// +// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct ggml_tensor * c = ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3); +// +// // a[1, 2] = 1.0f; +// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f; +// +// // a[2, 0] = 2.0f; +// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f; +// +// ... +// } +// +// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + +#ifdef __cplusplus +extern "C" { +#endif + +#include <stdint.h> +#include <stddef.h> +#include <stdbool.h> + +#define GGML_MAX_DIMS 4 +#define GGML_MAX_NODES 4096 +#define GGML_MAX_PARAMS 16 +#define GGML_MAX_CONTEXTS 64 +#define GGML_MAX_OPT 4 + +#ifdef __ARM_NEON +// we use the built-in 16-bit float type +typedef __fp16 ggml_fp16_t; +#else +typedef uint16_t ggml_fp16_t; +#endif + +// convert FP16 <-> FP32 +float ggml_fp16_to_fp32(ggml_fp16_t x); +ggml_fp16_t ggml_fp32_to_fp16(float x); + +struct ggml_object; +struct ggml_context; + +enum ggml_type { + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, + GGML_TYPE_F16, + GGML_TYPE_F32, + GGML_TYPE_COUNT, +}; + +// available tensor operations: +enum ggml_op { + GGML_OP_NONE = 0, + + GGML_OP_DUP, + GGML_OP_ADD, + GGML_OP_SUB, + GGML_OP_MUL, + GGML_OP_DIV, + GGML_OP_SQR, + GGML_OP_SQRT, + GGML_OP_SUM, + GGML_OP_MEAN, + GGML_OP_REPEAT, + GGML_OP_ABS, + GGML_OP_SGN, + GGML_OP_NEG, + GGML_OP_STEP, + GGML_OP_RELU, + GGML_OP_GELU, + GGML_OP_NORM, // normalize + + GGML_OP_MUL_MAT, + + GGML_OP_SCALE, + GGML_OP_CPY, + GGML_OP_RESHAPE, + GGML_OP_VIEW, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_DIAG_MASK_INF, + GGML_OP_SOFT_MAX, + GGML_OP_ROPE, + GGML_OP_CONV_1D_1S, + GGML_OP_CONV_1D_2S, + + GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_FF, + + GGML_OP_COUNT, +}; + +// n-dimensional tensor +struct ggml_tensor { + enum ggml_type type; + + int n_dims; + int ne[GGML_MAX_DIMS]; // number of elements + size_t nb[GGML_MAX_DIMS]; // stride in bytes: + // nb[0] = sizeof(type) + // nb[1] = nb[0] * ne[0] + padding + // nb[i] = nb[i-1] * ne[i-1] + + // compute data + enum ggml_op op; + + bool is_param; + + struct ggml_tensor * grad; + struct ggml_tensor * src0; + struct ggml_tensor * src1; + struct ggml_tensor * opt[GGML_MAX_OPT]; + + // thread scheduling + int n_tasks; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + + void * data; + char padding[8]; +}; + +// computation graph +struct ggml_cgraph { + int n_nodes; + int n_leafs; + int n_threads; + + size_t work_size; + struct ggml_tensor * work; + + struct ggml_tensor * nodes[GGML_MAX_NODES]; + struct ggml_tensor * grads[GGML_MAX_NODES]; + struct ggml_tensor * leafs[GGML_MAX_NODES]; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; +}; + +struct ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally +}; + +void ggml_time_init(void); // call this once at the beginning of the program +int64_t ggml_time_ms(void); +int64_t ggml_time_us(void); +int64_t ggml_cycles(void); +int64_t ggml_cycles_per_ms(void); + +void ggml_print_object (const struct ggml_object * obj); +void ggml_print_objects(const struct ggml_context * ctx); + +int ggml_nelements(const struct ggml_tensor * tensor); +size_t ggml_nbytes (const struct ggml_tensor * tensor); + +size_t ggml_type_size (enum ggml_type type); +size_t ggml_element_size(const struct ggml_tensor * tensor); + +struct ggml_context * ggml_init(struct ggml_init_params params); +void ggml_free(struct ggml_context * ctx); + +size_t ggml_used_mem(const struct ggml_context * ctx); + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int *ne); + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0); + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1); + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2); + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2, + int ne3); + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + +struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); +struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); +struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + + void * ggml_get_data (const struct ggml_tensor * tensor); +float * ggml_get_data_f32(const struct ggml_tensor * tensor); + +// +// operations on tensors with backpropagation +// + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// return scalar +// TODO: compute sum along rows +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// mean along rows +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// if a is the same shape as b, and a is not parameter, return a +// otherwise, return a new tensor: repeat(a) to fit in b +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// TODO: double-check this computation is correct +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// normalize along rows +// TODO: eps is hardcoded to 1e-5 for now +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// A: m rows, n columns +// B: p rows, n columns (i.e. we transpose it internally) +// result is m columns, p rows +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// +// operations on tensors without backpropagation +// + +// in-place, returns view(a) +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// a -> b, return view(b) +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// return view(a), b specifies the new shape +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// return view(a) +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1); + +// return view(a) +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2); + +// offset in bytes +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + size_t offset); + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + size_t nb1, // row stride in bytes + size_t offset); + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3); + +// alias for ggml_permute(ctx, a, 1, 0, 2, 3) +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// set elements above the diagonal to -INF +// in-place, returns view(a) +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + +// in-place, returns view(a) +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// rotary position embedding +// in-place, returns view(a) +// if mode == 1, skip n_past elements +// TODO: avoid creating a new tensor every time +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode); + +// padding = 1 +// TODO: we don't support extra parameters for now +// that's why we are hard-coding the stride, padding, and dilation +// not great .. +struct ggml_tensor * ggml_conv_1d_1s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_conv_1d_2s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked); + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1); + +// +// automatic differentiation +// + +void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor); + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + +struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); + +void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); +void ggml_graph_reset (struct ggml_cgraph * cgraph); + +// print info and performance information for the graph +void ggml_graph_print(const struct ggml_cgraph * cgraph); + +// dump the graph into a file using the dot format +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + +// +// optimization +// + +// optimization methods +enum ggml_opt_type { + GGML_OPT_ADAM, + GGML_OPT_LBFGS, +}; + +// linesearch methods +enum ggml_linesearch { + GGML_LINESEARCH_DEFAULT = 1, + + GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, + GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, + GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, +}; + +// optimization return values +enum ggml_opt_result { + GGML_OPT_OK = 0, + GGML_OPT_DID_NOT_CONVERGE, + GGML_OPT_NO_CONTEXT, + GGML_OPT_INVALID_WOLFE, + GGML_OPT_FAIL, + + GGML_LINESEARCH_FAIL = -128, + GGML_LINESEARCH_MINIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_ITERATIONS, + GGML_LINESEARCH_INVALID_PARAMETERS, +}; + +// optimization parameters +// +// see ggml.c (ggml_opt_default_params) for default values +// +struct ggml_opt_params { + enum ggml_opt_type type; + + int n_threads; + + // delta-based convergence test + // + // if past == 0 - disabled + // if past > 0: + // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) + // + int past; + float delta; + + // maximum number of iterations without improvement + // + // if 0 - disabled + // if > 0: + // assume convergence if no cost improvement in this number of iterations + // + int max_no_improvement; + + bool print_forward_graph; + bool print_backward_graph; + + // ADAM parameters + struct { + int n_iter; + + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float eps_f; // epsilon for convergence test + float eps_g; // epsilon for convergence test + } adam; + + // LBFGS parameters + struct { + int m; // number of corrections to approximate the inv. Hessian + int n_iter; + int max_linesearch; + + float eps; // convergence tolerance + float ftol; // line search tolerance + float wolfe; + float min_step; + float max_step; + + enum ggml_linesearch linesearch; + } lbfgs; +}; + +struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); + +// optimize the function defined by the tensor f +enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f); + +// +// system info +// + +int ggml_cpu_has_avx(void); +int ggml_cpu_has_avx2(void); +int ggml_cpu_has_avx512(void); +int ggml_cpu_has_fma(void); +int ggml_cpu_has_neon(void); +int ggml_cpu_has_arm_fma(void); +int ggml_cpu_has_f16c(void); +int ggml_cpu_has_fp16_va(void); +int ggml_cpu_has_wasm_simd(void); +int ggml_cpu_has_blas(void); + +#ifdef __cplusplus +} +#endif diff --git a/Whisper/source/whisper.cpp b/Whisper/source/whisper.cpp new file mode 100644 index 0000000..268774d --- /dev/null +++ b/Whisper/source/whisper.cpp @@ -0,0 +1,3601 @@ +#define WHISPER_BUILD +#include "whisper.h" + +#include "ggml.h" + +#include <algorithm> +#include <cassert> +#define _USE_MATH_DEFINES +#include <cmath> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <map> +#include <string> +#include <thread> +#include <vector> +#include <regex> + +#define USE_FLASH_ATTN +//#define USE_FLASH_FF + +// available whisper models +enum e_model { + MODEL_UNKNOWN, + MODEL_TINY, + MODEL_BASE, + MODEL_SMALL, + MODEL_MEDIUM, + MODEL_LARGE, +}; + +static const std::map<std::string, std::pair<int, std::string>> g_lang = { + { "en", { 0, "english", } }, + { "zh", { 1, "chinese", } }, + { "de", { 2, "german", } }, + { "es", { 3, "spanish", } }, + { "ru", { 4, "russian", } }, + { "ko", { 5, "korean", } }, + { "fr", { 6, "french", } }, + { "ja", { 7, "japanese", } }, + { "pt", { 8, "portuguese", } }, + { "tr", { 9, "turkish", } }, + { "pl", { 10, "polish", } }, + { "ca", { 11, "catalan", } }, + { "nl", { 12, "dutch", } }, + { "ar", { 13, "arabic", } }, + { "sv", { 14, "swedish", } }, + { "it", { 15, "italian", } }, + { "id", { 16, "indonesian", } }, + { "hi", { 17, "hindi", } }, + { "fi", { 18, "finnish", } }, + { "vi", { 19, "vietnamese", } }, + { "iw", { 20, "hebrew", } }, + { "uk", { 21, "ukrainian", } }, + { "el", { 22, "greek", } }, + { "ms", { 23, "malay", } }, + { "cs", { 24, "czech", } }, + { "ro", { 25, "romanian", } }, + { "da", { 26, "danish", } }, + { "hu", { 27, "hungarian", } }, + { "ta", { 28, "tamil", } }, + { "no", { 29, "norwegian", } }, + { "th", { 30, "thai", } }, + { "ur", { 31, "urdu", } }, + { "hr", { 32, "croatian", } }, + { "bg", { 33, "bulgarian", } }, + { "lt", { 34, "lithuanian", } }, + { "la", { 35, "latin", } }, + { "mi", { 36, "maori", } }, + { "ml", { 37, "malayalam", } }, + { "cy", { 38, "welsh", } }, + { "sk", { 39, "slovak", } }, + { "te", { 40, "telugu", } }, + { "fa", { 41, "persian", } }, + { "lv", { 42, "latvian", } }, + { "bn", { 43, "bengali", } }, + { "sr", { 44, "serbian", } }, + { "az", { 45, "azerbaijani", } }, + { "sl", { 46, "slovenian", } }, + { "kn", { 47, "kannada", } }, + { "et", { 48, "estonian", } }, + { "mk", { 49, "macedonian", } }, + { "br", { 50, "breton", } }, + { "eu", { 51, "basque", } }, + { "is", { 52, "icelandic", } }, + { "hy", { 53, "armenian", } }, + { "ne", { 54, "nepali", } }, + { "mn", { 55, "mongolian", } }, + { "bs", { 56, "bosnian", } }, + { "kk", { 57, "kazakh", } }, + { "sq", { 58, "albanian", } }, + { "sw", { 59, "swahili", } }, + { "gl", { 60, "galician", } }, + { "mr", { 61, "marathi", } }, + { "pa", { 62, "punjabi", } }, + { "si", { 63, "sinhala", } }, + { "km", { 64, "khmer", } }, + { "sn", { 65, "shona", } }, + { "yo", { 66, "yoruba", } }, + { "so", { 67, "somali", } }, + { "af", { 68, "afrikaans", } }, + { "oc", { 69, "occitan", } }, + { "ka", { 70, "georgian", } }, + { "be", { 71, "belarusian", } }, + { "tg", { 72, "tajik", } }, + { "sd", { 73, "sindhi", } }, + { "gu", { 74, "gujarati", } }, + { "am", { 75, "amharic", } }, + { "yi", { 76, "yiddish", } }, + { "lo", { 77, "lao", } }, + { "uz", { 78, "uzbek", } }, + { "fo", { 79, "faroese", } }, + { "ht", { 80, "haitian creole", } }, + { "ps", { 81, "pashto", } }, + { "tk", { 82, "turkmen", } }, + { "nn", { 83, "nynorsk", } }, + { "mt", { 84, "maltese", } }, + { "sa", { 85, "sanskrit", } }, + { "lb", { 86, "luxembourgish", } }, + { "my", { 87, "myanmar", } }, + { "bo", { 88, "tibetan", } }, + { "tl", { 89, "tagalog", } }, + { "mg", { 90, "malagasy", } }, + { "as", { 91, "assamese", } }, + { "tt", { 92, "tatar", } }, + { "haw", { 93, "hawaiian", } }, + { "ln", { 94, "lingala", } }, + { "ha", { 95, "hausa", } }, + { "ba", { 96, "bashkir", } }, + { "jw", { 97, "javanese", } }, + { "su", { 98, "sundanese", } }, +}; + +static const size_t MB = 1024*1024; + +static const std::map<e_model, size_t> MEM_REQ_MODEL = { + { MODEL_TINY, 74ull*MB }, + { MODEL_BASE, 142ull*MB }, + { MODEL_SMALL, 466ull*MB }, + { MODEL_MEDIUM, 1464ull*MB }, + { MODEL_LARGE, 2952ull*MB }, +}; + +static const std::map<e_model, size_t> MEM_REQ_MEMORY = { + { MODEL_TINY, 12ull*MB }, + { MODEL_BASE, 24ull*MB }, + { MODEL_SMALL, 70ull*MB }, + { MODEL_MEDIUM, 184ull*MB }, + { MODEL_LARGE, 306ull*MB }, +}; + +static const std::map<e_model, size_t> MEM_REQ_ENCODE = { + { MODEL_TINY, 80ull*MB }, + { MODEL_BASE, 128ull*MB }, + { MODEL_SMALL, 300ull*MB }, + { MODEL_MEDIUM, 680ull*MB }, + { MODEL_LARGE, 1100ull*MB }, +}; + +static const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = { + { MODEL_TINY, 104ull*MB }, + { MODEL_BASE, 138ull*MB }, + { MODEL_SMALL, 208ull*MB }, + { MODEL_MEDIUM, 280ull*MB }, + { MODEL_LARGE, 354ull*MB }, +}; + +static const std::map<e_model, size_t> MEM_REQ_DECODE = { + { MODEL_TINY, 200ull*MB }, + { MODEL_BASE, 202ull*MB }, + { MODEL_SMALL, 204ull*MB }, + { MODEL_MEDIUM, 206ull*MB }, + { MODEL_LARGE, 208ull*MB }, +}; + +static const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = { + { MODEL_TINY, 32ull*MB }, + { MODEL_BASE, 44ull*MB }, + { MODEL_SMALL, 64ull*MB }, + { MODEL_MEDIUM, 84ull*MB }, + { MODEL_LARGE, 110ull*MB }, +}; + +struct whisper_mel { + int n_len; + int n_mel; + + std::vector<float> data; +}; + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector<float> data; +}; + +struct whisper_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 51864; + + std::map<token, id> token_to_id; + std::map<id, token> id_to_token; + + id token_eot = 50256; + id token_sot = 50257; + id token_prev = 50360; + id token_solm = 50361; // ?? + id token_not = 50362; // no timestamps + id token_beg = 50363; + + // available tasks + static const id token_translate = 50358; + static const id token_transcribe = 50359; + + bool is_multilingual() const { + return n_vocab == 51865; + } +}; + +struct whisper_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector<whisper_token_data> tokens; +}; + +// medium +// hparams: { +// 'n_mels': 80, +// 'n_vocab': 51864, +// 'n_audio_ctx': 1500, +// 'n_audio_state': 1024, +// 'n_audio_head': 16, +// 'n_audio_layer': 24, +// 'n_text_ctx': 448, +// 'n_text_state': 1024, +// 'n_text_head': 16, +// 'n_text_layer': 24 +// } +// +// default hparams (Whisper tiny) +struct whisper_hparams { + int32_t n_vocab = 51864; + int32_t n_audio_ctx = 1500; + int32_t n_audio_state = 384; + int32_t n_audio_head = 6; + int32_t n_audio_layer = 4; + int32_t n_text_ctx = 448; + int32_t n_text_state = 384; + int32_t n_text_head = 6; + int32_t n_text_layer = 4; + int32_t n_mels = 80; + int32_t f16 = 1; +}; + +// audio encoding layer +struct whisper_layer_encoder { + // encoder.blocks.*.attn_ln + struct ggml_tensor * attn_ln_0_w; + struct ggml_tensor * attn_ln_0_b; + + // encoder.blocks.*.attn.out + struct ggml_tensor * attn_ln_1_w; + struct ggml_tensor * attn_ln_1_b; + + // encoder.blocks.*.attn.query + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_q_b; + + // encoder.blocks.*.attn.key + struct ggml_tensor * attn_k_w; + + // encoder.blocks.*.attn.value + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_v_b; + + // encoder.blocks.*.mlp_ln + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + + // encoder.blocks.*.mlp.0 + struct ggml_tensor * mlp_0_w; + struct ggml_tensor * mlp_0_b; + + // encoder.blocks.*.mlp.2 + struct ggml_tensor * mlp_1_w; + struct ggml_tensor * mlp_1_b; +}; + +// token decoding layer +struct whisper_layer_decoder { + // decoder.blocks.*.attn_ln + struct ggml_tensor * attn_ln_0_w; + struct ggml_tensor * attn_ln_0_b; + + // decoder.blocks.*.attn.out + struct ggml_tensor * attn_ln_1_w; + struct ggml_tensor * attn_ln_1_b; + + // decoder.blocks.*.attn.query + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_q_b; + + // decoder.blocks.*.attn.key + struct ggml_tensor * attn_k_w; + + // decoder.blocks.*.attn.value + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_v_b; + + // decoder.blocks.*.cross_attn_ln + struct ggml_tensor * cross_attn_ln_0_w; + struct ggml_tensor * cross_attn_ln_0_b; + + // decoder.blocks.*.cross_attn.out + struct ggml_tensor * cross_attn_ln_1_w; + struct ggml_tensor * cross_attn_ln_1_b; + + // decoder.blocks.*.cross_attn.query + struct ggml_tensor * cross_attn_q_w; + struct ggml_tensor * cross_attn_q_b; + + // decoder.blocks.*.cross_attn.key + struct ggml_tensor * cross_attn_k_w; + + // decoder.blocks.*.cross_attn.value + struct ggml_tensor * cross_attn_v_w; + struct ggml_tensor * cross_attn_v_b; + + // decoder.blocks.*.mlp_ln + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + + // decoder.blocks.*.mlp.0 + struct ggml_tensor * mlp_0_w; + struct ggml_tensor * mlp_0_b; + + // decoder.blocks.*.mlp.2 + struct ggml_tensor * mlp_1_w; + struct ggml_tensor * mlp_1_b; +}; + +struct whisper_model { + e_model type = MODEL_UNKNOWN; + + whisper_hparams hparams; + whisper_filters filters; + + // encoder.positional_embedding + struct ggml_tensor * e_pe; + + // encoder.conv1 + struct ggml_tensor * e_conv_1_w; + struct ggml_tensor * e_conv_1_b; + + // encoder.conv2 + struct ggml_tensor * e_conv_2_w; + struct ggml_tensor * e_conv_2_b; + + // encoder.ln_post + struct ggml_tensor * e_ln_w; + struct ggml_tensor * e_ln_b; + + // decoder.positional_embedding + struct ggml_tensor * d_pe; // DD + + // decoder.token_embedding + struct ggml_tensor * d_te; // DD + + // decoder.ln + struct ggml_tensor * d_ln_w; // DD + struct ggml_tensor * d_ln_b; // DD + + std::vector<whisper_layer_encoder> layers_encoder; + std::vector<whisper_layer_decoder> layers_decoder; + + // key + value memory + struct ggml_tensor * memory_k; + struct ggml_tensor * memory_v; + + struct ggml_tensor * memory_cross_k; + struct ggml_tensor * memory_cross_v; + + // context + struct ggml_context * ctx; + struct ggml_context * ctx_mem; + + // tensors + int n_loaded; + std::map<std::string, struct ggml_tensor *> tensors; +}; + +struct whisper_context { + int64_t t_load_us = 0; + int64_t t_mel_us = 0; + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_start_us = 0; + + std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors + std::vector<uint8_t> buf_memory; + std::vector<uint8_t> buf_compute; + std::vector<uint8_t> buf_compute_layer; + + whisper_model model; + whisper_vocab vocab; + + whisper_mel mel; + + std::vector<float> probs; + std::vector<float> logits; + + std::vector<whisper_segment> result_all; + + std::vector<whisper_token> prompt_past; + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg; + int64_t t_last; + whisper_token tid_last; + std::vector<float> energy; // PCM signal energy + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx; // 0 - use default +}; + +template<typename T> +static void read_safe(std::ifstream& fin, T& dest) +{ + fin.read((char*)& dest, sizeof(T)); +} + +// load the model from a ggml file +// +// file format: +// +// - hparams +// - pre-computed mel filters +// - vocab +// - weights +// +// see the convert-pt-to-ggml.py script for details +// +static bool whisper_model_load(const std::string & fname, whisper_context & wctx) { + logDebug( u8"%s: loading model from '%s'", __func__, fname.c_str() ); + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + logError( u8"%s: failed to open '%s'", __func__, fname.c_str() ); + return false; + } + + // verify magic + { + uint32_t magic; + read_safe(fin, magic); + if (magic != 0x67676d6c) { + logError( u8"%s: invalid model file '%s' (bad magic)", __func__, fname.c_str() ); + return false; + } + } + + //load hparams + { + auto & hparams = model.hparams; + + read_safe(fin, hparams.n_vocab); + read_safe(fin, hparams.n_audio_ctx); + read_safe(fin, hparams.n_audio_state); + read_safe(fin, hparams.n_audio_head); + read_safe(fin, hparams.n_audio_layer); + read_safe(fin, hparams.n_text_ctx); + read_safe(fin, hparams.n_text_state); + read_safe(fin, hparams.n_text_head); + read_safe(fin, hparams.n_text_layer); + read_safe(fin, hparams.n_mels); + read_safe(fin, hparams.f16); + + assert(hparams.n_text_state == hparams.n_audio_state); + + if (hparams.n_audio_layer == 4) { + model.type = e_model::MODEL_TINY; + } + + if (hparams.n_audio_layer == 6) { + model.type = e_model::MODEL_BASE; + } + + if (hparams.n_audio_layer == 12) { + model.type = e_model::MODEL_SMALL; + } + + if (hparams.n_audio_layer == 24) { + model.type = e_model::MODEL_MEDIUM; + } + + if (hparams.n_audio_layer == 32) { + model.type = e_model::MODEL_LARGE; + } + + logDebug( u8"%s: n_vocab = %d", __func__, hparams.n_vocab); + logDebug( u8"%s: n_audio_ctx = %d", __func__, hparams.n_audio_ctx); + logDebug( u8"%s: n_audio_state = %d", __func__, hparams.n_audio_state); + logDebug( u8"%s: n_audio_head = %d", __func__, hparams.n_audio_head); + logDebug( u8"%s: n_audio_layer = %d", __func__, hparams.n_audio_layer); + logDebug( u8"%s: n_text_ctx = %d", __func__, hparams.n_text_ctx); + logDebug( u8"%s: n_text_state = %d", __func__, hparams.n_text_state); + logDebug( u8"%s: n_text_head = %d", __func__, hparams.n_text_head); + logDebug( u8"%s: n_text_layer = %d", __func__, hparams.n_text_layer); + logDebug( u8"%s: n_mels = %d", __func__, hparams.n_mels); + logDebug( u8"%s: f16 = %d", __func__, hparams.f16); + logDebug( u8"%s: type = %d", __func__, model.type); + + wctx.buf_model = new std::vector<uint8_t>(); + wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); + wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); + wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(fin, filters.n_mel); + read_safe(fin, filters.n_fft); + + filters.data.resize(filters.n_mel * filters.n_fft); + fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float)); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(fin, n_vocab); + + //if (n_vocab != model.hparams.n_vocab) { + // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + //} + + std::string word; + std::vector<char> tmp; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(fin, len); + + if (len > 0) { + tmp.resize(len); + fin.read(&tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + // seems like we have an empty-string token in multi-language models (i = 50256) + //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + } + + vocab.n_vocab = model.hparams.n_vocab; + if (vocab.is_multilingual()) { + vocab.token_eot++; + vocab.token_sot++; + vocab.token_prev++; + vocab.token_solm++; + vocab.token_not++; + vocab.token_beg++; + } + + if (n_vocab < model.hparams.n_vocab) { + logDebug( u8"%s: adding %d extra tokens", __func__, model.hparams.n_vocab - n_vocab ); + for (int i = n_vocab; i < model.hparams.n_vocab; i++) { + if (i > vocab.token_beg) { + word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; + } else if (i == vocab.token_eot) { + word = "[_EOT_]"; + } else if (i == vocab.token_sot) { + word = "[_SOT_]"; + } else if (i == vocab.token_prev) { + word = "[_PREV_]"; + } else if (i == vocab.token_not) { + word = "[_NOT_]"; + } else if (i == vocab.token_beg) { + word = "[_BEG_]"; + } else { + word = "[_extra_token_" + std::to_string(i) + "]"; + } + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + } + + { + // this is the total memory required to run the inference + const size_t mem_required = + wctx.buf_model->size() + + wctx.buf_memory.size() + + wctx.buf_compute.size() + + wctx.buf_compute_layer.size(); + + logDebug( u8"%s: mem_required = %7.2f MB", __func__, mem_required / 1024.0 / 1024.0 ); + } + + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + size_t ctx_size = 0; + + { + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + // encoder + { + // TODO: F16 .. maybe not? + ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; + + ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b + + ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b + + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w; + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b; + } + + // decoder + { + // TODO: F16 .. maybe not? + ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; + + ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; + + ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w; + ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b; + } + + // encoder layers + { + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w + ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + } + + // decoder layers + { + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w + ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + // + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b + } + + ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead + + logDebug( u8"%s: ggml ctx size = %7.2f MB", __func__, ctx_size / ( 1024.0 * 1024.0 ) ); + } + + // create the ggml context + { + struct ggml_init_params params; + params.mem_size = wctx.buf_model->size(); + params.mem_buffer = wctx.buf_model->data(); + + model.ctx = ggml_init(params); + if (!model.ctx) { + logError( u8"%s: ggml_init() failed", __func__ ); + return false; + } + } + + // prepare memory for the weights + { + auto & ctx = model.ctx; + + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + model.layers_encoder.resize(n_audio_layer); + model.layers_decoder.resize(n_text_layer); + + // encoder + { + model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); + + model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.positional_embedding"] = model.e_pe; + + model.tensors["encoder.conv1.weight"] = model.e_conv_1_w; + model.tensors["encoder.conv1.bias"] = model.e_conv_1_b; + + model.tensors["encoder.conv2.weight"] = model.e_conv_2_w; + model.tensors["encoder.conv2.bias"] = model.e_conv_2_b; + + model.tensors["encoder.ln_post.weight"] = model.e_ln_w; + model.tensors["encoder.ln_post.bias"] = model.e_ln_b; + + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers_encoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + } + } + + // decoder + { + model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx); + + model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); + + model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.positional_embedding"] = model.d_pe; + + model.tensors["decoder.token_embedding.weight"] = model.d_te; + + model.tensors["decoder.ln.weight"] = model.d_ln_w; + model.tensors["decoder.ln.bias"] = model.d_ln_b; + + for (int i = 0; i < n_text_layer; ++i) { + auto & layer = model.layers_decoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; + } + } + } + + // create the ggml memory context + { + struct ggml_init_params params; + params.mem_size = wctx.buf_memory.size(); + params.mem_buffer = wctx.buf_memory.data(); + + model.ctx_mem = ggml_init(params); + if (!model.ctx_mem) { + logError( u8"%s: ggml_init() failed", __func__ ); + return false; + } + } + + // key + value memory + { + auto & ctx = model.ctx_mem; + + const auto & hparams = model.hparams; + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + const int n_text_ctx = hparams.n_text_ctx; + + // key/value memory for the self-attention layer + { + const int n_mem = n_text_layer*n_text_ctx; + const int n_elements = n_text_state*n_mem; + + model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + } + + // key/value memory for the cross-attention layer + { + const int n_audio_ctx = hparams.n_audio_ctx; + + const int n_mem = n_text_layer*n_audio_ctx; + const int n_elements = n_text_state*n_mem; + + model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + } + + const size_t memory_size = + ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + + ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); + + logDebug( u8"%s: memory size = %7.2f MB", __func__, memory_size/1024.0/1024.0); + } + + // load weights + { + size_t total_size = 0; + + model.n_loaded = 0; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + read_safe(fin, n_dims); + read_safe(fin, length); + read_safe(fin, ftype); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[3] = { 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(fin, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector<char> tmp(length); // create a buffer + fin.read( &tmp[0], tmp.size() ); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (model.tensors.find(name) == model.tensors.end()) { + logError( u8"%s: unknown tensor '%s' in model file", __func__, name.data() ); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ggml_nelements(tensor) != nelements) { + logError( u8"%s: tensor '%s' has wrong size in model file", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + logError( u8"%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]", + __func__, name.data(), tensor->ne[ 0 ], tensor->ne[ 1 ], tensor->ne[ 2 ], ne[ 0 ], ne[ 1 ], ne[ 2 ] ); + return false; + } + + const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); + + if (nelements*bpe != ggml_nbytes(tensor)) { + logError( u8"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes( tensor ), nelements* bpe ); + return false; + } + + fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor)); + + //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + total_size += ggml_nbytes(tensor); + model.n_loaded++; + } + + logDebug( u8"%s: model size = %7.2f MB", __func__, total_size / 1024.0 / 1024.0 ); + + if (model.n_loaded == 0) { + logWarning( u8"%s: WARN no tensors loaded from model file - assuming empty model for testing", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + logError( u8"%s: ERROR not all tensors loaded from model file - expected %zu, got %d", __func__, model.tensors.size(), model.n_loaded ); + return false; + } + } + + fin.close(); + + return true; +} + +// evaluate the encoder +// +// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder +// part of the transformer model and returns the encoded features +// +// - model: the model +// - n_threads: number of threads to use +// - mel_offset: offset in the mel spectrogram (i.e. audio offset) +// +static bool whisper_encode( + whisper_context & wctx, + const int n_threads, + const int mel_offset) { + const auto & model = wctx.model; + const auto & mel_inp = wctx.mel; + const auto & hparams = model.hparams; + + const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + const int n_layer = hparams.n_audio_layer; + + const int n_mels = hparams.n_mels; + assert(mel_inp.n_mel == n_mels); + + struct ggml_init_params params; + params.mem_size = wctx.buf_compute.size(); + params.mem_buffer = wctx.buf_compute.data(); + + struct ggml_context * ctx0 = ggml_init(params); + + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); + assert(mel->type == GGML_TYPE_F32); + { + float * dst = (float *) mel->data; + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + + for (int j = 0; j < mel_inp.n_mel; ++j) { + for (int i = i0; i < i1; ++i) { + dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; + } + } + } + Tracing::delayTensor( "enc.input", mel ); + + struct ggml_tensor * cur; + + // convolution + gelu + { + cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); + Tracing::delayTensor( "enc.conv1", cur ); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, + model.e_conv_1_b, + cur), + cur); + + cur = ggml_gelu(ctx0, cur); + Tracing::delayTensor( "enc.temp1", cur ); + + cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, + model.e_conv_2_b, + cur), + cur); + + cur = ggml_gelu(ctx0, cur); + } + + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + //static int iter = -1; + //const int n_iter = 1500/n_ctx; + + //iter = (iter + 1) % n_iter; + + //if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + //} + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + + struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + + cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); + // =================================================================== + + // original: + //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_encoder[il]; + + // create separate context for each layer to reduce memory usage + + struct ggml_init_params paramsL; + paramsL.mem_size = wctx.buf_compute_layer.size(); + paramsL.mem_buffer = wctx.buf_compute_layer.data(); + + struct ggml_context * ctxL = ggml_init(paramsL); + + Tracing::delayTensor( { "enc.layer[ %i ].in", il }, inpL ); + + // norm + { + cur = ggml_norm(ctxL, inpL); + if( il == 0 ) + Tracing::delayTensor( "enc-norm", cur ); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctxL, + ggml_mul(ctxL, + ggml_repeat(ctxL, layer.attn_ln_0_w, cur), + cur), + ggml_repeat(ctxL, layer.attn_ln_0_b, cur)); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + layer.attn_q_w, + cur); + if( il == 0 ) + Tracing::delayTensor( "enc-Qcur", Qcur ); + + Qcur = ggml_add(ctxL, + ggml_repeat(ctxL, + layer.attn_q_b, + Qcur), + Qcur); + + //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, + layer.attn_k_w, + cur); + if( il == 0 ) + Tracing::delayTensor( "enc-Kcur", Kcur ); + + //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, + layer.attn_v_w, + cur); + if( il == 0 ) + Tracing::delayTensor( "enc-Vcur", Vcur ); + + Vcur = ggml_add(ctxL, + ggml_repeat(ctxL, + layer.attn_v_b, + Vcur), + Vcur); + + // ------ + +#ifdef USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Kcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctxL, + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + Vcur, + n_state/n_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); + if( il == 0 ) + Tracing::delayTensor( "enc-KQV", KQV ); +#else + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Kcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + + struct ggml_tensor * KQ_scaled = + ggml_scale(ctxL, + KQ, + ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + ); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled); + + //struct ggml_tensor * V_trans = + // ggml_permute(ctxL, + // ggml_cpy(ctxL, + // Vcur, + // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + // 1, 2, 0, 3); + + //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + + struct ggml_tensor * V = + ggml_cpy(ctxL, + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + Vcur, + n_state/n_head, n_head, n_ctx), + 0, 2, 1, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) + ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); +#endif + + struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctxL, + KQV_merged, + ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx)); + } + + // projection + { + cur = ggml_mul_mat(ctxL, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctxL, + ggml_repeat(ctxL, layer.attn_ln_1_b, cur), + cur); + } + + // add the input + cur = ggml_add(ctxL, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctxL, inpFF); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctxL, + ggml_mul(ctxL, + ggml_repeat(ctxL, layer.mlp_ln_w, cur), + cur), + ggml_repeat(ctxL, layer.mlp_ln_b, cur)); + } + +#ifdef USE_FLASH_FF + cur = ggml_flash_ff(ctxL, + ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else + // fully connected + cur = ggml_mul_mat(ctxL, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctxL, + ggml_repeat(ctxL, layer.mlp_0_b, cur), + cur); + + // GELU activation + cur = ggml_gelu(ctxL, cur); + + // projection + cur = ggml_mul_mat(ctxL, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctxL, + ggml_repeat(ctxL, layer.mlp_1_b, cur), + cur); +#endif + } + + // output from this layer + struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); + + { + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + ggml_build_forward_expand(&gf, inpO); + ggml_graph_compute (ctxL, &gf); + Tracing::writeDelayedTensors(); + //ggml_graph_print(&gf); + } + + // TODO: this is a hack to have per-layer computation graphs - need to come up with something better + // input for next layer (inpO -> inpL) + memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); + inpL->op = GGML_OP_NONE; + inpL->src0 = nullptr; + inpL->src1 = nullptr; + + //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); + + ggml_free(ctxL); + } + Tracing::tensor( "enc.layers", inpL ); + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur); + + // cur = ln_f_g*cur + ln_f_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.e_ln_w, cur), + cur), + ggml_repeat(ctx0, model.e_ln_b, cur)); + } + + // run the computation + { + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + ggml_build_forward_expand(&gf, cur); + ggml_graph_compute (ctx0, &gf); + + //ggml_graph_print(&gf); + } + + Tracing::tensor( "encode-out", cur ); + + // cur + //{ + // printf("ne0 = %d\n", cur->ne[0]); + // printf("ne1 = %d\n", cur->ne[1]); + // for (int i = 0; i < 10; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("... "); + // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("\n"); + //} + + // pre-compute cross-attention memory + { + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + // TODO: hack to disconnect the encoded features from the previous graph + cur->op = GGML_OP_NONE; + cur->src0 = nullptr; + cur->src1 = nullptr; + + for (int il = 0; il < model.hparams.n_text_layer; ++il) { + auto & layer = model.layers_decoder[il]; + + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); + + Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); + + Vcross = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.cross_attn_v_b, + Vcross), + Vcross); + + //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); + } + + ggml_graph_compute(ctx0, &gf); + } + + //////////////////////////////////////////////////////////////////////////// + + //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0); + + ggml_free(ctx0); + + return true; +} + +// evaluate the decoder +// +// given text prompt + audio features -> predicts the probabilities for the next token +// +// - model: the model +// - n_threads: number of threads to use +// - tokens: text prompt +// - n_tokens: number of tokens in the prompt +// - n_past: number of past tokens to prefix the prompt with +// +static bool whisper_decode( + whisper_context & wctx, + const int n_threads, + const whisper_token * tokens, + const int n_tokens, + const int n_past) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + auto & logits_out = wctx.logits; + auto & probs_out = wctx.probs; + + const int n_vocab = hparams.n_vocab; + + const int n_ctx = hparams.n_text_ctx; + const int n_state = hparams.n_text_state; + const int n_head = hparams.n_text_head; + const int n_layer = hparams.n_text_layer; + + const int N = n_tokens; + const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + + struct ggml_init_params params; + params.mem_size = wctx.buf_compute.size(); + params.mem_buffer = wctx.buf_compute.data(); + + struct ggml_context * ctx0 = ggml_init(params); + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, tokens, N*ggml_element_size(embd)); + + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + for (int i = 0; i < N; ++i) { + ((int32_t *) position->data)[i] = n_past + i; + } + + // token encoding + position encoding + struct ggml_tensor * cur = + ggml_add(ctx0, + ggml_get_rows(ctx0, model.d_te, embd), + ggml_get_rows(ctx0, model.d_pe, position)); + Tracing::delayTensor( "dec-rows", cur ); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_decoder[il]; + + struct ggml_init_params paramsL; + paramsL.mem_size = wctx.buf_compute_layer.size(); + paramsL.mem_buffer = wctx.buf_compute_layer.data(); + + struct ggml_context * ctxL = ggml_init(paramsL); + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + // norm + { + cur = ggml_norm(ctxL, inpL); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctxL, + ggml_mul(ctxL, + ggml_repeat(ctxL, layer.attn_ln_0_w, cur), + cur), + ggml_repeat(ctxL, layer.attn_ln_0_b, cur)); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctxL, + ggml_repeat(ctxL, + layer.attn_q_b, + Qcur), + Qcur); + + Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, + layer.attn_k_w, + cur); + + Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctxL, + ggml_repeat(ctxL, + layer.attn_v_b, + Vcur), + Vcur); + + // store key and value to memory + { + struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); + } + + // ------ + + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state), + n_state/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + + //struct ggml_tensor * KQ_scaled = + // ggml_scale(ctxL, + // KQ, + // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked); + if( 0 == il ) Tracing::delayTensor( "dec-KQ", KQ_soft_max ); + + struct ggml_tensor * V_trans = + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state), + n_state/n_head, n_head, n_past + N), + 1, 2, 0, 3); + + struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + if( 0 == il ) Tracing::delayTensor( "dec-KQV", KQV ); + + struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctxL, + KQV_merged, + ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + } + + { + cur = ggml_mul_mat(ctxL, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctxL, + ggml_repeat(ctxL, layer.attn_ln_1_b, cur), + cur); + } + + // add the input + struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL); + + // norm + { + cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctxL, + ggml_mul(ctxL, + ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur), + cur), + ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur)); + } + + // cross-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, + layer.cross_attn_q_w, + cur); + + Qcur = ggml_add(ctxL, + ggml_repeat(ctxL, + layer.cross_attn_q_b, + Qcur), + Qcur); + + Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + + // Kcross is already scaled + struct ggml_tensor * Kcross = + ggml_reshape_3d(ctxL, + ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state), + n_state/n_head, n_head, M); + + struct ggml_tensor * Vcross = + ggml_reshape_3d(ctxL, + ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state), + n_state/n_head, n_head, M); + + // ------ + + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); + + //struct ggml_tensor * KQ_scaled = + // ggml_scale(ctxL, + // KQ, + // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + // no masking for cross-attention + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ); + + struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3); + + struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + if( 0 == il ) Tracing::delayTensor( "dec-KQV", KQV ); + + struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_state, N) + cur = ggml_cpy(ctxL, + KQV_merged, + ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + } + + // projection + { + cur = ggml_mul_mat(ctxL, + layer.cross_attn_ln_1_w, + cur); + + cur = ggml_add(ctxL, + ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur), + cur); + } + + // add the input + cur = ggml_add(ctxL, cur, inpCA); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctxL, inpFF); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctxL, + ggml_mul(ctxL, + ggml_repeat(ctxL, layer.mlp_ln_w, cur), + cur), + ggml_repeat(ctxL, layer.mlp_ln_b, cur)); + } + + // fully connected + cur = ggml_mul_mat(ctxL, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctxL, + ggml_repeat(ctxL, layer.mlp_0_b, cur), + cur); + + // GELU activation + cur = ggml_gelu(ctxL, cur); + + // projection + cur = ggml_mul_mat(ctxL, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctxL, + ggml_repeat(ctxL, layer.mlp_1_b, cur), + cur); + } + + // output from this layer + struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); + + { + ggml_build_forward_expand(&gf, inpO); + ggml_graph_compute (ctxL, &gf); + Tracing::writeDelayedTensors(); + //ggml_graph_print(&gf); + } + + // TODO: this is a hack to have per-layer computation graphs - need to come up with something better + // input for next layer (inpO -> inpL) + memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); + inpL->op = GGML_OP_NONE; + inpL->src0 = nullptr; + inpL->src1 = nullptr; + + if (N > 1) { + //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); + } + + ggml_free(ctxL); + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.d_ln_w, cur), + cur), + ggml_repeat(ctx0, model.d_ln_b, cur)); + } + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + + // logits -> probs + cur = ggml_dup(ctx0, logits); + cur = ggml_soft_max(ctx0, cur); // in-place + + // run the computation + { + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + ggml_build_forward_expand(&gf, cur); + ggml_graph_compute (ctx0, &gf); + } + + logits_out.resize(N*n_vocab); + memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + + probs_out.resize(N*n_vocab); + memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab); + + if (N > 1) { + //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; + //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); + //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx); + } + + ggml_free(ctx0); + // Hash::vector( "probs", probs_out ); + Tracing::vector( "probs", probs_out ); + + return true; +} + +// the most basic sampling scheme - select the top token +static whisper_token_data whisper_sample_best( + const whisper_vocab & vocab, + const float * probs, + bool force_timestamp, + bool is_initial) { + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; + + int n_logits = vocab.id_to_token.size(); + + std::vector<std::pair<double, whisper_vocab::id>> probs_id; + probs_id.reserve(n_logits); + + for (int i = 0; i < n_logits; i++) { + probs_id.emplace_back(probs[i], i); + } + + { + double sum_ts = 0.0; + double max_ts = -1.0; + double max_tx = -1.0; + + for (int i = 0; i < vocab.token_beg; i++) { + max_tx = std::max(max_tx, probs_id[i].first); + } + + const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; + const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits; + + // the initial timestamp cannot be larger than 100 + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial) { + for (int i = i0; i < n_logits; ++ i) { + probs_id[i].first = -INFINITY; + } + } + + for (int i = vocab.token_beg; i < i1; i++) { + sum_ts += probs_id[i].first; + if (probs_id[i].first > max_ts) { + max_ts = probs_id[i].first; + result.tid = probs_id[i].second; + } + } + + // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a + // timestamp token + if (sum_ts > max_tx || force_timestamp) { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 + for (int i = 0; i < vocab.token_beg; i++) { + probs_id[i].first = -INFINITY; + } + } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + // find the top K tokens + const int top_k = 4; + + std::partial_sort( + probs_id.begin(), + probs_id.begin() + top_k, probs_id.end(), + [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) { + return a.first > b.first; + }); + + probs_id.resize(top_k); + + //printf("\n"); + //for (int i = 0; i < (int) probs_id.size(); i++) { + // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); + //} + + int res = 0; + while ((probs_id[res].second == vocab.token_sot || + probs_id[res].second == vocab.token_solm || + probs_id[res].second == vocab.token_not) && + res < (int) probs_id.size() - 1) { + res++; + } + + result.id = probs_id[res].second; + result.p = probs_id[res].first; + + return result; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +static std::string to_timestamp(int64_t t, bool comma = false) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + + return std::string(buf); +} + +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const std::vector<float> & in, std::vector<float> & out) { + int N = in.size(); + + out.resize(N*2); + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + float angle = 2*M_PI*k*n/N; + re += in[n]*cos(angle); + im -= in[n]*sin(angle); + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(const std::vector<float> & in, std::vector<float> & out) { + out.resize(in.size()*2); + + int N = in.size(); + + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + if (N%2 == 1) { + dft(in, out); + return; + } + + std::vector<float> even; + std::vector<float> odd; + + for (int i = 0; i < N; i++) { + if (i % 2 == 0) { + even.push_back(in[i]); + } else { + odd.push_back(in[i]); + } + } + + std::vector<float> even_fft; + std::vector<float> odd_fft; + + fft(even, even_fft); + fft(odd, odd_fft); + + for (int k = 0; k < N/2; k++) { + float theta = 2*M_PI*k/N; + + float re = cos(theta); + float im = -sin(theta); + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 +static bool log_mel_spectrogram( + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int fft_size, + const int fft_step, + const int n_mel, + const int n_threads, + const whisper_filters & filters, + const bool speed_up, + whisper_mel & mel) { + + // Hanning window + std::vector<float> hann; + hann.resize(fft_size); + for (int i = 0; i < fft_size; i++) { + hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size))); + } + + mel.n_mel = n_mel; + mel.n_len = (n_samples)/fft_step; + mel.data.resize(mel.n_mel*mel.n_len); + + const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2); + + //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); + //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); + + std::vector<std::thread> workers(n_threads); + for (int iw = 0; iw < n_threads; ++iw) { + workers[iw] = std::thread([&](int ith) { + std::vector<float> fft_in; + fft_in.resize(fft_size); + for (int i = 0; i < fft_size; i++) { + fft_in[i] = 0.0; + } + + std::vector<float> fft_out; + fft_out.resize(2*fft_size); + + for (int i = ith; i < mel.n_len; i += n_threads) { + const int offset = i*fft_step; + + // apply Hanning window + for (int j = 0; j < fft_size; j++) { + if (offset + j < n_samples) { + fft_in[j] = hann[j]*samples[offset + j]; + } else { + fft_in[j] = 0.0; + } + } + + // FFT -> mag^2 + fft(fft_in, fft_out); + + for (int j = 0; j < fft_size; j++) { + fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); + } + for (int j = 1; j < fft_size/2; j++) { + //if (i == 0) { + // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); + //} + fft_out[j] += fft_out[fft_size - j]; + } + if (i == 0) { + //for (int j = 0; j < fft_size; j++) { + // printf("%d: %e\n", j, fft_out[j]); + //} + } + + if (speed_up) { + // scale down in the frequency domain results in a speed up in the time domain + for (int j = 0; j < n_fft; j++) { + fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]); + } + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + + for (int k = 0; k < n_fft; k++) { + sum += fft_out[k]*filters.data[j*n_fft + k]; + } + if (sum < 1e-10) { + sum = 1e-10; + } + + sum = log10(sum); + + mel.data[j*mel.n_len + i] = sum; + } + } + }, iw); + } + + for (int iw = 0; iw < n_threads; ++iw) { + workers[iw].join(); + } + + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] > mmax) { + mmax = mel.data[i]; + } + } + //printf("%s: max = %f\n", __func__, mmax); + + mmax -= 8.0; + + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] < mmax) { + mel.data[i] = mmax; + } + + mel.data[i] = (mel.data[i] + 4.0)/4.0; + } + + return true; +} + +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) { + std::vector<std::string> words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector<whisper_vocab::id> tokens; + for (const auto & word : words) { + if (word.empty()) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + while (j > i) { + auto it = vocab.token_to_id.find(word.substr(i, j-i)); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + break; + } + --j; + } + if (i == n) { + break; + } + if (j == i) { + auto sub = word.substr(i, 1); + if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { + tokens.push_back(vocab.token_to_id.at(sub)); + } else { + logWarning( u8"%s: unknown token '%s'", __func__, sub.data() ); + } + ++i; + } + } + } + + return tokens; +} + +// +// interface implementation +// + +struct whisper_context * whisper_init(const char * path_model) { + ggml_time_init(); + + whisper_context * ctx = new whisper_context; + + const int64_t t_start_us = ggml_time_us(); + + ctx->t_start_us = t_start_us; + + if (!whisper_model_load(path_model, *ctx)) { + logError( u8"%s: failed to load model from '%s'", __func__, path_model ); + delete ctx; + return nullptr; + } + + ctx->t_load_us = ggml_time_us() - t_start_us; + + return ctx; +} + +void whisper_free(struct whisper_context * ctx) { + if (ctx) { + if (ctx->model.ctx) { + ggml_free(ctx->model.ctx); + } + if (ctx->model.ctx_mem) { + ggml_free(ctx->model.ctx_mem); + } + if (ctx->buf_model) { + delete ctx->buf_model; + } + delete ctx; + } +} + +int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + logError( u8"%s: failed to compute mel spectrogram", __func__ ); + return -1; + } + + ctx->t_mel_us = ggml_time_us() - t_start_us; + + return 0; +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { + logError( u8"%s: failed to compute mel spectrogram", __func__ ); + return -1; + } + + ctx->t_mel_us = ggml_time_us() - t_start_us; + + return 0; +} + +int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel) { + if (n_mel != WHISPER_N_MEL) { + logError( u8"%s: invalid number of mel bands: %d (expected %d)", __func__, n_mel, WHISPER_N_MEL ); + return -1; + } + + ctx->mel.n_len = n_len; + ctx->mel.n_mel = n_mel; + + ctx->mel.data.resize(n_len*n_mel); + memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + if (!whisper_encode(*ctx, n_threads, offset)) { + logError( u8"%s: failed to eval", __func__ ); + return -1; + } + + ctx->t_encode_us += ggml_time_us() - t_start_us; + + return 0; +} + +int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) { + logError( u8"%s: failed to eval", __func__ ); + return 1; + } + + ctx->t_decode_us += ggml_time_us() - t_start_us; + + return 0; +} + +struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { + const int64_t t_start_sample_us = ggml_time_us(); + + const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + + return res; +} + +struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { + const int64_t t_start_sample_us = ggml_time_us(); + + const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + + return res; +} + +int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + logError( u8"%s: too many resulting tokens: %d (max %d)", __func__, (int)res.size(), n_max_tokens ); + return -1; + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int whisper_lang_max_id() { + auto max_id = 0; + for (const auto & kv : g_lang) { + max_id = std::max(max_id, kv.second.first); + } + + return max_id; +} + +int whisper_lang_id(const char * lang) { + if (!g_lang.count(lang)) { + for (const auto & kv : g_lang) { + if (kv.second.second == lang) { + return kv.second.first; + } + } + + logError( u8"%s: unknown language '%s'", __func__, lang ); + return -1; + } + + return g_lang.at(lang).first; +} + +const char * whisper_lang_str(int id) { + for (const auto & kv : g_lang) { + if (kv.second.first == id) { + return kv.first.c_str(); + } + } + + logError( u8"%s: unknown language id %d", __func__, id ); + return nullptr; +} + +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + const int seek = offset_ms/10; + + if (seek < 0) { + logError( u8"%s: offset %dms is before the start of the audio", __func__, offset_ms ); + return -1; + } + + if (seek >= ctx->mel.n_len) { + logError( u8"%s: offset %dms is past the end of the audio (%dms)", __func__, offset_ms, ctx->mel.n_len * 10 ); + return -2; + } + + // run the encoder + if (whisper_encode(ctx, seek, n_threads) != 0) { + logError( u8"%s: failed to encode", __func__ ); + return -6; + } + + const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) }; + + if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { + logError( u8"%s: failed to decode", __func__ ); + return -7; + } + + std::vector<std::pair<float, int>> probs_id; + for (const auto & kv : g_lang) { + const auto token_lang = whisper_token_lang(ctx, kv.second.first); + probs_id.emplace_back( ctx->probs[token_lang], kv.second.first ); + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // softmax + { + float sum = 0; + for (const auto & kv : probs_id) { + sum += exp(kv.first); + } + + for (auto & kv : probs_id) { + kv.first = exp(kv.first) / sum; + } + } + + { + for (int i = 0; i < (int) probs_id.size(); i++) { + if (lang_probs) { + lang_probs[probs_id[i].second] = probs_id[i].first; + } + + //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first); + } + } + + return probs_id[0].second; +} + +int whisper_n_len(struct whisper_context * ctx) { + return ctx->mel.n_len; +} + +int whisper_n_vocab(struct whisper_context * ctx) { + return ctx->vocab.n_vocab; +} + +int whisper_n_text_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_ctx; +} + +int whisper_is_multilingual(struct whisper_context * ctx) { + return ctx->vocab.is_multilingual() ? 1 : 0; +} + +float * whisper_get_probs(struct whisper_context * ctx) { + return ctx->probs.data(); +} + +const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +whisper_token whisper_token_eot(struct whisper_context * ctx) { + return ctx->vocab.token_eot; +} + +whisper_token whisper_token_sot(struct whisper_context * ctx) { + return ctx->vocab.token_sot; +} + +whisper_token whisper_token_prev(struct whisper_context * ctx) { + return ctx->vocab.token_prev; +} + +whisper_token whisper_token_solm(struct whisper_context * ctx) { + return ctx->vocab.token_solm; +} + +whisper_token whisper_token_not(struct whisper_context * ctx) { + return ctx->vocab.token_not; +} + +whisper_token whisper_token_beg(struct whisper_context * ctx) { + return ctx->vocab.token_beg; +} + +whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { + return whisper_token_sot(ctx) + 1 + lang_id; +} + +whisper_token whisper_token_translate(void) { + return whisper_vocab::token_translate; +} + +whisper_token whisper_token_transcribe(void) { + return whisper_vocab::token_transcribe; +} + +void whisper_print_timings(struct whisper_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + logInfo( u8"%s: load time = %8.2f ms", __func__, ctx->t_load_us / 1000.0f ); + logInfo( u8"%s: mel time = %8.2f ms", __func__, ctx->t_mel_us / 1000.0f ); + logInfo( u8"%s: sample time = %8.2f ms", __func__, ctx->t_sample_us / 1000.0f ); + logInfo( u8"%s: encode time = %8.2f ms / %.2f ms per layer", __func__, + ctx->t_encode_us / 1000.0f, ctx->t_encode_us / 1000.0f / ctx->model.hparams.n_audio_layer ); + logInfo( u8"%s: decode time = %8.2f ms / %.2f ms per layer", __func__, + ctx->t_decode_us / 1000.0f, ctx->t_decode_us / 1000.0f / ctx->model.hparams.n_text_layer ); + logInfo( u8"%s: total time = %8.2f ms", __func__, ( t_end_us - ctx->t_start_us ) / 1000.0f ); +} + +void whisper_reset_timings(struct whisper_context * ctx) { + ctx->t_sample_us = 0; + ctx->t_encode_us = 0; + ctx->t_decode_us = 0; +} + +const char * whisper_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + + return s.c_str(); +} + +//////////////////////////////////////////////////////////////////////////// + +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { + struct whisper_full_params result; + + switch (strategy) { + case WHISPER_SAMPLING_GREEDY: + { + result = { + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, + + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + + /*.greedy =*/ { + /*.n_past =*/ 0, + }, + + /*.beam_search =*/ { + /*.n_past =*/ -1, + /*.beam_width =*/ -1, + /*.n_best =*/ -1, + }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + }; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + result = { + /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, + + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + + /*.greedy =*/ { + /*.n_past =*/ -1, + }, + + /*.beam_search =*/ { + /*.n_past =*/ 0, + /*.beam_width =*/ 10, + /*.n_best =*/ 5, + }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + }; + } break; + } + + return result; +} + +// forward declarations +static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum); + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { + auto segment = ctx->result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int) segment.tokens.size(); i++) { + const auto & token = segment.tokens[i]; + if (token.id >= whisper_token_eot(ctx)) { + continue; + } + + const auto txt = whisper_token_to_str(ctx, token.id); + + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0) { + // split here + ctx->result_all.back().text = std::move(text); + ctx->result_all.back().t1 = token.t0; + ctx->result_all.back().tokens.resize(i); + + ctx->result_all.push_back({}); + ctx->result_all.back().t0 = token.t0; + ctx->result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + ctx->result_all.back().tokens.insert( + ctx->result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + acc = 0; + text = ""; + + segment = ctx->result_all.back(); + i = -1; + + res++; + } else { + acc += cur; + text += txt; + } + } + + ctx->result_all.back().text = std::move(text); + + return res; +} + +int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples) { + // clear old results + auto & result_all = ctx->result_all; + + result_all.clear(); + + // compute log mel spectrogram + if (params.speed_up) { + if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + logError( u8"%s: failed to compute log mel spectrogram", __func__ ); + return -1; + } + } else { + if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + logError( u8"%s: failed to compute log mel spectrogram", __func__ ); + return -2; + } + } + + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { + std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + if (lang_id < 0) { + logError( u8"%s: failed to auto-detect language", __func__ ); + return -3; + } + + params.language = whisper_lang_str(lang_id); + + logInfo( u8"%s: auto-detected language: %s (p = %f)", __func__, params.language, probs[ whisper_lang_id( params.language ) ] ); + } + + if (params.token_timestamps) { + ctx->t_beg = 0; + ctx->t_last = 0; + ctx->tid_last = 0; + ctx->energy = get_signal_energy(samples, n_samples, 32); + } + + const int seek_start = params.offset_ms/10; + const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); + + // if length of spectrogram is less than 1s (100 samples), then return + // basically don't process anything that is less than 1s + // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 + if (seek_end < 100 + seek_start) { + return 0; + } + + // the accumulated text context so far + auto & prompt_past = ctx->prompt_past; + if (params.no_context) { + prompt_past.clear(); + } + + // prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + // parse tokens from the pointer + for (int i = 0; i < params.prompt_n_tokens; i++) { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } + + // overwrite audio_ctx + ctx->exp_n_audio_ctx = params.audio_ctx; + + // these tokens determine the task that will be performed + std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) }; + if (whisper_is_multilingual(ctx)) { + const int lang_id = whisper_lang_id(params.language); + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); + if (params.translate) { + prompt_init.push_back(whisper_token_translate()); + } else { + prompt_init.push_back(whisper_token_transcribe()); + } + } + + int progress_prev = 0; + int progress_step = 5; + + std::vector<whisper_token_data> tokens_cur; + tokens_cur.reserve(whisper_n_text_ctx(ctx)); + + std::vector<whisper_token> prompt; + prompt.reserve(whisper_n_text_ctx(ctx)); + + // main loop + int seek = seek_start; + while (true) { + const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); + while (progress_cur >= progress_prev + progress_step) { + progress_prev += progress_step; + if (params.print_progress) { + logInfo( u8"%s: progress = %3d%%", __func__, progress_prev ); + } + } + + // of only 1 second left, then stop + if (seek + 100 >= seek_end) { + break; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); + } + + if (params.encoder_begin_callback) { + if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { + logDebug( u8"%s: encoder_begin_callback returned false - aborting", __func__ ); + break; + } + } + + // encode audio features starting at offset seek + if (whisper_encode(ctx, seek, params.n_threads) != 0) { + logError( u8"%s: failed to encode", __func__ ); + return -4; + } + + int n_past = 0; + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty()) { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + + prompt = { whisper_token_prev(ctx) }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + + prompt_past.clear(); + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + } + + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + int seek_delta = 100*WHISPER_CHUNK_SIZE; + + // print the prompt + //printf("\n\n"); + //for (int i = 0; i < prompt.size(); i++) { + // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); + //} + //printf("\n\n"); + + // the accumulated transcription in the current interation + int result_len = 0; + tokens_cur.clear(); + + bool failed = false; + bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { + logError( u8"%s: failed to decode", __func__ ); + return -5; + } + + n_past += prompt.size(); + prompt.clear(); + + // very basic greedy sampling strategy: + // + // - always take the most probable token + // + // more sophisticated sampling strategies could be implemented here, but we keep it simple + // feel free to experiment! + // + { + const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + break; + } + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + + // add it to the context + prompt.push_back(token.id); + tokens_cur.push_back(token); + + //{ + // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; + // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + //} + + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { + if (result_len == 0) { + if (seek + seek_delta + 100 >= seek_end) { + result_len = i + 1; + } else { + failed = true; + break; + } + } + + if (params.single_segment) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + } + + break; + } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + break; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance + // the sliding window by 1 second + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; + break; + } + } + + if (failed) { + // when we fail to sample timestamp token, retry by clearing the past prompt + // if it fails again, then we advance the window by 1 second + if (!prompt_past.empty()) { + prompt_past.clear(); + } else { + logWarning( u8"%s: failed to generate timestamp token - skipping one second", __func__ ); + seek += 100; + } + continue; + } + + // shrink down to result_len + tokens_cur.resize(result_len); + + for (const auto & r : tokens_cur) { + prompt_past.push_back(r.id); + } + + // store the text from this iteration + if (!tokens_cur.empty()) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + + std::string text; + + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { + } else { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(ctx, params.max_len); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + } + } + text = ""; + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + i++; + } + i--; + t0 = t1; + i0 = i + 1; + } + } + + if (!text.empty()) { + const auto t1 = seek + seek_delta; + + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(ctx, params.max_len); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + } + } + } + + seek += seek_delta; + } + + return 0; +} + +int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors) { + if (n_processors == 1) { + return whisper_full(ctx, params, samples, n_samples); + } + + int ret = 0; + + // prepare separate contexts for each thread + std::vector<struct whisper_context> ctxs(n_processors - 1); + + for (int i = 0; i < n_processors - 1; ++i) { + ctxs[i] = *ctx; + + auto & model = ctxs[i].model; + + // create the ggml memory context + { + struct ggml_init_params params; + params.mem_size = ctxs[i].buf_memory.size(); + params.mem_buffer = ctxs[i].buf_memory.data(); + + model.ctx_mem = ggml_init(params); + if (!model.ctx_mem) { + logError( u8"%s: ggml_init() failed", __func__ ); + return false; + } + } + + // separate key + value memory for each processor + { + auto & ctx = model.ctx_mem; + + const auto & hparams = model.hparams; + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + const int n_text_ctx = hparams.n_text_ctx; + + // key/value memory for the self-attention layer + { + const int n_mem = n_text_layer*n_text_ctx; + const int n_elements = n_text_state*n_mem; + + model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + } + + // key/value memory for the cross-attention layer + { + const int n_audio_ctx = hparams.n_audio_ctx; + + const int n_mem = n_text_layer*n_audio_ctx; + const int n_elements = n_text_state*n_mem; + + model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + } + } + } + + const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; + const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector<std::thread> workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) { + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) { + workers[i].join(); + } + + const int64_t offset_t = (int64_t) params.offset_ms/10.0; + + // combine results into ctx->result_all + for (int i = 0; i < n_processors - 1; ++i) { + auto & results_i = ctxs[i].result_all; + + for (int j = 0; j < (int) results_i.size(); ++j) { + // correct the segment timestamp taking into account the offset + results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (!ctx->result_all.empty()) { + results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1); + } + + ctx->result_all.push_back(std::move(results_i[j])); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) { + params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); + } + } + + ctx->t_mel_us += ctxs[i].t_mel_us; + ctx->t_sample_us += ctxs[i].t_sample_us; + ctx->t_encode_us += ctxs[i].t_encode_us; + ctx->t_decode_us += ctxs[i].t_decode_us; + } + + // average the timings + ctx->t_mel_us /= n_processors; + ctx->t_sample_us /= n_processors; + ctx->t_encode_us /= n_processors; + ctx->t_decode_us /= n_processors; + + // print information about the audio boundaries + logDebug( u8"%s: the audio has been split into %d chunks at the following times:", __func__, n_processors ); + for( int i = 0; i < n_processors - 1; ++i ) + logDebug( u8"%s: split %d - %s", __func__, ( i + 1 ), to_timestamp( 100 * ( ( i + 1 ) * n_samples_per_processor ) / WHISPER_SAMPLE_RATE + offset_t ).c_str() ); + logDebug( u8"%s: the transcription quality may be degraded near these boundaries", __func__ ); + + return ret; +} + +int whisper_full_n_segments(struct whisper_context * ctx) { + return ctx->result_all.size(); +} + +int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].t1; +} + +const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].text.c_str(); +} + +int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].tokens.size(); +} + +const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); +} + +whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].p; +} + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +static int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + +static int64_t sample_to_timestamp(int i_sample) { + return (100*i_sample)/WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string & text) { + float res = 0.0f; + + for (size_t i = 0; i < text.size(); ++i) { + if (text[i] == ' ') { + res += 0.01f; + } else if (text[i] == ',') { + res += 2.00f; + } else if (text[i] == '.') { + res += 3.00f; + } else if (text[i] == '!') { + res += 3.00f; + } else if (text[i] == '?') { + res += 3.00f; + } else if (text[i] >= '0' && text[i] <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { + const int hw = n_samples_per_half_window; + + std::vector<float> result(n_samples); + + for (int i = 0; i < n_samples; i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < n_samples) { + sum += fabs(signal[i + j]); + } + } + result[i] = sum/(2*hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx->result_all[i_segment]; + auto & tokens = segment.tokens; + + const int n_samples = ctx->energy.size(); + + if (n_samples == 0) { + logWarning( u8"%s: no signal data available", __func__ ); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int n = tokens.size(); + + if (n == 0) { + return; + } + + if (n == 1) { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto & t_beg = ctx->t_beg; + auto & t_last = ctx->t_last; + auto & tid_last = ctx->tid_last; + + for (int j = 0; j < n; ++j) { + auto & token = tokens[j]; + + if (j == 0) { + if (token.id == whisper_token_beg(ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += ctx->energy[k]; + } + + const float thold = 0.5*sum/ns; + + { + int k = s0; + if (ctx->energy[k] > thold && j > 0) { + while (k > 0 && ctx->energy[k] > thold) { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (ctx->energy[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (ctx->energy[k] > thold) { + while (k < n_samples - 1 && ctx->energy[k] > thold) { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (ctx->energy[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + //for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(ctx)) { + // continue; + // } + //} +} diff --git a/Whisper/source/whisper.h b/Whisper/source/whisper.h new file mode 100644 index 0000000..92c14da --- /dev/null +++ b/Whisper/source/whisper.h @@ -0,0 +1,330 @@ +#ifndef WHISPER_H +#define WHISPER_H + +#include <stdint.h> +#include <stdbool.h> + +#ifdef WHISPER_SHARED +# ifdef _WIN32 +# ifdef WHISPER_BUILD +# define WHISPER_API __declspec(dllexport) +# else +# define WHISPER_API __declspec(dllimport) +# endif +# else +# define WHISPER_API __attribute__ ((visibility ("default"))) +# endif +#else +# define WHISPER_API +#endif + +#define WHISPER_SAMPLE_RATE 16000 +#define WHISPER_N_FFT 400 +#define WHISPER_N_MEL 80 +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_CHUNK_SIZE 30 + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads + // concurrently. + // + // Basic usage: + // + // #include "whisper.h" + // + // ... + // + // struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin"); + // + // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + // fprintf(stderr, "failed to process audio\n"); + // return 7; + // } + // + // const int n_segments = whisper_full_n_segments(ctx); + // for (int i = 0; i < n_segments; ++i) { + // const char * text = whisper_full_get_segment_text(ctx, i); + // printf("%s", text); + // } + // + // whisper_free(ctx); + // + // ... + // + // This is a demonstration of the most straightforward usage of the library. + // "pcmf32" contains the RAW audio data in 32-bit floating point format. + // + // The interface also allows for more fine-grained control over the computation, but it requires a deeper + // understanding of how the model works. + // + + struct whisper_context; + + typedef int whisper_token; + + typedef struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + float vlen; // voice length of the token + } whisper_token_data; + + // Allocates all memory needed for the model and loads the model from the given file. + // Returns NULL on failure. + WHISPER_API struct whisper_context * whisper_init(const char * path_model); + + // Frees all memory allocated by the model. + WHISPER_API void whisper_free(struct whisper_context * ctx); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the provided whisper context. + // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 80 + // Returns 0 on success + WHISPER_API int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel); + + // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + WHISPER_API int whisper_encode( + struct whisper_context * ctx, + int offset, + int n_threads); + + // Run the Whisper decoder to obtain the logits and probabilities for the next token. + // Make sure to call whisper_encode() first. + // tokens + n_tokens is the provided context for the decoder. + // n_past is the number of tokens to use from previous decoder calls. + // Returns 0 on success + WHISPER_API int whisper_decode( + struct whisper_context * ctx, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + // Token sampling methods. + // These are provided for convenience and can be used after each call to whisper_decode(). + // You can also implement your own sampling method using the whisper_get_probs() function. + // whisper_sample_best() returns the token with the highest probability + // whisper_sample_timestamp() returns the most probable timestamp token + WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); + WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns -1 on failure + // TODO: not sure if correct + WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens); + + // Largest language id (i.e. number of available languages - 1) + WHISPER_API int whisper_lang_max_id(); + + // Return the id of the specified language, returns -1 if not found + // Examples: + // "de" -> 2 + // "german" -> 2 + WHISPER_API int whisper_lang_id(const char * lang); + + // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str(int id); + + // Use mel data at offset_ms to try and auto-detect the spoken language + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + // Returns the top language id or negative on failure + // If not null, fills the lang_probs array with the probabilities of all languages + // The array must be whispe_lang_max_id() + 1 in size + // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); + + // The probabilities for the next token + WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); + + // Token Id -> String. Uses the vocabulary in the provided context + WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); + + // Special tokens + WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); + + // Task tokens + WHISPER_API whisper_token whisper_token_translate (void); + WHISPER_API whisper_token whisper_token_transcribe(void); + + // Performance information + WHISPER_API void whisper_print_timings(struct whisper_context * ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + + // Print system information + WHISPER_API const char * whisper_print_system_info(void); + + //////////////////////////////////////////////////////////////////////////// + + // Available sampling strategies + enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // Always select the most probable token + WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! + }; + + // Text segment callback + // Called on every newly generated text segment + // Use the whisper_full_...() functions to obtain the text segments + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + + // Parameters for the whisper_full() function + // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: + // whisper_full_default_params() + struct whisper_full_params { + enum whisper_sampling_strategy strategy; + + int n_threads; + int n_max_text_ctx; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool translate; + bool no_context; + bool single_segment; // force single segment output (useful for streaming) + bool print_special; + bool print_progress; + bool print_realtime; + bool print_timestamps; + + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + int max_tokens; // max tokens per segment (0 = no limit) + + // [EXPERIMENTAL] speed-up techniques + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + int audio_ctx; // overwrite the audio context size (0 = use default) + + // tokens to provide the whisper model as initial prompt + // these are prepended to any existing text context from a previous call + const whisper_token * prompt_tokens; + int prompt_n_tokens; + + // for auto-detection, set to nullptr, "" or "auto" + const char * language; + + struct { + int n_past; + } greedy; + + struct { + int n_past; + int beam_width; + int n_best; + } beam_search; + + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + }; + + WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Uses the specified decoding strategy to obtain the text. + WHISPER_API int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full() + // It seems this approach can offer some speedup in some cases. + // However, the transcription accuracy can be worse at the beginning and end of each chunk. + WHISPER_API int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors); + + // Number of generated text segments. + // A segment can be a few words, a sentence, or even a paragraph. + WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + + // Get the start and end time of the specified segment. + WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); + + // Get the text of the specified segment. + WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + + // Get number of tokens in the specified segment. + WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + + // Get the token text of the specified token in the specified segment. + WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment. + // This contains probabilities, timestamps, etc. + WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment. + WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/Whisper/stdafx.cpp b/Whisper/stdafx.cpp new file mode 100644 index 0000000..1577c4e --- /dev/null +++ b/Whisper/stdafx.cpp @@ -0,0 +1 @@ +#include "stdafx.h"
\ No newline at end of file diff --git a/Whisper/stdafx.h b/Whisper/stdafx.h new file mode 100644 index 0000000..c84e10d --- /dev/null +++ b/Whisper/stdafx.h @@ -0,0 +1,43 @@ +#pragma once +#define _USE_MATH_DEFINES +#include <stdint.h> +#include <assert.h> +#include <array> +#include <vector> +#include <algorithm> +#include <emmintrin.h> // SSE 2 +#include <smmintrin.h> // SSE 4.1 + +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +// Setup Windows SDK to only enable features available since Windows 8.0 +#include <WinSDKVer.h> +#define _WIN32_WINNT _WIN32_WINNT_WIN8 +#define NTDDI_VERSION NTDDI_WIN8 +#include <sdkddkver.h> + +#include <windows.h> +#include <d3d11.h> +#include <atlcomcli.h> +#include "Utils/Logger.h" +#include "Utils/miscUtils.h" + +// Build both legacy and DirectCompute implementations +#define BUILD_BOTH_VERSIONS 0 + +// Build hybrid model which uses DirectCompute only for the encode step of the algorithm, and decodes on CPU, using AVX SIMD and the Windows' built-in thread pool. +#define BUILD_HYBRID_VERSION 0 + +// Enable debug traces. Should be disabled in production, the feature comes with a huge performance overhead. +// When enabled, while computing things it streams gigabytes of data into that binary file. +// See Tools / compareTraces project for a command-line app to compare these traces. +#define SAVE_DEBUG_TRACE 0 + +// In addition to collecting total GPU times per compute shader, also collect and print performance data about individual invocations of some of the most expensive shaders +// The feature is relatively cheap in terms of performance overhead, but pretty much useless in production, and clutters debug console with all these numbers +#define PROFILER_COLLECT_TAGS 0 + +// Reshape some of the tensors to a better VRAM layout while loading a model +// So far, the feature is only used on AMD GPUs. On AMD Vega integrated GPUs it helps by up to 30%. +// Should be enabled in production build +#define RESHAPED_MATRIX_MULTIPLY 1
\ No newline at end of file diff --git a/Whisper/whisper.def b/Whisper/whisper.def new file mode 100644 index 0000000..69ada14 --- /dev/null +++ b/Whisper/whisper.def @@ -0,0 +1,7 @@ +LIBRARY +EXPORTS setupLogger +EXPORTS loadModel +EXPORTS initMediaFoundation +EXPORTS findLanguageKeyW +EXPORTS findLanguageKeyA +EXPORTS getSupportedLanguages
\ No newline at end of file diff --git a/Whisper/whisperCom.cpp b/Whisper/whisperCom.cpp new file mode 100644 index 0000000..a0205ec --- /dev/null +++ b/Whisper/whisperCom.cpp @@ -0,0 +1,1070 @@ +#include "stdafx.h" +#include "ML/Tensor.h" +#include "API/iMediaFoundation.cl.h" +#include "API/iContext.cl.h" +#include "API/sFullParams.h" +#include "Utils/ReadStream.h" +#include "ML/testUtils.h" +#include "Utils/Trace/tracing.h" +#include "modelFactory.h" +#if BUILD_BOTH_VERSIONS + +namespace +{ + LPCTSTR traceFilePath = LR"(C:\Temp\2remove\Whisper\ref.bin)"; + using ComLight::iReadStream; +} + +struct whisper_context; +struct ggml_tensor; + +class GpuEncTest +{ + DirectCompute::Tensor mel, gpuResult; + + DirectCompute::Tensor tempGpu; + const ggml_tensor* tempRef = nullptr; +public: + GpuEncTest( const whisper_context& wctx, const int mel_offset ); + void compare( const ggml_tensor* expected ) const; + void compareMel( const ggml_tensor* expected ) const; +}; + +class GpuDecTest +{ + std::vector<float> logits, probs; + const ggml_tensor* tempRef = nullptr; + +public: + + GpuDecTest( const whisper_context& wctx, const int* tokens, const int n_tokens, const int n_past ); + + void postpone( const ggml_tensor* t ); + void comparePostponed(); + void compare( const std::vector<float>& cpuLogits, const std::vector<float>& cpuProbs ) const; +}; + +static DirectCompute::Tensor gpuEncode( const whisper_context& wctx, const int mel_offset ); + +#include "source/whisper.cpp" +#include "API/iContext.cl.h" +#include "../ComLightLib/comLightServer.h" +#include "ML/mlStartup.h" +#include "Whisper/WhisperContext.h" +#include "Whisper/ModelLoader.h" +#include "Whisper/WhisperModel.h" +#include "source.compat/convertThings.h" + +namespace Whisper +{ + inline HRESULT isZero( int i ) + { + return ( 0 == i ) ? S_OK : E_FAIL; + } + + class Context : public ComLight::ObjectRoot<iContext>, + public iModel + { + virtual HRESULT COMLIGHTCALL isMultilingual() override final + { + return whisper_is_multilingual( &ctx ) ? S_OK : S_FALSE; + } + virtual const char* COMLIGHTCALL stringFromToken( whisper_token token ) override final + { + return whisper_token_to_str( &ctx, token ); + } + virtual HRESULT COMLIGHTCALL getSpecialTokens( SpecialTokens& rdi ) + { + rdi.TranscriptionEnd = whisper_token_eot( &ctx ); + rdi.TranscriptionStart = whisper_token_sot( &ctx ); + rdi.PreviousWord = whisper_token_prev( &ctx ); + rdi.SentenceStart = whisper_token_solm( &ctx ); + rdi.Not = whisper_token_not( &ctx ); + rdi.TranscriptionBegin = whisper_token_beg( &ctx ); + rdi.TaskTranslate = whisper_token_translate(); + rdi.TaskTranscribe = whisper_token_transcribe(); + return S_OK; + } + + // Performance information + virtual HRESULT COMLIGHTCALL timingsPrint() override final + { + whisper_print_timings( &ctx ); + return S_OK; + } + virtual HRESULT COMLIGHTCALL timingsReset() override final + { + whisper_reset_timings( &ctx ); + return S_OK; + } + + virtual HRESULT COMLIGHTCALL fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi ) + { + static_assert( (int)eSamplingStrategy::Greedy == whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY ); + static_assert( (int)eSamplingStrategy::BeamSearch == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH ); + const whisper_sampling_strategy wss = (whisper_sampling_strategy)(int)strategy; + whisper_full_params wfp = whisper_full_default_params( wss ); + + *rdi = makeNewParams( wfp ); + return S_OK; + } + + HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) override final + { + whisper_full_params wfp = makeOldParams( params, this ); + const float* const samples = buffer->getPcmMono(); + const uint32_t n_samples = buffer->countSamples(); + return isZero( whisper_full( &ctx, wfp, samples, (int)n_samples ) ); + } + + HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) override final + { + logError( u8"The CPU reference implementation doesn’t support streaming" ); + return E_NOTIMPL; + } + HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) override final + { + logError( u8"The CPU reference implementation doesn’t support audio capture" ); + return E_NOTIMPL; + } + + HRESULT COMLIGHTCALL getResults( eResultFlags flags, iTranscribeResult** pp ) const override final + { + makeNewResults( &ctx, flags, pp ); + return S_OK; + } + + HRESULT loadImpl( iReadStream* stm ); + + virtual HRESULT COMLIGHTCALL createContext( iContext** pp ) override final + { + if( nullptr == pp ) + return E_POINTER; + *pp = this; + ( *pp )->AddRef(); + return S_OK; + } + + virtual HRESULT COMLIGHTCALL getModel( iModel** pp ) override final + { + if( nullptr == pp ) + return E_POINTER; + *pp = this; + ( *pp )->AddRef(); + return S_OK; + } + + public: + + Context() + { + if( nullptr != traceFilePath ) + Tracing::traceCreate( traceFilePath ); + } + + mutable whisper_context ctx; + + HRESULT load( iReadStream* stm ); + + ~Context() + { + Tracing::traceClose(); + + if( ctx.model.ctx ) + { + ggml_free( ctx.model.ctx ); + ctx.model.ctx = nullptr; + } + if( ctx.model.ctx_mem ) + { + ggml_free( ctx.model.ctx_mem ); + ctx.model.ctx_mem = nullptr; + } + if( ctx.buf_model ) + { + delete ctx.buf_model; + ctx.buf_model = nullptr; + } + } + + BEGIN_COM_MAP() + COM_INTERFACE_ENTRY( iModel ); + END_COM_MAP() + }; + + inline HRESULT readBytes( iReadStream* stm, void* rdi, size_t cb ) + { + if( cb > INT_MAX ) + return DISP_E_OVERFLOW; + if( cb == 0 ) + return S_FALSE; + int n; + CHECK( stm->read( rdi, (int)cb, n ) ); + if( n != (int)cb ) + return E_EOF; + return S_OK; + } + + template<typename T> + inline HRESULT readStruct( iReadStream* stm, T& dest ) + { + return readBytes( stm, &dest, sizeof( T ) ); + } + template<typename E> + inline HRESULT readVector( iReadStream* stm, std::vector<E>& vec ) + { + const size_t cb = sizeof( E ) * vec.size(); + if( cb > 0 ) + return readBytes( stm, vec.data(), cb ); + return S_FALSE; + } + + inline HRESULT readString( iReadStream* stm, std::string& str ) + { + uint32_t len; + CHECK( readStruct( stm, len ) ); + if( len > 0 ) + { + str.resize( len ); + return readBytes( stm, str.data(), len ); + } + else + { + str.clear(); + return S_FALSE; + } + } + + // load the model from a ggml file + // file format: + // - hparams + // - pre-computed mel filters + // - vocab + // - weights + // see the convert-pt-to-ggml.py script for details + HRESULT Context::loadImpl( iReadStream* stm ) + { + // WhisperModel wm; + // return wm.load( stm ); + + // Copy-pasted from whisper_model_load() function + auto& model = ctx.model; + auto& vocab = ctx.vocab; + + // verify magic + { + uint32_t magic; + int cbRead; + CHECK( stm->read( &magic, 4, cbRead ) ); + if( magic != 0x67676d6c ) + { + logError( u8"Invalid model file, bad magic" ); + return E_INVALIDARG; + } + } + + //load hparams + { + auto& hparams = model.hparams; + CHECK( readStruct( stm, hparams ) ); + assert( hparams.n_text_state == hparams.n_audio_state ); + + if( hparams.n_audio_layer == 4 ) + model.type = e_model::MODEL_TINY; + if( hparams.n_audio_layer == 6 ) + model.type = e_model::MODEL_BASE; + if( hparams.n_audio_layer == 12 ) + model.type = e_model::MODEL_SMALL; + if( hparams.n_audio_layer == 24 ) + model.type = e_model::MODEL_MEDIUM; + if( hparams.n_audio_layer == 32 ) + model.type = e_model::MODEL_LARGE; + + logDebug( u8"%s: n_vocab = %d", __func__, hparams.n_vocab ); + logDebug( u8"%s: n_audio_ctx = %d", __func__, hparams.n_audio_ctx ); + logDebug( u8"%s: n_audio_state = %d", __func__, hparams.n_audio_state ); + logDebug( u8"%s: n_audio_head = %d", __func__, hparams.n_audio_head ); + logDebug( u8"%s: n_audio_layer = %d", __func__, hparams.n_audio_layer ); + logDebug( u8"%s: n_text_ctx = %d", __func__, hparams.n_text_ctx ); + logDebug( u8"%s: n_text_state = %d", __func__, hparams.n_text_state ); + logDebug( u8"%s: n_text_head = %d", __func__, hparams.n_text_head ); + logDebug( u8"%s: n_text_layer = %d", __func__, hparams.n_text_layer ); + logDebug( u8"%s: n_mels = %d", __func__, hparams.n_mels ); + logDebug( u8"%s: f16 = %d", __func__, hparams.f16 ); + logDebug( u8"%s: type = %d", __func__, model.type ); + + ctx.buf_model = new std::vector<uint8_t>(); + ctx.buf_model->resize( MEM_REQ_MODEL.at( model.type ) ); + ctx.buf_memory.resize( MEM_REQ_MEMORY.at( model.type ) ); + ctx.buf_compute.resize( std::max( MEM_REQ_ENCODE.at( model.type ), MEM_REQ_DECODE.at( model.type ) ) ); + ctx.buf_compute_layer.resize( std::max( MEM_REQ_ENCODE_LAYER.at( model.type ), MEM_REQ_DECODE_LAYER.at( model.type ) ) ); + } + + // load mel filters + { + auto& filters = ctx.model.filters; + CHECK( readStruct( stm, filters.n_mel ) ); + CHECK( readStruct( stm, filters.n_fft ) ); + filters.data.resize( filters.n_mel * filters.n_fft ); + CHECK( readVector( stm, filters.data ) ); + } + + // load vocab + { + int32_t n_vocab = 0; + CHECK( readStruct( stm, n_vocab ) ); + + //if (n_vocab != model.hparams.n_vocab) { + // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + //} + + std::string word; + for( int i = 0; i < n_vocab; i++ ) + { + CHECK( readString( stm, word ) ); + vocab.token_to_id[ word ] = i; + vocab.id_to_token[ i ] = word; + } + + vocab.n_vocab = model.hparams.n_vocab; + if( vocab.is_multilingual() ) + { + vocab.token_eot++; + vocab.token_sot++; + vocab.token_prev++; + vocab.token_solm++; + vocab.token_not++; + vocab.token_beg++; + } + + if( n_vocab < model.hparams.n_vocab ) + { + logDebug( u8"%s: adding %d extra tokens", __func__, model.hparams.n_vocab - n_vocab ); + for( int i = n_vocab; i < model.hparams.n_vocab; i++ ) + { + if( i > vocab.token_beg ) + word = "[_TT_" + std::to_string( i - vocab.token_beg ) + "]"; + else if( i == vocab.token_eot ) + word = "[_EOT_]"; + else if( i == vocab.token_sot ) + word = "[_SOT_]"; + else if( i == vocab.token_prev ) + word = "[_PREV_]"; + else if( i == vocab.token_not ) + word = "[_NOT_]"; + else if( i == vocab.token_beg ) + word = "[_BEG_]"; + else + word = "[_extra_token_" + std::to_string( i ) + "]"; + + vocab.token_to_id[ word ] = i; + vocab.id_to_token[ i ] = word; + } + } + } + + { + // this is the total memory required to run the inference + const size_t mem_required = + ctx.buf_model->size() + + ctx.buf_memory.size() + + ctx.buf_compute.size() + + ctx.buf_compute_layer.size(); + logDebug( u8"%s: mem_required = %7.2f MB", __func__, mem_required / 1024.0 / 1024.0 ); + } + + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + size_t ctx_size = 0; + size_t ctx_mem_size = 0; + + { + const auto& hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + // encoder + { + // TODO: F16 .. maybe not? + ctx_size += n_audio_ctx * n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_pe; + + ctx_size += 3 * n_mels * n_audio_state * ggml_type_size( wtype ); // e_conv_1_w + ctx_size += n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_conv_1_b + + ctx_size += 3 * n_audio_state * n_audio_state * ggml_type_size( wtype ); // e_conv_2_w + ctx_size += n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_conv_2_b + + ctx_size += n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_ln_w; + ctx_size += n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_ln_b; + } + + // decoder + { + // TODO: F16 .. maybe not? + ctx_size += n_text_ctx * n_text_state * ggml_type_size( GGML_TYPE_F32 ); // d_pe; + + ctx_size += n_vocab * n_text_state * ggml_type_size( wtype ); // d_te; + + ctx_size += n_text_state * ggml_type_size( GGML_TYPE_F32 ); // d_ln_w; + ctx_size += n_text_state * ggml_type_size( GGML_TYPE_F32 ); // d_ln_b; + } + + // encoder layers + { + ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_ln_w + ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_ln_b + + ctx_size += n_audio_layer * ( 4 * n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // mlp_0_w + ctx_size += n_audio_layer * ( 4 * n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_0_b + + ctx_size += n_audio_layer * ( 4 * n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // mlp_1_w + ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_1_b + + ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_0_w + ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_0_b + + ctx_size += n_audio_layer * ( n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // attn_q_w + ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_q_b + + ctx_size += n_audio_layer * ( n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // attn_k_w + + ctx_size += n_audio_layer * ( n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // attn_v_w + ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_v_b + + ctx_size += n_audio_layer * ( n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // attn_ln_1_w + ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_1_b + } + + // decoder layers + { + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_ln_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_ln_b + + ctx_size += n_text_layer * ( 4 * n_text_state * n_text_state * ggml_type_size( wtype ) ); // mlp_0_w + ctx_size += n_text_layer * ( 4 * n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_0_b + + ctx_size += n_text_layer * ( 4 * n_text_state * n_text_state * ggml_type_size( wtype ) ); // mlp_1_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_1_b + + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_0_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_0_b + + ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // attn_q_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_q_b + + ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // attn_k_w + + ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // attn_v_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_v_b + + ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // attn_ln_1_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_1_b + // + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_ln_0_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_ln_0_b + + ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // cross_attn_q_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_q_b + + ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // cross_attn_k_w + + ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // cross_attn_v_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_v_b + + ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // cross_attn_ln_1_w + ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_ln_1_b + } + + ctx_mem_size += n_text_layer * n_text_ctx * n_text_state * ggml_type_size( GGML_TYPE_F16 ); // memory_k + ctx_mem_size += n_text_layer * n_text_ctx * n_text_state * ggml_type_size( GGML_TYPE_F16 ); // memory_v + + ctx_mem_size += n_text_layer * n_audio_ctx * n_text_state * ggml_type_size( GGML_TYPE_F16 ); // memory_cross_k + ctx_mem_size += n_text_layer * n_audio_ctx * n_text_state * ggml_type_size( GGML_TYPE_F16 ); // memory_cross_v + + ctx_size += ( 15 + 15 * n_audio_layer + 24 * n_text_layer ) * 256; // object overhead + + logDebug( u8"%s: ggml ctx size = %7.2f MB", __func__, ctx_size / ( 1024.0 * 1024.0 ) ); + } + + // create the ggml context + { + struct ggml_init_params params; + params.mem_size = ctx.buf_model->size(); + params.mem_buffer = ctx.buf_model->data(); + + model.ctx = ggml_init( params ); + if( !model.ctx ) + { + logError( u8"%s: ggml_init() failed", __func__ ); + return E_INVALIDARG; + } + } + + std::map<std::string, struct ggml_tensor*> tensors; + DirectCompute::ModelLoader loader{ model.hparams.n_audio_layer, model.hparams.n_text_layer }; + + // prepare memory for the weights + { + auto& ctx = model.ctx; + const auto& hparams = model.hparams; + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + model.layers_encoder.resize( n_audio_layer ); + model.layers_decoder.resize( n_text_layer ); + + // encoder + { + model.e_pe = ggml_new_tensor_2d( ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx ); + loader.add( model.e_pe, loader.model.enc.positionalEmbedding ); + + model.e_conv_1_w = ggml_new_tensor_3d( ctx, wtype, 3, n_mels, n_audio_state ); + model.e_conv_1_b = ggml_new_tensor_2d( ctx, GGML_TYPE_F32, 1, n_audio_state ); + loader.add( model.e_conv_1_w, model.e_conv_1_b, loader.model.enc.conv1 ); + + model.e_conv_2_w = ggml_new_tensor_3d( ctx, wtype, 3, n_audio_state, n_audio_state ); + model.e_conv_2_b = ggml_new_tensor_2d( ctx, GGML_TYPE_F32, 1, n_audio_state ); + loader.add( model.e_conv_2_w, model.e_conv_2_b, loader.model.enc.conv2 ); + + model.e_ln_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + model.e_ln_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + loader.add( model.e_ln_w, model.e_ln_b, loader.model.enc.lnPost ); + + // map by name + tensors[ "encoder.positional_embedding" ] = model.e_pe; + + tensors[ "encoder.conv1.weight" ] = model.e_conv_1_w; + tensors[ "encoder.conv1.bias" ] = model.e_conv_1_b; + + tensors[ "encoder.conv2.weight" ] = model.e_conv_2_w; + tensors[ "encoder.conv2.bias" ] = model.e_conv_2_b; + + tensors[ "encoder.ln_post.weight" ] = model.e_ln_w; + tensors[ "encoder.ln_post.bias" ] = model.e_ln_b; + + for( int i = 0; i < n_audio_layer; ++i ) + { + auto& layer = model.layers_encoder[ i ]; + auto& gpu = loader.model.enc.layers[ i ]; + + layer.mlp_ln_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + layer.mlp_ln_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + loader.add( layer.mlp_ln_w, layer.mlp_ln_b, gpu.mlpLn ); + + layer.mlp_0_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, 4 * n_audio_state ); + layer.mlp_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, 4 * n_audio_state ); + loader.add( layer.mlp_0_w, layer.mlp_0_b, gpu.mlp0 ); + + layer.mlp_1_w = ggml_new_tensor_2d( ctx, wtype, 4 * n_audio_state, n_audio_state ); + layer.mlp_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + loader.add( layer.mlp_1_w, layer.mlp_1_b, gpu.mlp1 ); + + layer.attn_ln_0_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + layer.attn_ln_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + loader.add( layer.attn_ln_0_w, layer.attn_ln_0_b, gpu.attnLn0 ); + + layer.attn_q_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, n_audio_state ); + layer.attn_q_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + loader.add( layer.attn_q_w, layer.attn_q_b, gpu.attnQuery ); + + layer.attn_k_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, n_audio_state ); + loader.add( layer.attn_k_w, gpu.attnKey ); + + layer.attn_v_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, n_audio_state ); + layer.attn_v_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + loader.add( layer.attn_v_w, layer.attn_v_b, gpu.attnValue ); + + layer.attn_ln_1_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, n_audio_state ); + layer.attn_ln_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state ); + loader.add( layer.attn_ln_1_w, layer.attn_ln_1_b, gpu.attnLn1 ); + + // map by name + tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp_ln.weight" ] = layer.mlp_ln_w; + tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp_ln.bias" ] = layer.mlp_ln_b; + + tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp.0.weight" ] = layer.mlp_0_w; + tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp.0.bias" ] = layer.mlp_0_b; + + tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp.2.weight" ] = layer.mlp_1_w; + tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp.2.bias" ] = layer.mlp_1_b; + + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn_ln.weight" ] = layer.attn_ln_0_w; + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn_ln.bias" ] = layer.attn_ln_0_b; + + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.query.weight" ] = layer.attn_q_w; + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.query.bias" ] = layer.attn_q_b; + + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.key.weight" ] = layer.attn_k_w; + + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.value.weight" ] = layer.attn_v_w; + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.value.bias" ] = layer.attn_v_b; + + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.out.weight" ] = layer.attn_ln_1_w; + tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.out.bias" ] = layer.attn_ln_1_b; + } + } + + // decoder + { + model.d_pe = ggml_new_tensor_2d( ctx, GGML_TYPE_F32, n_text_state, n_text_ctx ); + loader.add( model.d_pe, loader.model.dec.positionalEmbedding ); + + model.d_te = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_vocab ); + loader.add( model.d_te, loader.model.dec.tokenEmbedding ); + + model.d_ln_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + model.d_ln_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( model.d_ln_w, model.d_ln_b, loader.model.dec.ln ); + + // map by name + tensors[ "decoder.positional_embedding" ] = model.d_pe; + + tensors[ "decoder.token_embedding.weight" ] = model.d_te; + + tensors[ "decoder.ln.weight" ] = model.d_ln_w; + tensors[ "decoder.ln.bias" ] = model.d_ln_b; + + for( int i = 0; i < n_text_layer; ++i ) { + auto& layer = model.layers_decoder[ i ]; + auto& gpu = loader.model.dec.layers[ i ]; + + layer.mlp_ln_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + layer.mlp_ln_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.mlp_ln_w, layer.mlp_ln_b, gpu.mlpLn ); + + layer.mlp_0_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, 4 * n_text_state ); + layer.mlp_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, 4 * n_text_state ); + loader.add( layer.mlp_0_w, layer.mlp_0_b, gpu.mlp0 ); + + layer.mlp_1_w = ggml_new_tensor_2d( ctx, wtype, 4 * n_text_state, n_text_state ); + layer.mlp_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.mlp_1_w, layer.mlp_1_b, gpu.mlp1 ); + + layer.attn_ln_0_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + layer.attn_ln_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.attn_ln_0_w, layer.attn_ln_0_b, gpu.attnLn0 ); + + layer.attn_q_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state ); + layer.attn_q_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.attn_q_w, layer.attn_q_b, gpu.attnQuery ); + + layer.attn_k_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state ); + loader.add( layer.attn_k_w, gpu.attnKey ); + + layer.attn_v_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state ); + layer.attn_v_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.attn_v_w, layer.attn_v_b, gpu.attnValue ); + + layer.attn_ln_1_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state ); + layer.attn_ln_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.attn_ln_1_w, layer.attn_ln_1_b, gpu.attnLn1 ); + + layer.cross_attn_ln_0_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + layer.cross_attn_ln_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.cross_attn_ln_0_w, layer.cross_attn_ln_0_b, gpu.crossAttnLn0 ); + + layer.cross_attn_q_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state ); + layer.cross_attn_q_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.cross_attn_q_w, layer.cross_attn_q_b, gpu.crossAttnQuery ); + + layer.cross_attn_k_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state ); + loader.add( layer.cross_attn_k_w, gpu.crossAttnKey ); + + layer.cross_attn_v_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state ); + layer.cross_attn_v_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.cross_attn_v_w, layer.cross_attn_v_b, gpu.crossAttnValue ); + + layer.cross_attn_ln_1_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state ); + layer.cross_attn_ln_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state ); + loader.add( layer.cross_attn_ln_1_w, layer.cross_attn_ln_1_b, gpu.crossAttnLn1 ); + + // map by name + tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp_ln.weight" ] = layer.mlp_ln_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp_ln.bias" ] = layer.mlp_ln_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp.0.weight" ] = layer.mlp_0_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp.0.bias" ] = layer.mlp_0_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp.2.weight" ] = layer.mlp_1_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp.2.bias" ] = layer.mlp_1_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn_ln.weight" ] = layer.attn_ln_0_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn_ln.bias" ] = layer.attn_ln_0_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.query.weight" ] = layer.attn_q_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.query.bias" ] = layer.attn_q_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.key.weight" ] = layer.attn_k_w; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.value.weight" ] = layer.attn_v_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.value.bias" ] = layer.attn_v_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.out.weight" ] = layer.attn_ln_1_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.out.bias" ] = layer.attn_ln_1_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn_ln.weight" ] = layer.cross_attn_ln_0_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn_ln.bias" ] = layer.cross_attn_ln_0_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.query.weight" ] = layer.cross_attn_q_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.query.bias" ] = layer.cross_attn_q_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.key.weight" ] = layer.cross_attn_k_w; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.value.weight" ] = layer.cross_attn_v_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.value.bias" ] = layer.cross_attn_v_b; + + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.out.weight" ] = layer.cross_attn_ln_1_w; + tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.out.bias" ] = layer.cross_attn_ln_1_b; + } + } + } + + // create the ggml memory context + { + struct ggml_init_params params; + params.mem_size = ctx.buf_memory.size(); + params.mem_buffer = ctx.buf_memory.data(); + model.ctx_mem = ggml_init( params ); + if( !model.ctx_mem ) + { + logError( u8"%s: ggml_init() failed", __func__ ); + return E_INVALIDARG; + } + } + + // key + value memory + { + auto& ctx = model.ctx_mem; + + const auto& hparams = model.hparams; + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + const int n_text_ctx = hparams.n_text_ctx; + + // key/value memory for the self-attention layer + { + const int n_mem = n_text_layer * n_text_ctx; + const int n_elements = n_text_state * n_mem; + + model.memory_k = ggml_new_tensor_1d( ctx, GGML_TYPE_F16, n_elements ); + model.memory_v = ggml_new_tensor_1d( ctx, GGML_TYPE_F16, n_elements ); + } + + // key/value memory for the cross-attention layer + { + const int n_audio_ctx = hparams.n_audio_ctx; + + const int n_mem = n_text_layer * n_audio_ctx; + const int n_elements = n_text_state * n_mem; + + model.memory_cross_k = ggml_new_tensor_1d( ctx, GGML_TYPE_F16, n_elements ); + model.memory_cross_v = ggml_new_tensor_1d( ctx, GGML_TYPE_F16, n_elements ); + } + + const size_t memory_size = + ggml_nbytes( model.memory_k ) + ggml_nbytes( model.memory_v ) + + ggml_nbytes( model.memory_cross_k ) + ggml_nbytes( model.memory_cross_v ); + + logDebug( u8"%s: memory size = %7.2f MB", __func__, memory_size / 1024.0 / 1024.0 ); + } + + // load weights + { + size_t total_size = 0; + int n_loaded = 0; + std::string name; + + while( true ) + { + int32_t n_dims; + int32_t length; + int32_t ftype; + + HRESULT hr = readStruct( stm, n_dims ); + if( hr == E_EOF ) + break; + CHECK( hr ); + CHECK( readStruct( stm, length ) ); + CHECK( readStruct( stm, ftype ) ); + + int32_t nelements = 1; + int32_t ne[ 3 ] = { 1, 1, 1 }; + for( int i = 0; i < n_dims; ++i ) + { + CHECK( readStruct( stm, ne[ i ] ) ); + nelements *= ne[ i ]; + } + + name.resize( length ); + CHECK( readBytes( stm, name.data(), length ) ); + + if( tensors.find( name.data() ) == tensors.end() ) + { + logError( u8"%s: unknown tensor '%s' in model file", __func__, name.data() ); + return E_INVALIDARG; + } + + auto tensor = tensors[ name.data() ]; + if( ggml_nelements( tensor ) != nelements ) + { + logError( u8"%s: tensor '%s' has wrong size in model file", __func__, name.data() ); + return E_INVALIDARG; + } + + if( tensor->ne[ 0 ] != ne[ 0 ] || tensor->ne[ 1 ] != ne[ 1 ] || tensor->ne[ 2 ] != ne[ 2 ] ) + { + logError( u8"%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]", + __func__, name.data(), tensor->ne[ 0 ], tensor->ne[ 1 ], tensor->ne[ 2 ], ne[ 0 ], ne[ 1 ], ne[ 2 ] ); + return E_INVALIDARG; + } + + const size_t bpe = ( ftype == 0 ) ? sizeof( float ) : sizeof( ggml_fp16_t ); + + if( nelements * bpe != ggml_nbytes( tensor ) ) + { + logError( u8"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu", + __func__, name.data(), ggml_nbytes( tensor ), nelements * bpe ); + return E_INVALIDARG; + } + + CHECK( readBytes( stm, tensor->data, ggml_nbytes( tensor ) ) ); + + //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + total_size += ggml_nbytes( tensor ); + n_loaded++; + // loader.tryLoad( tensor ); + } + + logDebug( u8"%s: model size = %7.2f MB", __func__, total_size / 1024.0 / 1024.0 ); + if( n_loaded == 0 ) + { + logError( u8"%s: no tensors loaded from model file", __func__ ); + return E_INVALIDARG; + } + else if( n_loaded != (int)tensors.size() ) + { + logError( u8"%s: not all tensors loaded from model file - expected %zu, got %d", __func__, tensors.size(), n_loaded ); + return E_INVALIDARG; + } + model.n_loaded = n_loaded; + } + + return S_OK; + } + + HRESULT Context::load( iReadStream* stm ) + { + const int64_t t_start_us = ggml_time_us(); + ctx.t_start_us = t_start_us; + HRESULT hr = loadImpl( stm ); + ctx.t_load_us = ggml_time_us() - t_start_us; + return hr; + } + + HRESULT __stdcall loadReferenceCpuModel( const wchar_t* path, iModel** pp ) + { + if( nullptr == path || nullptr == pp ) + return E_POINTER; + + ComLight::Object<ReadStream> stream; + CHECK( stream.open( path ) ); + + ggml_time_init(); + ComLight::CComPtr<ComLight::Object<Context>> obj; + CHECK( ComLight::Object<Context>::create( obj ) ); + CHECK( obj->load( &stream ) ); + obj.detach( pp ); + return S_OK; + } +} + +#include "Whisper/WhisperContext.h" +#include "Whisper/ModelBuffers.h" +#include "ML/testUtils.h" +using namespace DirectCompute; + +static DirectCompute::Tensor gpuEncode( const whisper_context& wctx, const int mel_offset ) +{ + return DirectCompute::Tensor{}; +#if 0 + using namespace DirectCompute; + WhisperContext& ctx = WhisperContext::current(); + + Tensor cur; + sEncodeParams whisperParams; + const auto& mel_inp = wctx.mel; + { + const auto& model = wctx.model; + const auto& hparams = model.hparams; + whisperParams.n_len = (uint32_t)mel_inp.n_len; + whisperParams.n_mel = (uint32_t)mel_inp.n_mel; + + const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx; + assert( n_ctx > 0 ); + whisperParams.n_ctx = (uint32_t)n_ctx; + + const int n_mels = hparams.n_mels; + assert( n_mels > 0 ); + whisperParams.n_mels = (uint32_t)n_mels; + + assert( mel_offset >= 0 ); + whisperParams.mel_offset = (uint32_t)mel_offset; + + const int layersCount = hparams.n_audio_layer; + assert( layersCount > 0 ); + whisperParams.layersCount = (uint32_t)layersCount; + + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + assert( n_state >= 0 ); + assert( n_head >= 0 ); + + whisperParams.n_state = (uint32_t)n_state; + whisperParams.n_head = (uint32_t)n_head; + + int n_audio_ctx = hparams.n_audio_ctx; + assert( n_audio_ctx > 0 ); + whisperParams.n_audio_ctx = (uint32_t)n_audio_ctx; + + int n_text_state = hparams.n_text_state; + assert( n_text_state > 0 ); + whisperParams.n_text_state = (uint32_t)n_text_state; + + int n_text_layer = hparams.n_text_layer; + assert( n_text_layer > 0 ); + whisperParams.n_text_layer = (uint32_t)n_text_layer; + + int n_text_ctx = hparams.n_text_ctx; + assert( n_text_ctx > 0 ); + whisperParams.n_text_ctx = (uint32_t)n_text_ctx; + } + + return ctx.encode( mel_inp.data, whisperParams ); +#endif +} + +GpuEncTest::GpuEncTest( const whisper_context& wctx, const int mel_offset ) +{ + return; + gpuResult = gpuEncode( wctx, mel_offset ); +} + +void GpuEncTest::compare( const ggml_tensor* expected ) const +{ + return; + WhisperContext& ctx = WhisperContext::current(); + ctx.dbgPrintDifference( expected, gpuResult, "GpuEncTest.compare", false ); +} + +void GpuEncTest::compareMel( const ggml_tensor* expected ) const +{ + return; + WhisperContext& ctx = WhisperContext::current(); + ctx.dbgPrintDifference( expected, mel, "GpuEncTest.compareMel", false ); +} + +/* +void GpuEncTest::comparePostponed() +{ + if( nullptr == tempRef ) + return; + + WhisperContext& ctx = WhisperContext::current(); + ctx.dbgPrintDifference( tempRef, tempGpu, "comparePostponed" ); + tempRef = nullptr; +} */ + +__declspec( noinline ) GpuDecTest::GpuDecTest( const whisper_context& wctx, const int* tokens, const int n_tokens, const int n_past ) +{ +#if 1 + return; +#else + sDecodeParams dp; + { + WhisperContext& ctx = WhisperContext::current(); + const auto& model = wctx.model; + const auto& hparams = model.hparams; + dp.n_state = hparams.n_text_state; + dp.n_head = hparams.n_text_head; + dp.n_ctx = hparams.n_text_ctx; + dp.n_past = n_past; + dp.M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + dp.n_text_layer = hparams.n_text_layer; + dp.n_vocab = hparams.n_vocab; + } + + WhisperContext& ctx = WhisperContext::current(); + ctx.decode( tokens, n_tokens, dp, logits, probs ); +#endif +} + +void __declspec( noinline ) GpuDecTest::compare( const std::vector<float>& cpuLogits, const std::vector<float>& cpuProbs ) const +{ + return; + + if( cpuLogits.size() != logits.size() ) + { + printf( "GpuDecTest.compare fail, size different\n" ); + return; + } + + computeDiff( logits.data(), cpuLogits.data(), logits.size() ).print( "GpuDecTest.compare logits" ); + computeDiff( probs.data(), cpuProbs.data(), probs.size() ).print( "GpuDecTest.compare probs" ); +} + +void __declspec( noinline ) GpuDecTest::postpone( const ggml_tensor* t ) +{ + return; + + if( nullptr != tempRef ) + return; + tempRef = t; +} + +void __declspec( noinline ) GpuDecTest::comparePostponed() +{ +#if 1 + return; +#else + if( nullptr == tempRef ) + return; + WhisperContext& ctx = WhisperContext::current(); + ID3D11ShaderResourceView* srv = ctx.dbgDecodeTest; + if( nullptr == srv ) + return; + + ctx.dbgPrintDifference( tempRef, ctx.dbgDecodeTest, "GpuDecTest.comparePostponed" ); + tempRef = nullptr; +#endif +} +#else +HRESULT __stdcall Whisper::loadReferenceCpuModel( const wchar_t* path, Whisper::iModel** pp ) +{ + logError( u8"This build of the DLL doesn’t implement the reference CPU-running Whisper model." ); + return E_NOTIMPL; +} +#endif
\ No newline at end of file diff --git a/WhisperCpp.sln b/WhisperCpp.sln new file mode 100644 index 0000000..4c2ae51 --- /dev/null +++ b/WhisperCpp.sln @@ -0,0 +1,110 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.4.33122.133 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ComLightLib", "ComLightLib\ComLightLib.vcxproj", "{52F486E7-830C-45D8-BE47-E76B5AAB2772}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Whisper", "Whisper\Whisper.vcxproj", "{701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ComputeShaders", "ComputeShaders\ComputeShaders.vcxproj", "{1C39D386-96D0-47A1-BBFA-68BBDB24439C}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "WhisperNet", "WhisperNet\WhisperNet.csproj", "{F213558F-FEA2-4F66-A07F-69727E9EC81D}" + ProjectSection(ProjectDependencies) = postProject + {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71} = {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71} + EndProjectSection +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TranscribeCS", "Examples\TranscribeCS\TranscribeCS.csproj", "{0533B86C-D0E8-4190-9717-7DBD9EC8C11F}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "OldMain", "Examples\OldMain\OldMain.vcxproj", "{596F9770-9AEB-49D3-86CA-4200197DF12B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CompressShaders", "Tools\CompressShaders\CompressShaders.csproj", "{4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}" + ProjectSection(ProjectDependencies) = postProject + {1C39D386-96D0-47A1-BBFA-68BBDB24439C} = {1C39D386-96D0-47A1-BBFA-68BBDB24439C} + EndProjectSection +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Tools", "Tools", "{90D16EBB-08A4-4C9B-9991-B1B2E036838C}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Examples", "Examples", "{B988C132-115D-4157-99FE-0D891CE45A82}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "main", "Examples\main\main.vcxproj", "{4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MicrophoneCS", "Examples\MicrophoneCS\MicrophoneCS.csproj", "{A49305C0-7022-45A6-89B4-4BD33138C98A}" + ProjectSection(ProjectDependencies) = postProject + {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71} = {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "compareTraces", "Tools\compareTraces\compareTraces.vcxproj", "{8478A77C-D851-4C63-9511-1770CC82D33E}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "WhisperDesktop", "Examples\WhisperDesktop\WhisperDesktop.vcxproj", "{CD9E49F0-75A3-4F91-AC71-336109EE39C6}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{17835CA3-D7F6-4BEF-9471-12C015764A2C}" + ProjectSection(SolutionItems) = preProject + Readme.md = Readme.md + EndProjectSection +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Release|x64 = Release|x64 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {52F486E7-830C-45D8-BE47-E76B5AAB2772}.Debug|x64.ActiveCfg = Debug|x64 + {52F486E7-830C-45D8-BE47-E76B5AAB2772}.Debug|x64.Build.0 = Debug|x64 + {52F486E7-830C-45D8-BE47-E76B5AAB2772}.Release|x64.ActiveCfg = Release|x64 + {52F486E7-830C-45D8-BE47-E76B5AAB2772}.Release|x64.Build.0 = Release|x64 + {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}.Debug|x64.ActiveCfg = Debug|x64 + {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}.Debug|x64.Build.0 = Debug|x64 + {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}.Release|x64.ActiveCfg = Release|x64 + {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}.Release|x64.Build.0 = Release|x64 + {1C39D386-96D0-47A1-BBFA-68BBDB24439C}.Debug|x64.ActiveCfg = Debug|x64 + {1C39D386-96D0-47A1-BBFA-68BBDB24439C}.Debug|x64.Build.0 = Debug|x64 + {1C39D386-96D0-47A1-BBFA-68BBDB24439C}.Release|x64.ActiveCfg = Release|x64 + {1C39D386-96D0-47A1-BBFA-68BBDB24439C}.Release|x64.Build.0 = Release|x64 + {F213558F-FEA2-4F66-A07F-69727E9EC81D}.Debug|x64.ActiveCfg = Debug|Any CPU + {F213558F-FEA2-4F66-A07F-69727E9EC81D}.Debug|x64.Build.0 = Debug|Any CPU + {F213558F-FEA2-4F66-A07F-69727E9EC81D}.Release|x64.ActiveCfg = Release|Any CPU + {F213558F-FEA2-4F66-A07F-69727E9EC81D}.Release|x64.Build.0 = Release|Any CPU + {0533B86C-D0E8-4190-9717-7DBD9EC8C11F}.Debug|x64.ActiveCfg = Debug|x64 + {0533B86C-D0E8-4190-9717-7DBD9EC8C11F}.Debug|x64.Build.0 = Debug|x64 + {0533B86C-D0E8-4190-9717-7DBD9EC8C11F}.Release|x64.ActiveCfg = Release|x64 + {596F9770-9AEB-49D3-86CA-4200197DF12B}.Debug|x64.ActiveCfg = Debug|x64 + {596F9770-9AEB-49D3-86CA-4200197DF12B}.Debug|x64.Build.0 = Debug|x64 + {596F9770-9AEB-49D3-86CA-4200197DF12B}.Release|x64.ActiveCfg = Release|x64 + {596F9770-9AEB-49D3-86CA-4200197DF12B}.Release|x64.Build.0 = Release|x64 + {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}.Debug|x64.ActiveCfg = Debug|Any CPU + {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}.Debug|x64.Build.0 = Debug|Any CPU + {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}.Release|x64.ActiveCfg = Release|Any CPU + {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}.Release|x64.Build.0 = Release|Any CPU + {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}.Debug|x64.ActiveCfg = Debug|x64 + {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}.Debug|x64.Build.0 = Debug|x64 + {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}.Release|x64.ActiveCfg = Release|x64 + {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}.Release|x64.Build.0 = Release|x64 + {A49305C0-7022-45A6-89B4-4BD33138C98A}.Debug|x64.ActiveCfg = Debug|x64 + {A49305C0-7022-45A6-89B4-4BD33138C98A}.Debug|x64.Build.0 = Debug|x64 + {A49305C0-7022-45A6-89B4-4BD33138C98A}.Release|x64.ActiveCfg = Release|x64 + {8478A77C-D851-4C63-9511-1770CC82D33E}.Debug|x64.ActiveCfg = Debug|x64 + {8478A77C-D851-4C63-9511-1770CC82D33E}.Debug|x64.Build.0 = Debug|x64 + {8478A77C-D851-4C63-9511-1770CC82D33E}.Release|x64.ActiveCfg = Release|x64 + {8478A77C-D851-4C63-9511-1770CC82D33E}.Release|x64.Build.0 = Release|x64 + {CD9E49F0-75A3-4F91-AC71-336109EE39C6}.Debug|x64.ActiveCfg = Debug|x64 + {CD9E49F0-75A3-4F91-AC71-336109EE39C6}.Debug|x64.Build.0 = Debug|x64 + {CD9E49F0-75A3-4F91-AC71-336109EE39C6}.Release|x64.ActiveCfg = Release|x64 + {CD9E49F0-75A3-4F91-AC71-336109EE39C6}.Release|x64.Build.0 = Release|x64 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {0533B86C-D0E8-4190-9717-7DBD9EC8C11F} = {B988C132-115D-4157-99FE-0D891CE45A82} + {596F9770-9AEB-49D3-86CA-4200197DF12B} = {B988C132-115D-4157-99FE-0D891CE45A82} + {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2} = {90D16EBB-08A4-4C9B-9991-B1B2E036838C} + {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2} = {B988C132-115D-4157-99FE-0D891CE45A82} + {A49305C0-7022-45A6-89B4-4BD33138C98A} = {B988C132-115D-4157-99FE-0D891CE45A82} + {8478A77C-D851-4C63-9511-1770CC82D33E} = {90D16EBB-08A4-4C9B-9991-B1B2E036838C} + {CD9E49F0-75A3-4F91-AC71-336109EE39C6} = {B988C132-115D-4157-99FE-0D891CE45A82} + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {07D5F1CF-1FAD-4F40-806A-B148CD609961} + EndGlobalSection +EndGlobal diff --git a/WhisperNet/API/CaptureDeviceId.cs b/WhisperNet/API/CaptureDeviceId.cs new file mode 100644 index 0000000..9636e53 --- /dev/null +++ b/WhisperNet/API/CaptureDeviceId.cs @@ -0,0 +1,24 @@ +using Whisper.Internal; + +namespace Whisper +{ + /// <summary>Identifiers for an audio capture device</summary> + public record struct CaptureDeviceId + { + /// <summary>The display name is suitable for showing to the user, but might not be unique.</summary> + public string displayName; + + /// <summary>Endpoint ID for an audio capture device.<br/> + /// It uniquely identifies the device on the system, but is not a readable string.</summary> + public string endpoint; + + internal CaptureDeviceId( in sCaptureDevice rsi ) + { + displayName = rsi.displayName ?? "<display name unavailable>"; + endpoint = rsi.endpoint ?? throw new ApplicationException( "The device has no endpoint ID" ); + } + + /// <summary>Returns a String which represents the object instance</summary> + public override string ToString() => $"Capture device: \"{displayName}\""; + } +}
\ No newline at end of file diff --git a/WhisperNet/API/Parameters.cs b/WhisperNet/API/Parameters.cs new file mode 100644 index 0000000..d2b53f9 --- /dev/null +++ b/WhisperNet/API/Parameters.cs @@ -0,0 +1,95 @@ +// Missing XML comment for publicly visible type or member +// TODO: remove this line and document them. +#pragma warning disable CS1591 + +namespace Whisper +{ + /// <summary>Available sampling strategies</summary> + public enum eSamplingStrategy: int + { + /// <summary>Always select the most probable token</summary> + Greedy, + /// <summary>TODO: not implemented yet!</summary> + BeamSearch, + }; + + [Flags] + public enum eFullParamsFlags: uint + { + None = 0, + Translate = 1, + NoContext = 2, + SingleSegment = 4, + PrintSpecial = 8, + PrintProgress = 0x10, + PrintRealtime = 0x20, + PrintTimestamps = 0x40, + + // Experimental + TokenTimestamps = 0x100, + SpeedupAudio = 0x200, + }; + + /// <summary>Transcribe parameters</summary> + public struct Parameters + { + /// <summary>Sampling strategy</summary> + public eSamplingStrategy strategy; + + /// <summary>Count of CPU worker threads to use</summary> + /// <remarks>So far, the GPU model only uses CPU threads for MEL spectrograms</remarks> + public int cpuThreads; + + public int n_max_text_ctx; + /// <summary>start offset in ms</summary> + public int offset_ms; + /// <summary>audio duration to process in ms</summary> + public int duration_ms; + public eFullParamsFlags flags; + + /// <summary>Set or clear the specified flag in the <see cref="flags" /> field of this structure</summary> + public void setFlag( eFullParamsFlags flag, bool set ) + { + if( flag != eFullParamsFlags.None ) + { + if( set ) + flags |= flag; + else + flags &= ~flag; + return; + } + throw new ArgumentException(); + } + + /// <summary>Language</summary> + public eLanguage language; + + // [EXPERIMENTAL] token-level timestamps + /// <summary>timestamp token probability threshold (~0.01)</summary> + public float thold_pt; + /// <summary>timestamp token sum probability threshold (~0.01)</summary> + public float thold_ptsum; + /// <summary>max segment length in characters</summary> + public int max_len; + /// <summary>max tokens per segment (0 = no limit)</summary> + public int max_tokens; + + public struct sGreedy + { + public int n_past; + } + public sGreedy greedy; + + public struct sBeamSearch + { + public int n_past; + public int beam_width; + public int n_best; + } + public sBeamSearch beamSearch; + + // [EXPERIMENTAL] speed-up techniques + /// <summary>overwrite the audio context size (0 = use default)</summary> + public int audioContextSize; + } +}
\ No newline at end of file diff --git a/WhisperNet/API/SpecialTokens.cs b/WhisperNet/API/SpecialTokens.cs new file mode 100644 index 0000000..d672369 --- /dev/null +++ b/WhisperNet/API/SpecialTokens.cs @@ -0,0 +1,23 @@ +namespace Whisper +{ + /// <summary>Special tokens defined in the model</summary> + public readonly struct SpecialTokens + { + /// <summary>The end of a transcription</summary> + public readonly int TranscriptionEnd; // token_eot + /// <summary>Start of a transcription</summary> + public readonly int TranscriptionStart; // token_sot + /// <summary>Represents the previous word in the transcription. It is used to help the model predict the current word based on the context of the words that came before it.</summary> + public readonly int PreviousWord; // token_prev + /// <summary>Start of a sentence</summary> + public readonly int SentenceStart; // token_solm + /// <summary>Represents the word "not" in the transcription</summary> + public readonly int Not; // token_not + /// <summary>New transcription</summary> + public readonly int TranscriptionBegin; // token_beg + /// <summary>token_translate</summary> + public readonly int TaskTranslate; + /// <summary>token_transcribe</summary> + public readonly int TaskTranscribe; + } +}
\ No newline at end of file diff --git a/WhisperNet/API/eCaptureStatus.cs b/WhisperNet/API/eCaptureStatus.cs new file mode 100644 index 0000000..41f05fb --- /dev/null +++ b/WhisperNet/API/eCaptureStatus.cs @@ -0,0 +1,19 @@ +namespace Whisper +{ + /// <summary>Status of the voice capture</summary> + [Flags] + public enum eCaptureStatus: byte + { + /// <summary>Doing nothing</summary> + None = 0, + /// <summary>Capturing the audio</summary> + Listening = 1, + /// <summary>A voice is detected in the captured audio, recording</summary> + Voice = 2, + /// <summary>Transcribing a recorded piece of the audio</summary> + Transcribing = 4, + /// <summary>The computer is unable to transcribe the audio quickly enough,<br/> + /// and the capture is dropping the incoming audio samples.</summary> + Stalled = 0x80, + } +}
\ No newline at end of file diff --git a/WhisperNet/API/eLanguage.cs b/WhisperNet/API/eLanguage.cs new file mode 100644 index 0000000..1241077 --- /dev/null +++ b/WhisperNet/API/eLanguage.cs @@ -0,0 +1,206 @@ +// This file is generated by a tool, from the `languageCodez.tsv` file in this repository +namespace Whisper +{ + /// <summary>Supported languages</summary> + public enum eLanguage: uint + { + /// <summary>Afrikaans</summary> + Afrikaans = 0x6661, + /// <summary>Albanian</summary> + Albanian = 0x7173, + /// <summary>Amharic</summary> + Amharic = 0x6D61, + /// <summary>Arabic</summary> + Arabic = 0x7261, + /// <summary>Armenian</summary> + Armenian = 0x7968, + /// <summary>Assamese</summary> + Assamese = 0x7361, + /// <summary>Azerbaijani</summary> + Azerbaijani = 0x7A61, + /// <summary>Bashkir</summary> + Bashkir = 0x6162, + /// <summary>Basque</summary> + Basque = 0x7565, + /// <summary>Belarusian</summary> + Belarusian = 0x6562, + /// <summary>Bengali</summary> + Bengali = 0x6E62, + /// <summary>Bosnian</summary> + Bosnian = 0x7362, + /// <summary>Breton</summary> + Breton = 0x7262, + /// <summary>Bulgarian</summary> + Bulgarian = 0x6762, + /// <summary>Catalan</summary> + Catalan = 0x6163, + /// <summary>Chinese</summary> + Chinese = 0x687A, + /// <summary>Croatian</summary> + Croatian = 0x7268, + /// <summary>Czech</summary> + Czech = 0x7363, + /// <summary>Danish</summary> + Danish = 0x6164, + /// <summary>Dutch</summary> + Dutch = 0x6C6E, + /// <summary>English</summary> + English = 0x6E65, + /// <summary>Estonian</summary> + Estonian = 0x7465, + /// <summary>Faroese</summary> + Faroese = 0x6F66, + /// <summary>Finnish</summary> + Finnish = 0x6966, + /// <summary>French</summary> + French = 0x7266, + /// <summary>Galician</summary> + Galician = 0x6C67, + /// <summary>Georgian</summary> + Georgian = 0x616B, + /// <summary>German</summary> + German = 0x6564, + /// <summary>Greek</summary> + Greek = 0x6C65, + /// <summary>Gujarati</summary> + Gujarati = 0x7567, + /// <summary>Haitian Creole</summary> + HaitianCreole = 0x7468, + /// <summary>Hausa</summary> + Hausa = 0x6168, + /// <summary>Hawaiian</summary> + Hawaiian = 0x776168, + /// <summary>Hebrew</summary> + Hebrew = 0x7769, + /// <summary>Hindi</summary> + Hindi = 0x6968, + /// <summary>Hungarian</summary> + Hungarian = 0x7568, + /// <summary>Icelandic</summary> + Icelandic = 0x7369, + /// <summary>Indonesian</summary> + Indonesian = 0x6469, + /// <summary>Italian</summary> + Italian = 0x7469, + /// <summary>Japanese</summary> + Japanese = 0x616A, + /// <summary>Javanese</summary> + Javanese = 0x776A, + /// <summary>Kannada</summary> + Kannada = 0x6E6B, + /// <summary>Kazakh</summary> + Kazakh = 0x6B6B, + /// <summary>Khmer</summary> + Khmer = 0x6D6B, + /// <summary>Korean</summary> + Korean = 0x6F6B, + /// <summary>Lao</summary> + Lao = 0x6F6C, + /// <summary>Latin</summary> + Latin = 0x616C, + /// <summary>Latvian</summary> + Latvian = 0x766C, + /// <summary>Lingala</summary> + Lingala = 0x6E6C, + /// <summary>Lithuanian</summary> + Lithuanian = 0x746C, + /// <summary>Luxembourgish</summary> + Luxembourgish = 0x626C, + /// <summary>Macedonian</summary> + Macedonian = 0x6B6D, + /// <summary>Malagasy</summary> + Malagasy = 0x676D, + /// <summary>Malay</summary> + Malay = 0x736D, + /// <summary>Malayalam</summary> + Malayalam = 0x6C6D, + /// <summary>Maltese</summary> + Maltese = 0x746D, + /// <summary>Maori</summary> + Maori = 0x696D, + /// <summary>Marathi</summary> + Marathi = 0x726D, + /// <summary>Mongolian</summary> + Mongolian = 0x6E6D, + /// <summary>Myanmar</summary> + Myanmar = 0x796D, + /// <summary>Nepali</summary> + Nepali = 0x656E, + /// <summary>Norwegian</summary> + Norwegian = 0x6F6E, + /// <summary>Nynorsk</summary> + Nynorsk = 0x6E6E, + /// <summary>Occitan</summary> + Occitan = 0x636F, + /// <summary>Pashto</summary> + Pashto = 0x7370, + /// <summary>Persian</summary> + Persian = 0x6166, + /// <summary>Polish</summary> + Polish = 0x6C70, + /// <summary>Portuguese</summary> + Portuguese = 0x7470, + /// <summary>Punjabi</summary> + Punjabi = 0x6170, + /// <summary>Romanian</summary> + Romanian = 0x6F72, + /// <summary>Russian</summary> + Russian = 0x7572, + /// <summary>Sanskrit</summary> + Sanskrit = 0x6173, + /// <summary>Serbian</summary> + Serbian = 0x7273, + /// <summary>Shona</summary> + Shona = 0x6E73, + /// <summary>Sindhi</summary> + Sindhi = 0x6473, + /// <summary>Sinhala</summary> + Sinhala = 0x6973, + /// <summary>Slovak</summary> + Slovak = 0x6B73, + /// <summary>Slovenian</summary> + Slovenian = 0x6C73, + /// <summary>Somali</summary> + Somali = 0x6F73, + /// <summary>Spanish</summary> + Spanish = 0x7365, + /// <summary>Sundanese</summary> + Sundanese = 0x7573, + /// <summary>Swahili</summary> + Swahili = 0x7773, + /// <summary>Swedish</summary> + Swedish = 0x7673, + /// <summary>Tagalog</summary> + Tagalog = 0x6C74, + /// <summary>Tajik</summary> + Tajik = 0x6774, + /// <summary>Tamil</summary> + Tamil = 0x6174, + /// <summary>Tatar</summary> + Tatar = 0x7474, + /// <summary>Telugu</summary> + Telugu = 0x6574, + /// <summary>Thai</summary> + Thai = 0x6874, + /// <summary>Tibetan</summary> + Tibetan = 0x6F62, + /// <summary>Turkish</summary> + Turkish = 0x7274, + /// <summary>Turkmen</summary> + Turkmen = 0x6B74, + /// <summary>Ukrainian</summary> + Ukrainian = 0x6B75, + /// <summary>Urdu</summary> + Urdu = 0x7275, + /// <summary>Uzbek</summary> + Uzbek = 0x7A75, + /// <summary>Vietnamese</summary> + Vietnamese = 0x6976, + /// <summary>Welsh</summary> + Welsh = 0x7963, + /// <summary>Yiddish</summary> + Yiddish = 0x6979, + /// <summary>Yoruba</summary> + Yoruba = 0x6F79, + } +}
\ No newline at end of file diff --git a/WhisperNet/API/eLogLevel.cs b/WhisperNet/API/eLogLevel.cs new file mode 100644 index 0000000..ae494d4 --- /dev/null +++ b/WhisperNet/API/eLogLevel.cs @@ -0,0 +1,34 @@ +namespace Whisper +{ + /// <summary>Message log level</summary> + public enum eLogLevel: byte + { + /// <summary>Error message</summary> + Error = 0, + /// <summary>Warning message</summary> + Warning = 1, + /// <summary>Informational message</summary> + Info = 2, + /// <summary>Debug message</summary> + Debug = 3 + } + + /// <summary>A delegate to receive log messages from the library</summary> + public delegate void pfnLogMessage( eLogLevel level, string message ); + + /// <summary>Log destination flags</summary> + [Flags] + public enum eLoggerFlags: byte + { + /// <summary>No special flags</summary> + None = 0, + + /// <summary>In addition to calling the delegate, print messaged to standard error</summary> + UseStandardError = 1, + + /// <summary>Don’t format error codes into messages</summary> + /// <remarks>It’s recommended to use this flag in .NET.<br/> + /// The standard library already formats these messages automatically, as needed.</remarks> + SkipFormatMessage = 2, + } +}
\ No newline at end of file diff --git a/WhisperNet/API/eModelImplementation.cs b/WhisperNet/API/eModelImplementation.cs new file mode 100644 index 0000000..1b0a079 --- /dev/null +++ b/WhisperNet/API/eModelImplementation.cs @@ -0,0 +1,25 @@ +namespace Whisper +{ + /// <summary>Implementation value for the <see cref="Library.loadModel(string, eModelImplementation)" /> factory function</summary> + public enum eModelImplementation: uint + { + /// <summary>GPGPU implementation based on Direct3D 11.0 compute shaders</summary> + GPU = 1, + + /// <summary>A hybrid implementation which uses DirectCompute for encode, and decodes on CPU</summary> + /// <remarks> + /// <para>The build of the native DLL included into this nuget package doesn’t implement this version.<br/> + /// To enable, edit <c>stdafx.h</c> in Whisper project, change the value of <c>BUILD_HYBRID_VERSION</c> macro from zero to one, and build.</para> + /// <para>This implementation requires a CPU with AVX1, FMA3, F16C and BMI1 instruction set extensions.</para> + /// </remarks> + Hybrid = 2, + + /// <summary>A reference implementation which uses the original GGML CPU-running code.</summary> + /// <remarks> + /// <para>The build of the native DLL included into this nuget package doesn’t implement this version either.<br/> + /// To enable, edit <c>stdafx.h</c> in Whisper project, change the value of <c>BUILD_BOTH_VERSIONS</c> macro from zero to one, and build the project.</para> + /// <para>This implementation requires a CPU with AVX1, FMA3, and F16C instruction set extensions.</para> + /// </remarks> + Reference = 3, + } +}
\ No newline at end of file diff --git a/WhisperNet/API/eResultFlags.cs b/WhisperNet/API/eResultFlags.cs new file mode 100644 index 0000000..1de61ab --- /dev/null +++ b/WhisperNet/API/eResultFlags.cs @@ -0,0 +1,21 @@ +namespace Whisper +{ + /// <summary>Flags for <see cref="Context.results(eResultFlags)" /> method</summary> + [Flags] + public enum eResultFlags: uint + { + /// <summary>No flags</summary> + None = 0, + + /// <summary>Return individual tokens in addition to the segments</summary> + Tokens = 1, + + /// <summary>Return timestamps</summary> + Timestamps = 2, + + /// <summary>Create a new COM object for the results.</summary> + /// <remarks>Without this flag, the context returns a pointer to the COM object stored in the context.<br/> + /// The content of that object is replaced every time you call <see cref="Internal.iContext.getResults(eResultFlags)" /> method.</remarks> + NewObject = 0x100, + } +}
\ No newline at end of file diff --git a/WhisperNet/API/iAudioBuffer.cs b/WhisperNet/API/iAudioBuffer.cs new file mode 100644 index 0000000..1b35621 --- /dev/null +++ b/WhisperNet/API/iAudioBuffer.cs @@ -0,0 +1,27 @@ +using ComLight; +using System.Runtime.InteropServices; + +namespace Whisper +{ + /// <summary>A buffer with a chunk of audio.</summary> + /// <remarks>Note the interface supports both marshaling directions.<br/> + /// I have not tested, but you should be able to implement this interface in C#, to supply PCM audio data to the native code</remarks> + [ComInterface( "013583aa-c9eb-42bc-83db-633c2c317051", eMarshalDirection.BothWays )] + public interface iAudioBuffer: IDisposable + { + /// <summary>Count of samples in the buffer</summary> + int countSamples(); + + /// <summary>Unmanaged pointer to the internal buffer containing single-channel FP32 samples.</summary> + /// <remarks>If you implementing this interface in C# and your audio data is on the managed heap, use <see cref="GCHandle" /> to make sure it doesn't move.<br/> + /// Or better yet, move the data to unmanaged buffer allocated with <see cref="Marshal.AllocHGlobal(int)" /> or <see cref="Marshal.AllocCoTaskMem(int)" /> method.</remarks> + IntPtr getPcmMono(); + + /// <summary>Unmanaged pointer to the internal buffer containing stereo FP32 samples.</summary> + /// <remarks>When the buffer doesn’t have stereo data, the method gonna return <see cref="IntPtr.Zero" />.</remarks> + IntPtr getPcmStereo(); + + /// <summary>Start time of the buffer, relative to the start of the media</summary> + void getTime( out TimeSpan time ); + } +}
\ No newline at end of file diff --git a/WhisperNet/API/iAudioReader.cs b/WhisperNet/API/iAudioReader.cs new file mode 100644 index 0000000..68cf916 --- /dev/null +++ b/WhisperNet/API/iAudioReader.cs @@ -0,0 +1,23 @@ +using ComLight; + +namespace Whisper +{ + /// <summary>Audio stream reader object</summary> + /// <remarks>The implementation is forward-only, and these objects ain’t reusable.<br/> + /// To read a source file multiple time, dispose and re-create the reader.</remarks> + [ComInterface( "35b988da-04a6-476a-a193-d8891d5dc390", eMarshalDirection.ToManaged )] + public interface iAudioReader: IDisposable + { + /// <summary>Get duration of the media file</summary> + [RetValIndex] + TimeSpan getDuration(); + } + + /// <summary>Audio capture reader object</summary> + /// <remarks>This interface has no public methods callable from C#.<br/> + /// It’s only here to pass data between different functions implemented in C++.</remarks> + [ComInterface( "747752c2-d9fd-40df-8847-583c781bf013", eMarshalDirection.ToManaged )] + public interface iAudioCapture: IDisposable + { + } +}
\ No newline at end of file diff --git a/WhisperNet/API/iMediaFoundation.cs b/WhisperNet/API/iMediaFoundation.cs new file mode 100644 index 0000000..535f904 --- /dev/null +++ b/WhisperNet/API/iMediaFoundation.cs @@ -0,0 +1,36 @@ +using ComLight; +using System.Runtime.InteropServices; +using Whisper.Internal; + +namespace Whisper +{ + /// <summary>Exposes a small subset of MS Media Foundation framework.</summary> + /// <remarks>That framework is a part of Windows OS, since Vista.</remarks> + /// <seealso href="https://learn.microsoft.com/en-us/windows/win32/medfound/microsoft-media-foundation-sdk" /> + [ComInterface( "fb9763a5-d77d-4b6e-aff8-f494813cebd8", eMarshalDirection.ToManaged ), CustomConventions( typeof( NativeLogger ) )] + public interface iMediaFoundation: IDisposable + { + /// <summary>Decode complete audio file into a new memory buffer.</summary> + /// <returns> + /// Under the hood, the method asks MF to resample and convert audio into the suitable type for the Whisper model.<br/> + /// If the path is a video file, the method will decode the first audio track. + /// </returns> + [RetValIndex( 2 )] + iAudioBuffer loadAudioFile( [MarshalAs( UnmanagedType.LPWStr )] string path, [MarshalAs( UnmanagedType.U1 )] bool stereo = false ); + + /// <summary>Create a reader to stream the audio file from disk</summary> + /// <returns> + /// Under the hood, the method asks MF to resample and convert audio into the suitable type for the Whisper model.<br/> + /// If the path is a video file, the method will decode the first audio track. + /// </returns> + [RetValIndex( 2 )] + iAudioReader openAudioFile( [MarshalAs( UnmanagedType.LPWStr )] string path, [MarshalAs( UnmanagedType.U1 )] bool stereo = false ); + + /// <summary>List capture devices</summary> + void listCaptureDevices( [MarshalAs( UnmanagedType.FunctionPtr )] pfnFoundCaptureDevices pfn, IntPtr pv ); + + /// <summary>Open audio capture device</summary> + [RetValIndex( 2 )] + iAudioCapture openCaptureDevice( [MarshalAs( UnmanagedType.LPWStr )] string endpoint, [In] ref sCaptureParams captureParams ); + } +}
\ No newline at end of file diff --git a/WhisperNet/API/iModel.cs b/WhisperNet/API/iModel.cs new file mode 100644 index 0000000..8ec6d17 --- /dev/null +++ b/WhisperNet/API/iModel.cs @@ -0,0 +1,27 @@ +using ComLight; +using System.ComponentModel; + +namespace Whisper +{ + /// <summary>A model in VRAM, loaded from GGML file.</summary> + /// <remarks>This objetc doesn't keep any mutable state, and can be safely used from multiple threads concurrently</remarks> + [ComInterface( "abefb4c9-e8d8-46a3-8747-5afbadef1adb", eMarshalDirection.ToManaged ), CustomConventions( typeof( Internal.NativeLogger ) )] + public interface iModel: IDisposable + { + /// <summary>Create a context to transcribe audio with this model</summary> + /// <remarks>Don't call this method, use <see cref="ExtensionMethods.createContext(iModel)" /> instead.</remarks> + [RetValIndex, EditorBrowsable( EditorBrowsableState.Never )] + Internal.iContext createContextInternal(); + + /// <summary>True if this model is multi-lingual</summary> + bool isMultilingual(); + + /// <summary>Retrieve integer IDs of the special tokens defined by the model</summary> + [RetValIndex] + SpecialTokens getSpecialTokens(); + + /// <summary>Try to resolve integer token ID into string.</summary> + /// <remarks>Don't call this method, use <see cref="ExtensionMethods.stringFromToken(iModel, int)" /> instead.</remarks> + IntPtr stringFromTokenInternal( int id ); + } +}
\ No newline at end of file diff --git a/WhisperNet/API/sCaptureParams.cs b/WhisperNet/API/sCaptureParams.cs new file mode 100644 index 0000000..7595a69 --- /dev/null +++ b/WhisperNet/API/sCaptureParams.cs @@ -0,0 +1,37 @@ +namespace Whisper +{ + /// <summary>Flags for the audio capture</summary> + [Flags] + public enum eCaptureFlags: uint + { + /// <summary>No special flags</summary> + None = 0, + /// <summary>When the capture device supports stereo, keep stereo PCM samples in addition to mono</summary> + Stereo = 1, + } + + /// <summary>Parameters for audio capture</summary> + public struct sCaptureParams + { + /// <summary>Minimum transcribe duration in seconds</summary> + public float minDuration; + /// <summary>Maximum transcribe duration in seconds</summary> + public float maxDuration; + /// <summary></summary> + public float dropStartSilence; + /// <summary></summary> + public float pauseDuration; + /// <summary>Flags for the audio capture</summary> + public eCaptureFlags flags; + + /// <summary>Initialize the structure with some reasonable default values</summary> + public sCaptureParams() + { + minDuration = 7.0f; // 7 seconds + maxDuration = 11.0f; // 11 seconds + dropStartSilence = 0.25f; // 250 ms + pauseDuration = 0.333f; // 333 ms + flags = eCaptureFlags.None; + } + } +}
\ No newline at end of file diff --git a/WhisperNet/Callbacks.cs b/WhisperNet/Callbacks.cs new file mode 100644 index 0000000..db38718 --- /dev/null +++ b/WhisperNet/Callbacks.cs @@ -0,0 +1,44 @@ +using Whisper.Internal; + +namespace Whisper +{ + /// <summary>Implement this abstract class to receive callbacks from the native code</summary> + public abstract class Callbacks + { + /// <summary>The callback is called before every encoder run.</summary> + /// <remarks>If it returns false, the processing is aborted.</remarks> + protected virtual bool onEncoderBegin( Context sender ) { return true; } + + /// <summary>This callback is called on each new segment</summary> + protected virtual void onNewSegment( Context sender, int countNew ) { } + + const int S_OK = 0; + const int S_FALSE = 1; + internal int encoderBegin( Context sender ) + { + try + { + return onEncoderBegin( sender ) ? S_OK : S_FALSE; + } + catch( Exception ex ) + { + NativeLogger.captureException( ex ); + return ex.HResult; + } + } + + internal int newSegment( Context sender, int countNew ) + { + try + { + onNewSegment( sender, countNew ); + return S_OK; + } + catch( Exception ex ) + { + NativeLogger.captureException( ex ); + return ex.HResult; + } + } + } +}
\ No newline at end of file diff --git a/WhisperNet/CaptureCallbacks.cs b/WhisperNet/CaptureCallbacks.cs new file mode 100644 index 0000000..26013f9 --- /dev/null +++ b/WhisperNet/CaptureCallbacks.cs @@ -0,0 +1,49 @@ +using Whisper.Internal; + +namespace Whisper +{ + /// <summary>Implement this abstract class to provide callbacks for audio capture method</summary> + public abstract class CaptureCallbacks + { + /// <summary>Override this method to support cancellation</summary> + protected virtual bool shouldCancel( Context sender ) { return false; } + + /// <summary>Override this method to get notified about status changes</summary> + protected virtual void captureStatusChanged( Context sender, eCaptureStatus status ) { } + + internal pfnShouldCancel cancel( Context sender ) + { + const int S_OK = 0; + const int S_FALSE = 1; + return delegate ( IntPtr pv ) + { + try + { + return shouldCancel( sender ) ? S_OK : S_FALSE; + } + catch( Exception ex ) + { + NativeLogger.captureException( ex ); + return ex.HResult; + } + }; + } + + internal pfnCaptureStatus status( Context sender ) + { + return delegate ( IntPtr pv, eCaptureStatus status ) + { + try + { + captureStatusChanged( sender, status ); + return 0; + } + catch( Exception ex ) + { + NativeLogger.captureException( ex ); + return ex.HResult; + } + }; + } + } +}
\ No newline at end of file diff --git a/WhisperNet/Context.cs b/WhisperNet/Context.cs new file mode 100644 index 0000000..6c6a737 --- /dev/null +++ b/WhisperNet/Context.cs @@ -0,0 +1,201 @@ +using System.Diagnostics; +using Whisper.Internal; +using Whisper.Internals; + +namespace Whisper +{ + /// <summary>Stateful context, contains methods to transcribe audio</summary> + public sealed class Context: IDisposable + { + iContext context; + // Caching the results object here saves time spent in ComLight library creating these callable proxies over and over again for the same underlying C++ object + readonly iTranscribeResult transcribeResult; + sFullParams fullParams; + sProgressSink progressSink; + bool disposed = false; + readonly Action<object> pfnBuffer, pfnStream; + + internal Context( Internal.iContext context ) + { + this.context = context; + transcribeResult = context.getResults( eResultFlags.None ); + fullParams = context.fullDefaultParams( eSamplingStrategy.Greedy ); + pfnBuffer = processBuffer; + pfnStream = processStream; + progressSink = default; + } + + void IDisposable.Dispose() + { + if( disposed ) + return; + disposed = true; + context?.Dispose(); + GC.SuppressFinalize( this ); + } + + /// <summary>Adjustable parameters</summary> + public ref Parameters parameters => ref fullParams.publicParams; + + void processBuffer( object buffer ) + { + context.runFull( ref fullParams, (iAudioBuffer)buffer ); + } + void processStream( object reader ) + { + context.runStreamed( ref fullParams, ref progressSink, (iAudioReader)reader ); + } + + void runImpl( object source, Callbacks? callbacks, ReadOnlySpan<int> promptTokens, Action<object> pfn ) + { + if( null != callbacks ) + { + // TODO [very low, performance]: the following code creates 2 new GC-allocated objects on each call. + // Possible to optimize by caching these function pointers in static readonly fields, and use another [ThreadStatic] field for the callbacks object + fullParams.newSegmentCallback = delegate ( IntPtr ctx, int countNew, IntPtr userData ) + { + return callbacks.newSegment( this, countNew ); + }; + + fullParams.encoderBeginCallback = delegate ( IntPtr ctx, IntPtr userData ) + { + return callbacks.encoderBegin( this ); + }; + } + + try + { + if( promptTokens.IsEmpty ) + { + pfn( source ); + return; + } + unsafe + { + fixed( int* tokens = promptTokens ) + { + fullParams.prompt_tokens = (IntPtr)tokens; + fullParams.prompt_n_tokens = promptTokens.Length; + pfn( source ); + } + } + } + finally + { + // Reset these delegates. + // Otherwise, this class will retain the callbacks object preventing it from being garbage collected. + fullParams.newSegmentCallback = null; + fullParams.encoderBeginCallback = null; + + fullParams.prompt_tokens = IntPtr.Zero; + fullParams.prompt_n_tokens = 0; + } + } + + /// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary> + public void runFull( iAudioBuffer buffer, Callbacks? callbacks, ReadOnlySpan<int> promptTokens ) + { + runImpl( buffer, callbacks, promptTokens, pfnBuffer ); + } + /// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary> + public void runFull( iAudioBuffer buffer, Callbacks? callbacks = null ) => + runFull( buffer, callbacks, ReadOnlySpan<int>.Empty ); + /// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary> + public void runFull( iAudioBuffer buffer, Callbacks? callbacks, int[]? promptTokens ) => + runFull( buffer, callbacks, promptTokens ?? ReadOnlySpan<int>.Empty ); + + /// <summary>Run the entire model, streaming audio from the provided reader object</summary> + public void runFull( iAudioReader reader, Callbacks? callbacks, Action<double>? pfnProgress, ReadOnlySpan<int> promptTokens ) + { + if( null != pfnProgress ) + { + progressSink.pfn = delegate ( double value, IntPtr context, IntPtr pv ) + { + try + { + pfnProgress.Invoke( value ); + return 0; + } + catch( Exception ex ) + { + return ex.HResult; + } + }; + } + try + { + runImpl( reader, callbacks, promptTokens, pfnStream ); + } + finally + { + progressSink.pfn = null; + } + } + + /// <summary>Run the entire model, streaming audio from the provided reader object</summary> + public void runFull( iAudioReader reader, Action<double>? pfnProgress = null, Callbacks? callbacks = null ) => + runFull( reader, callbacks, pfnProgress, ReadOnlySpan<int>.Empty ); + + /// <summary>Run the entire model, streaming audio from the provided reader object</summary> + public void runFull( iAudioReader reader, Callbacks? callbacks, Action<double>? pfnProgress, int[]? promptTokens ) => + runFull( reader, callbacks, pfnProgress, promptTokens ?? ReadOnlySpan<int>.Empty ); + + /// <summary>Get text results out of the context</summary> + public TranscribeResult results( eResultFlags flags = eResultFlags.None ) + { + if( flags.HasFlag( eResultFlags.NewObject ) ) + throw new ArgumentException(); + + iTranscribeResult res = context.getResults( flags ); + Debug.Assert( ReferenceEquals( res, transcribeResult ) ); + return new TranscribeResult( res ); + } + + /// <summary>Print timing data</summary> + public void timingsPrint() => context.timingsPrint(); + + /// <summary>Reset timing data</summary> + public void timingsReset() => context.timingsReset(); + + /// <summary>Continuously process audio from microphone or a similar capture device</summary> + /// <remarks>It’s recommended to call this method on a background thread.</remarks> + public void runCapture( iAudioCapture capture, Callbacks? callbacks, CaptureCallbacks? captureCallbacks ) + { + if( null != callbacks ) + { + // TODO [very low, performance]: the following code creates 2 new GC-allocated objects on each call. + // Possible to optimize by caching these function pointers in static readonly fields, and use another [ThreadStatic] field for the callbacks object + fullParams.newSegmentCallback = delegate ( IntPtr ctx, int countNew, IntPtr userData ) + { + return callbacks.newSegment( this, countNew ); + }; + + fullParams.encoderBeginCallback = delegate ( IntPtr ctx, IntPtr userData ) + { + return callbacks.encoderBegin( this ); + }; + } + + try + { + sCaptureCallbacks cc = default; + if( captureCallbacks != null ) + { + cc.shouldCancel = captureCallbacks.cancel( this ); + cc.captureStatus = captureCallbacks.status( this ); + } + context.runCapture( ref fullParams, ref cc, capture ); + } + finally + { + // Reset these delegates. + // Otherwise, this class will retain the callbacks object preventing it from being garbage collected. + fullParams.newSegmentCallback = null; + fullParams.encoderBeginCallback = null; + + fullParams.prompt_tokens = IntPtr.Zero; + fullParams.prompt_n_tokens = 0; + } + } + } +}
\ No newline at end of file diff --git a/WhisperNet/ExtensionMethods.cs b/WhisperNet/ExtensionMethods.cs new file mode 100644 index 0000000..4380ece --- /dev/null +++ b/WhisperNet/ExtensionMethods.cs @@ -0,0 +1,69 @@ +using System.Runtime.InteropServices; +using Whisper.Internal; + +namespace Whisper +{ + /// <summary>Extension methods of these COM interfaces</summary> + public static class ExtensionMethods + { + /// <summary>Create a context to transcribe audio with this model</summary> + public static Context createContext( this iModel model ) + { + iContext ctx = model.createContextInternal(); + return new Context( ctx ); + } + + /// <summary>Convert language into a short ID string, like <c>"en"</c></summary> + public static string getCode( this eLanguage lang ) + { + unsafe + { + sbyte* ptr = stackalloc sbyte[ 5 ]; + *(uint*)ptr = (uint)lang; + ptr[ 4 ] = 0; + return new string( ptr ); + } + } + + /// <summary>Resolve integer token ID into string.</summary> + /// <remarks>If the token ID was not found in the model, the method returns null without raising exceptions.</remarks> + public static string? stringFromToken( this iModel model, int idToken ) => + Marshal.PtrToStringUTF8( model.stringFromTokenInternal( idToken ) ); + + /// <summary>List capture devices</summary> + public static CaptureDeviceId[]? listCaptureDevices( this iMediaFoundation mf ) + { + List<CaptureDeviceId>? list = null; + + pfnFoundCaptureDevices pfn = delegate ( int len, sCaptureDevice[]? arr, IntPtr pv ) + { + try + { + if( len == 0 || arr == null ) + return 1; + + list = new List<CaptureDeviceId>( len ); + foreach( var i in arr ) + list.Add( new CaptureDeviceId( i ) ); + return 0; + } + catch( Exception ex ) + { + NativeLogger.captureException( ex ); + return ex.HResult; + } + }; + + mf.listCaptureDevices( pfn, IntPtr.Zero ); + + return list?.ToArray(); + } + + /// <summary>Open audio capture device</summary> + public static iAudioCapture openCaptureDevice( this iMediaFoundation mf, in CaptureDeviceId id, sCaptureParams? cp = null ) + { + sCaptureParams captureParams = cp ?? new sCaptureParams(); + return mf.openCaptureDevice( id.endpoint, ref captureParams ); + } + } +}
\ No newline at end of file diff --git a/WhisperNet/Internal/AssemblyInfo.cs b/WhisperNet/Internal/AssemblyInfo.cs new file mode 100644 index 0000000..29ec638 --- /dev/null +++ b/WhisperNet/Internal/AssemblyInfo.cs @@ -0,0 +1,8 @@ +using System.Reflection; +using System.Runtime.InteropServices; +[assembly: AssemblyTitle( "WhisperNet" )] +[assembly: AssemblyCopyright( "Copyright © const.me, 2022" )] +[assembly: ComVisible( false )] +[assembly: Guid( "ced6cdb7-e040-4398-bae8-3417e5fa35f1" )] +[assembly: AssemblyVersion( "1.0.0.0" )] +[assembly: AssemblyDescription( "DirectCompute port of whisper.cpp library, C# bindings" )]
\ No newline at end of file diff --git a/WhisperNet/Internal/NativeLogger.cs b/WhisperNet/Internal/NativeLogger.cs new file mode 100644 index 0000000..b4b4eb2 --- /dev/null +++ b/WhisperNet/Internal/NativeLogger.cs @@ -0,0 +1,138 @@ +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; + +namespace Whisper.Internal +{ + /// <summary>Utility class to supply logging function pointer to the C++ library,<br/> + /// and provide custom calling conventions to ComLight runtime to convert error messages printed in C++ into .NET exception messages</summary> + public static class NativeLogger + { + internal static void startup() { } + + static NativeLogger() + { + sink = logSink; + sLoggerSetup setup = default; + setup.sink = sink; + setup.level = eLogLevel.Warning; + Library.setupLogger( ref setup ); + } + + internal static void setup( eLogLevel lvl, eLoggerFlags flags, pfnLogMessage? pfn ) + { + logMessage = pfn; + + sLoggerSetup setup = default; + setup.sink = sink; + setup.level = lvl; + setup.flags = flags; + Library.setupLogger( ref setup ); + } + + // This field is here to protect the function pointer from being collected by the GC + static readonly pfnLoggerSink sink; + + static void logSink( IntPtr context, eLogLevel lvl, string message ) + { + if( lvl == eLogLevel.Error ) + state.setText( message ); + logMessage?.Invoke( lvl, message ); + } + + sealed class ThreadState + { + string? errorText = null; + ExceptionDispatchInfo? dispatchInfo = null; + + public void setText( string text ) => errorText = text; + public void capture( Exception ex ) => dispatchInfo = ExceptionDispatchInfo.Capture( ex ); + + public void clear() + { + errorText = null; + dispatchInfo = null; + } + + public void Deconstruct( out string? text, out ExceptionDispatchInfo? edi ) + { + text = errorText; + edi = dispatchInfo; + errorText = null; + dispatchInfo = null; + } + } + + [ThreadStatic] + static ThreadState state = new ThreadState(); + + internal static void captureException( Exception ex ) => + state.capture( ex ); + + static pfnLogMessage? logMessage = null; + + /// <summary>Called internally by ComLight runtime</summary> + [MethodImpl( MethodImplOptions.AggressiveInlining )] + public static void prologue() + { + // https://stackoverflow.com/a/2043505/126995 + if( null != state ) + state.clear(); + else + createState(); + } + + [MethodImpl( MethodImplOptions.NoInlining )] + static void createState() + { + state = new ThreadState(); + } + + /// <summary>Epilogue implementation for unsuccessful status codes</summary> + [MethodImpl( MethodImplOptions.NoInlining )] + static void throwException( int hr ) + { + // Move state from the thread local object into local variables, and clear that object + (string? text, ExceptionDispatchInfo? edi) = state; + + if( null != edi && edi.SourceException.HResult == hr ) + { + // The error comes from a callback, and we have original context of that exception. + // Re-throw the original exception. + // This uses the original error message, and even correctly deals with the stack trace. + edi.Throw(); + } + + if( null != text ) + { + // C++ code has printed an error on the current thread, between prologue and epilogue. + // Use that text for the exception message. + Exception? ex = Marshal.GetExceptionForHR( hr ); + throw new ApplicationException( text, ex ); + } + + // We don’t have any additional info about the exception. + // Throw an exception from just the HRESULT code. + Marshal.ThrowExceptionForHR( hr ); + } + + /// <summary>Called internally by ComLight runtime</summary> + [MethodImpl( MethodImplOptions.AggressiveInlining )] + public static void throwForHR( int hr ) + { + if( hr >= 0 ) + return; // SUCCEEDED + throwException( hr ); + } + + /// <summary>Called internally by ComLight runtime</summary> + [MethodImpl( MethodImplOptions.AggressiveInlining )] + public static bool throwAndReturnBool( int hr ) + { + if( hr >= 0 ) + return 0 == hr; + throwException( hr ); + return false; + } + } +}
\ No newline at end of file diff --git a/WhisperNet/Internal/iContext.cs b/WhisperNet/Internal/iContext.cs new file mode 100644 index 0000000..6adf8c5 --- /dev/null +++ b/WhisperNet/Internal/iContext.cs @@ -0,0 +1,37 @@ +using ComLight; +using System.Runtime.InteropServices; +using Whisper.Internals; + +namespace Whisper.Internal +{ + /// <summary>Stateful context, contains methods to transcribe audio</summary> + [ComInterface( "b9956374-3b18-4943-90f2-2ab18a404537", eMarshalDirection.ToManaged ), CustomConventions( typeof( NativeLogger ) )] + public interface iContext: IDisposable + { + /// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary> + void runFull( [In] ref sFullParams @params, iAudioBuffer buffer ); + + /// <summary>Run the entire model, streaming audio from the provided reader object</summary> + void runStreamed( [In] ref sFullParams @params, [In] ref sProgressSink progressSink, iAudioReader reader ); + + /// <summary>Continuously process audio from microphone or a similar capture device</summary> + void runCapture( [In] ref sFullParams @params, [In] ref sCaptureCallbacks callbacks, iAudioCapture reader ); + + /// <summary>Get text results out of the context</summary> + [RetValIndex( 1 )] + iTranscribeResult getResults( eResultFlags flags ); + + /// <summary>Get the model which was used to create this context</summary> + [RetValIndex] + iModel getModel(); + + /// <summary>Full the default parameters of the model, for the specified sampling strategy</summary> + [RetValIndex( 1 )] + sFullParams fullDefaultParams( eSamplingStrategy strategy ); + + /// <summary>Print timing data</summary> + void timingsPrint(); + /// <summary>Reset timing data</summary> + void timingsReset(); + } +}
\ No newline at end of file diff --git a/WhisperNet/Internal/iTranscribeResult.cs b/WhisperNet/Internal/iTranscribeResult.cs new file mode 100644 index 0000000..cbf49dd --- /dev/null +++ b/WhisperNet/Internal/iTranscribeResult.cs @@ -0,0 +1,122 @@ +#pragma warning disable CS0649 // Field is never assigned to +using ComLight; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Whisper.Internal +{ + /// <summary>Size of the buffers owned by the <see cref="iTranscribeResult" /> object</summary> + public readonly struct sTranscribeLength + { + /// <summary>Count of segments</summary> + public readonly int countSegments; + /// <summary>Total count of tokens, for all segments combined</summary> + public readonly int countTokens; + } + + /// <summary>Output data from the model</summary> + [ComInterface( "2871a73f-5ce3-48f8-8779-6582ee11935e", eMarshalDirection.ToManaged ), CustomConventions( typeof( NativeLogger ) )] + public interface iTranscribeResult + { + /// <summary>Get size of the buffers</summary> + [RetValIndex, EditorBrowsable( EditorBrowsableState.Never )] + public sTranscribeLength getSize(); + + /// <summary>Pointer to segment data, a vector of <see cref="sSegment" /> structures</summary> + [EditorBrowsable( EditorBrowsableState.Never )] + public IntPtr getSegments(); + + /// <summary>Pointer to tokens data, a vector of <see cref="sToken" /> structures</summary> + [EditorBrowsable( EditorBrowsableState.Never )] + public IntPtr getTokens(); + } +} + +namespace Whisper +{ + /// <summary>Start and end times of a segment or token</summary> + /// <remarks>The times are relative to the start of the media</remarks> + public readonly struct sTimeInterval + { + /// <summary>Start time</summary> + public readonly TimeSpan begin; + /// <summary>End time</summary> + public readonly TimeSpan end; + } + + /// <summary>Segment data</summary> + public readonly struct sSegment + { + internal readonly IntPtr m_text; + /// <summary>Segment text</summary> + public string? text => Marshal.PtrToStringUTF8( m_text ); + /// <summary>Start and end times of the segment</summary> + public readonly sTimeInterval time; + /// <summary>Slice of the tokens</summary> + public readonly int firstToken, countTokens; + } + + /// <summary>Token flags</summary> + [Flags] + public enum eTokenFlags: uint + { + /// <summary>The token is special</summary> + Special = 1, + } + + /// <summary>Token data</summary> + public readonly struct sToken + { + internal readonly IntPtr m_text; + /// <summary>Token text</summary> + public string? text => Marshal.PtrToStringUTF8( m_text ); + /// <summary>Start and end times of the token</summary> + public readonly sTimeInterval time; + /// <summary>Probability of the token</summary> + public readonly float probability; + /// <summary>Probability of the timestamp token</summary> + public readonly float probabilityTimestamp; + /// <summary>Sum of probabilities of all timestamp tokens</summary> + public readonly float ptsum; + /// <summary>Voice length of the token</summary> + public readonly float vlen; + /// <summary>Token id</summary> + public readonly int id; + /// <summary>Token flags</summary> + readonly eTokenFlags flags; + /// <summary>True if the token flags has the specified bit set</summary> + public bool hasFlag( eTokenFlags bit ) => flags.HasFlag( bit ); + } + + /// <summary>Output data from the model</summary> + public readonly ref struct TranscribeResult + { + /// <summary>Segments in the results</summary> + public readonly ReadOnlySpan<sSegment> segments; + /// <summary>Tokens in the results, for all segments</summary> + public readonly ReadOnlySpan<sToken> tokens; + + internal TranscribeResult( Internal.iTranscribeResult i ) + { + Internal.sTranscribeLength len = i.getSize(); + unsafe + { + // This does not copy the buffers to managed memory. + // Instead, the C# spans directly reference the native memory stored in these std::vectors + if( len.countSegments > 0 ) + segments = new ReadOnlySpan<sSegment>( (void*)i.getSegments(), len.countSegments ); + else + segments = ReadOnlySpan<sSegment>.Empty; + + if( len.countTokens > 0 ) + tokens = new ReadOnlySpan<sToken>( (void*)i.getTokens(), len.countTokens ); + else + tokens = ReadOnlySpan<sToken>.Empty; + } + } + + /// <summary>Get tokens for the specified segment</summary> + public ReadOnlySpan<sToken> getTokens( in sSegment seg ) => + tokens.Slice( seg.firstToken, seg.countTokens ); + } +}
\ No newline at end of file diff --git a/WhisperNet/Internal/sCaptureCallbacks.cs b/WhisperNet/Internal/sCaptureCallbacks.cs new file mode 100644 index 0000000..483c2f2 --- /dev/null +++ b/WhisperNet/Internal/sCaptureCallbacks.cs @@ -0,0 +1,23 @@ +using System.Runtime.InteropServices; + +namespace Whisper.Internal +{ + /// <summary>Unmanaged code calls this to check for cancellation</summary> + [UnmanagedFunctionPointer( CallingConvention.StdCall )] + public delegate int pfnShouldCancel( IntPtr pv ); + + /// <summary>Unmanaged code calls this to notify about the status</summary> + [UnmanagedFunctionPointer( CallingConvention.StdCall )] + public delegate int pfnCaptureStatus( IntPtr pv, eCaptureStatus status ); + + /// <summary>Capture callbacks for unmanaged code</summary> + public struct sCaptureCallbacks + { + /// <summary>Cancellation function pointer</summary> + public pfnShouldCancel shouldCancel; + /// <summary>Capture status function pointer</summary> + public pfnCaptureStatus captureStatus; + /// <summary>COntext pointer, only needed for C++ compatibility</summary> + public IntPtr pv; + } +}
\ No newline at end of file diff --git a/WhisperNet/Internal/sCaptureDevice.cs b/WhisperNet/Internal/sCaptureDevice.cs new file mode 100644 index 0000000..e2d524d --- /dev/null +++ b/WhisperNet/Internal/sCaptureDevice.cs @@ -0,0 +1,22 @@ +#pragma warning disable CS0649 // Field is never assigned to +using System.Runtime.InteropServices; + +namespace Whisper.Internal +{ + /// <summary>Identifiers for an audio capture device</summary> + public struct sCaptureDevice + { + readonly IntPtr m_displayName; + /// <summary>The display name is suitable for showing to the user, but might not be unique.</summary> + public string? displayName => Marshal.PtrToStringUni( m_displayName ); + + readonly IntPtr m_endpoint; + /// <summary>Endpoint ID for an audio capture device.<br/> + /// It uniquely identifies the device on the system, but is not a readable string.</summary> + public string? endpoint => Marshal.PtrToStringUni( m_endpoint ); + } + + /// <summary>Function pointer to consume a list of audio capture device IDs</summary> + [UnmanagedFunctionPointer( CallingConvention.StdCall )] + public delegate int pfnFoundCaptureDevices( int len, [In, MarshalAs( UnmanagedType.LPArray, SizeParamIndex = 0 )] sCaptureDevice[]? arr, IntPtr pv ); +}
\ No newline at end of file diff --git a/WhisperNet/Internal/sFullParams.cs b/WhisperNet/Internal/sFullParams.cs new file mode 100644 index 0000000..7347afe --- /dev/null +++ b/WhisperNet/Internal/sFullParams.cs @@ -0,0 +1,40 @@ +#pragma warning disable CS0649 // Field is never assigned to + +// Missing XML comment for publicly visible type or member +// TODO: remove this line and document them. +#pragma warning disable CS1591 + +using System.Runtime.InteropServices; + +namespace Whisper.Internals +{ + /// <summary>This callback is called on each new segment</summary> + [UnmanagedFunctionPointer( CallingConvention.Cdecl )] + delegate int pfnNewSegment( IntPtr ctx, int countNew, IntPtr userData ); + + /// <summary>The callback is called before every encoder run. If it returns S_FALSE, the processing is aborted.</summary> + [UnmanagedFunctionPointer( CallingConvention.Cdecl )] + delegate int pfnEncoderBegin( IntPtr ctx, IntPtr userData ); + + /// <summary>Transcribe parameters</summary> + public struct sFullParams + { + internal Parameters publicParams; + // The rest of these parameters are not exposed to the user-friendly public API of this DLL + + internal IntPtr prompt_tokens; + internal int prompt_n_tokens; + + /// <summary>This callback is called on each new segment</summary> + [MarshalAs( UnmanagedType.FunctionPtr )] + internal pfnNewSegment? newSegmentCallback; + /// <summary>Parameter for the above, not needed in C#</summary> + internal IntPtr newSegmentCallbackData; + + /// <summary>The callback is called before every encoder run. If it returns false, the processing is aborted</summary> + [MarshalAs( UnmanagedType.FunctionPtr )] + internal pfnEncoderBegin? encoderBeginCallback; + /// <summary>Parameter for the above, not needed in C#</summary> + internal IntPtr encoderBeginCallbackData; + } +}
\ No newline at end of file diff --git a/WhisperNet/Internal/sLoadModelCallbacks.cs b/WhisperNet/Internal/sLoadModelCallbacks.cs new file mode 100644 index 0000000..07f5199 --- /dev/null +++ b/WhisperNet/Internal/sLoadModelCallbacks.cs @@ -0,0 +1,64 @@ +using System.Runtime.InteropServices; + +namespace Whisper.Internal +{ + /// <summary>Function pointer to report model loading progress</summary> + [UnmanagedFunctionPointer( CallingConvention.StdCall )] + delegate int pfnLoadProgress( double progress, IntPtr pv ); + + /// <summary>Function pointer to implement cooperative cancellation</summary> + [UnmanagedFunctionPointer( CallingConvention.StdCall )] + delegate int pfnCancel( IntPtr pv ); + + /// <summary>Callback functions for loading models</summary> + public struct sLoadModelCallbacks + { + /// <summary>Function pointer to report model loading progress</summary> + [MarshalAs( UnmanagedType.FunctionPtr )] + pfnLoadProgress? progress; + + /// <summary>Function pointer to implement cooperative cancellation</summary> + [MarshalAs( UnmanagedType.FunctionPtr )] + pfnCancel? cancel; + + // Not needed in C#, delegates can capture things + IntPtr pv; + + /// <summary>Wrap idiomatic C# things into these low-level C callbacks</summary> + internal sLoadModelCallbacks( CancellationToken cancelToken, Action<double>? pfnProgress ) + { + if( cancelToken != CancellationToken.None ) + { + cancel = delegate ( IntPtr pv ) + { + if( cancelToken.IsCancellationRequested ) + return 1; // S_FALSE + return 0; // S_OK + }; + } + else + cancel = null; + + if( null != pfnProgress ) + { + progress = delegate ( double val, IntPtr pv ) + { + try + { + pfnProgress( val ); + return 0; // S_OK + } + catch( Exception ex ) + { + NativeLogger.captureException( ex ); + return ex.HResult; + } + }; + } + else + progress = null; + + pv = IntPtr.Zero; + } + } +}
\ No newline at end of file diff --git a/WhisperNet/Internal/sLoggerSetup.cs b/WhisperNet/Internal/sLoggerSetup.cs new file mode 100644 index 0000000..ed1baa4 --- /dev/null +++ b/WhisperNet/Internal/sLoggerSetup.cs @@ -0,0 +1,16 @@ +using System.Runtime.InteropServices; + +namespace Whisper.Internal +{ + [UnmanagedFunctionPointer( CallingConvention.StdCall )] + delegate void pfnLoggerSink( IntPtr context, eLogLevel lvl, [MarshalAs( UnmanagedType.LPUTF8Str )] string message ); + + struct sLoggerSetup + { + [MarshalAs( UnmanagedType.FunctionPtr )] + public pfnLoggerSink sink; + IntPtr context; + public eLogLevel level; + public eLoggerFlags flags; + } +}
\ No newline at end of file diff --git a/WhisperNet/Internal/sProgressSink.cs b/WhisperNet/Internal/sProgressSink.cs new file mode 100644 index 0000000..9155677 --- /dev/null +++ b/WhisperNet/Internal/sProgressSink.cs @@ -0,0 +1,20 @@ +#pragma warning disable CS0649 // Field is never assigned to +using System.Runtime.InteropServices; + +namespace Whisper.Internal +{ + /// <summary>A callback to get notified about the progress</summary> + [UnmanagedFunctionPointer( CallingConvention.StdCall )] + delegate int pfnReportProgress( double value, IntPtr context, IntPtr pv ); + + /// <summary>C structure with a progress reporting function pointer</summary> + public struct sProgressSink + { + /// <summary>A callback to get notified about the progress</summary> + [MarshalAs( UnmanagedType.FunctionPtr )] + internal pfnReportProgress? pfn; + + /// <summary>Last parameter to the callback</summary> + internal IntPtr pv; + } +}
\ No newline at end of file diff --git a/WhisperNet/Library.cs b/WhisperNet/Library.cs new file mode 100644 index 0000000..72ecb6e --- /dev/null +++ b/WhisperNet/Library.cs @@ -0,0 +1,110 @@ +using ComLight; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics.X86; +using Whisper.Internal; + +namespace Whisper +{ + /// <summary>Factory methods implemented by the C++ DLL</summary> + public static class Library + { + static Library() + { + if( Environment.OSVersion.Platform != PlatformID.Win32NT ) + throw new ApplicationException( "This library requires Windows OS" ); + if( !Environment.Is64BitProcess ) + throw new ApplicationException( "This library only works in 64-bit processes" ); + if( RuntimeInformation.ProcessArchitecture != Architecture.X64 ) + throw new ApplicationException( "This library requires a processor with AMD64 instruction set" ); + if( !Sse41.IsSupported ) + throw new ApplicationException( "This library requires a CPU with SSE 4.1 support" ); + NativeLogger.startup(); + } + + const string dll = "Whisper.dll"; + + [DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = false )] + internal static extern void setupLogger( [In] ref sLoggerSetup setup ); + + [DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )] + static extern int loadModel( [MarshalAs( UnmanagedType.LPWStr )] string path, eModelImplementation impl, + [In] ref sLoadModelCallbacks callbacks, + [MarshalAs( UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof( Marshaler<iModel> ) )] out iModel model ); + + /// <summary>Load Whisper model from GGML file on disk</summary> + /// <remarks>Models are large, depending on user’s disk speed this might take a while, and this function blocks the calling thread.<br/> + /// Consider <see cref="loadModelAsync" /> instead.</remarks> + /// <seealso href="https://huggingface.co/datasets/ggerganov/whisper.cpp" /> + public static iModel loadModel( string path, eModelImplementation impl = eModelImplementation.GPU ) + { + iModel model; + sLoadModelCallbacks callbacks = default; + NativeLogger.prologue(); + int hr = loadModel( path, impl, ref callbacks, out model ); + NativeLogger.throwForHR( hr ); + return model; + } + + /// <summary>Load Whisper model on a background thread, with optional progress reporting and cancellation</summary> + public static Task<iModel> loadModelAsync( string path, CancellationToken cancelToken, Action<double>? pfnProgress = null, eModelImplementation impl = eModelImplementation.GPU ) + { + TaskCompletionSource<iModel> tcs = new TaskCompletionSource<iModel>(); + + WaitCallback wcb = delegate ( object? state ) + { + try + { + sLoadModelCallbacks callbacks = new sLoadModelCallbacks( cancelToken, pfnProgress ); + + iModel model; + NativeLogger.prologue(); + int hr = loadModel( path, impl, ref callbacks, out model ); + NativeLogger.throwForHR( hr ); + + tcs.SetResult( model ); + } + catch( Exception ex ) + { + tcs.SetException( ex ); + } + }; + + ThreadPool.QueueUserWorkItem( wcb ); + return tcs.Task; + } + + [DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )] + static extern int initMediaFoundation( [MarshalAs( UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof( Marshaler<iMediaFoundation> ) )] out iMediaFoundation mf ); + + /// <summary>Initialize Media Foundation runtime</summary> + public static iMediaFoundation initMediaFoundation() + { + iMediaFoundation mf; + NativeLogger.prologue(); + int hr = initMediaFoundation( out mf ); + NativeLogger.throwForHR( hr ); + return mf; + } + + // The .NET runtime uses UTF-16 for the strings, so we only need the Unicode version of this function. + // The native DLL exports both Unicode and ASCII versions. + [DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )] + static extern uint findLanguageKeyW( [MarshalAs( UnmanagedType.LPWStr )] string lang ); + + /// <summary>Try to resolve language code string like <c>"en"</c>, <c>"pl"</c> or <c>"uk"</c> into the strongly-typed enum.</summary> + /// <remarks>The function is case-sensitive, <c>"EN"</c> or <c>"UK"</c> gonna fail.</remarks> + public static eLanguage? languageFromCode( string lang ) + { + uint key = findLanguageKeyW( lang ); + if( key != uint.MaxValue ) + return (eLanguage)key; + return null; + } + + /// <summary>Set up delegate to receive log messages from the C++ library</summary> + public static void setLogSink( eLogLevel lvl, eLoggerFlags flags = eLoggerFlags.SkipFormatMessage, pfnLogMessage? pfn = null ) + { + NativeLogger.setup( lvl, flags, pfn ); + } + } +}
\ No newline at end of file diff --git a/WhisperNet/Readme.txt b/WhisperNet/Readme.txt new file mode 100644 index 0000000..50b3bb3 --- /dev/null +++ b/WhisperNet/Readme.txt @@ -0,0 +1 @@ +This project builds .NET DLL which wraps Whisper.dll into idiomatic C# API.
\ No newline at end of file diff --git a/WhisperNet/WhisperNet.csproj b/WhisperNet/WhisperNet.csproj new file mode 100644 index 0000000..105aa44 --- /dev/null +++ b/WhisperNet/WhisperNet.csproj @@ -0,0 +1,24 @@ +<Project Sdk="Microsoft.NET.Sdk"> + <PropertyGroup> + <TargetFramework>net6.0-windows</TargetFramework> + <ImplicitUsings>enable</ImplicitUsings> + <Nullable>enable</Nullable> + <CheckForOverflowUnderflow>true</CheckForOverflowUnderflow> + <AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath> + <GenerateDocumentationFile>True</GenerateDocumentationFile> + <AllowUnsafeBlocks>True</AllowUnsafeBlocks> + <RootNamespace>Whisper</RootNamespace> + <GenerateAssemblyInfo>false</GenerateAssemblyInfo> + <PlatformTarget>x64</PlatformTarget> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)'=='Release'"> + <GeneratePackageOnBuild>True</GeneratePackageOnBuild> + <NuspecFile>WhisperNet.nuspec</NuspecFile> + </PropertyGroup> + <ItemGroup> + <Content Include="..\x64\Release\Whisper.dll" Link="Whisper.dll" /> + </ItemGroup> + <ItemGroup> + <PackageReference Include="ComLightInterop" Version="1.3.7" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/WhisperNet/WhisperNet.nuspec b/WhisperNet/WhisperNet.nuspec new file mode 100644 index 0000000..d0a61f7 --- /dev/null +++ b/WhisperNet/WhisperNet.nuspec @@ -0,0 +1,27 @@ +<?xml version="1.0" encoding="utf-8"?> +<package xmlns="http://schemas.microsoft.com/packaging/2013/05/nuspec.xsd"> + <metadata> + <id>WhisperNet</id> + <version>1.0</version> + <authors>Konstantin, const.me</authors> + <license type="expression">MPL-2.0</license> + <projectUrl>https://github.com/Const-me/Whisper</projectUrl> + <description>High-performance GPGPU inference of OpenAI's Whisper automatic speech recognition (ASR) model</description> + <releaseNotes>Initial public version</releaseNotes> + <copyright>Copyright © const.me, 2022-2023</copyright> + <tags>whisper, gpgpu, speech recognition</tags> + <repository type="git" url="https://github.com/Const-me/Whisper.git" /> + <dependencies> + <group targetFramework="net6.0"> + <dependency id="ComLightInterop" version="1.3.7" /> + </group> + </dependencies> + </metadata> + <files> + <!-- Managed DLL with XML documentation --> + <file src="bin/Release/WhisperNet.dll" target="lib/net6.0/" /> + <file src="bin/Release/WhisperNet.xml" target="lib/net6.0/" /> + <!-- The C++ DLL --> + <file src="../x64/Release/Whisper.dll" target="runtimes/win-x64/native/" /> + </files> +</package>
\ No newline at end of file |
