summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
-rw-r--r--.gitignore17
-rw-r--r--ComLightLib/ComLightLib.vcxproj116
-rw-r--r--ComLightLib/ComLightLib.vcxproj.filters28
-rw-r--r--ComLightLib/Exception.hpp20
-rw-r--r--ComLightLib/Readme.txt3
-rw-r--r--ComLightLib/client/CComPtr.hpp110
-rw-r--r--ComLightLib/comLightClient.h23
-rw-r--r--ComLightLib/comLightCommon.h11
-rw-r--r--ComLightLib/comLightServer.h15
-rw-r--r--ComLightLib/hresult.h26
-rw-r--r--ComLightLib/pal/guiddef.h21
-rw-r--r--ComLightLib/pal/hresult.h101
-rw-r--r--ComLightLib/server/Object.hpp139
-rw-r--r--ComLightLib/server/ObjectRoot.hpp51
-rw-r--r--ComLightLib/server/RefCounter.hpp38
-rw-r--r--ComLightLib/server/freeThreadedMarshaller.cpp17
-rw-r--r--ComLightLib/server/freeThreadedMarshaller.h29
-rw-r--r--ComLightLib/server/interfaceMap.h31
-rw-r--r--ComLightLib/streams.h61
-rw-r--r--ComLightLib/unknwn.h36
-rw-r--r--ComLightLib/utils/guid_parse.hpp103
-rw-r--r--ComLightLib/utils/typeTraits.hpp43
-rw-r--r--ComputeShaders/ComputeShaders.cpp3
-rw-r--r--ComputeShaders/ComputeShaders.vcxproj221
-rw-r--r--ComputeShaders/ComputeShaders.vcxproj.filters66
-rw-r--r--ComputeShaders/Readme.txt11
-rw-r--r--ComputeShaders/add.hlsl6
-rw-r--r--ComputeShaders/addInPlace.hlsl39
-rw-r--r--ComputeShaders/addRepeat.hlsl70
-rw-r--r--ComputeShaders/addRepeat64.hlsl2
-rw-r--r--ComputeShaders/addRepeatGelu.hlsl88
-rw-r--r--ComputeShaders/addRepeatGelu64.hlsl2
-rw-r--r--ComputeShaders/addRepeatScale.hlsl73
-rw-r--r--ComputeShaders/addRows.hlsl46
-rw-r--r--ComputeShaders/componentwiseBinaryOp.hlsli43
-rw-r--r--ComputeShaders/convolutionMain.hlsl76
-rw-r--r--ComputeShaders/convolutionMain2.hlsl60
-rw-r--r--ComputeShaders/convolutionMain2Fixed.hlsl119
-rw-r--r--ComputeShaders/convolutionPrep1.hlsl39
-rw-r--r--ComputeShaders/convolutionPrep2.hlsl43
-rw-r--r--ComputeShaders/copyConvert.hlsl50
-rw-r--r--ComputeShaders/copyTranspose.hlsl56
-rw-r--r--ComputeShaders/diagMaskInf.hlsl30
-rw-r--r--ComputeShaders/flashAttention.hlsl170
-rw-r--r--ComputeShaders/flashAttentionCommon.hlsli67
-rw-r--r--ComputeShaders/flashAttentionCompat1.hlsl125
-rw-r--r--ComputeShaders/flashAttentionCompat2.hlsl114
-rw-r--r--ComputeShaders/flashAttentionCompat3.hlsl118
-rw-r--r--ComputeShaders/fmaRepeat1.hlsl77
-rw-r--r--ComputeShaders/fmaRepeat164.hlsl2
-rw-r--r--ComputeShaders/fmaRepeat2.hlsl45
-rw-r--r--ComputeShaders/fp64Utils.hlsli28
-rw-r--r--ComputeShaders/groupReduce.hlsli139
-rw-r--r--ComputeShaders/groupReduce64.hlsli46
-rw-r--r--ComputeShaders/matReshapePanels.hlsl105
-rw-r--r--ComputeShaders/miscUtils.hlsli84
-rw-r--r--ComputeShaders/mulMatByRow.hlsl49
-rw-r--r--ComputeShaders/mulMatByRow64.hlsl90
-rw-r--r--ComputeShaders/mulMatByRowTiled.hlsl120
-rw-r--r--ComputeShaders/mulMatByRowTiled64.hlsl4
-rw-r--r--ComputeShaders/mulMatByRowTiledEx.hlsl156
-rw-r--r--ComputeShaders/mulMatByScalar.hlsl41
-rw-r--r--ComputeShaders/mulMatDotMain.hlsl95
-rw-r--r--ComputeShaders/mulMatDotReshape.hlsl33
-rw-r--r--ComputeShaders/mulMatMadMain.hlsl154
-rw-r--r--ComputeShaders/mulMatTiled.hlsl236
-rw-r--r--ComputeShaders/mulMatTiled64.hlsl3
-rw-r--r--ComputeShaders/mulMatTiledEx.hlsl194
-rw-r--r--ComputeShaders/norm.hlsl86
-rw-r--r--ComputeShaders/normCompat.hlsl82
-rw-r--r--ComputeShaders/normFixed.hlsl124
-rw-r--r--ComputeShaders/normFixed64.hlsl2
-rw-r--r--ComputeShaders/repeatUtils.hlsli21
-rw-r--r--ComputeShaders/scaleInPlace.hlsl23
-rw-r--r--ComputeShaders/softMax.hlsl71
-rw-r--r--ComputeShaders/softMax64.hlsl71
-rw-r--r--ComputeShaders/softMaxCompat.hlsl62
-rw-r--r--ComputeShaders/softMaxFixed.hlsl79
-rw-r--r--ComputeShaders/zeroMemory.hlsl27
-rw-r--r--Examples/MicrophoneCS/CaptureThread.cs61
-rw-r--r--Examples/MicrophoneCS/CommandLineArgs.cs145
-rw-r--r--Examples/MicrophoneCS/MicrophoneCS.cs56
-rw-r--r--Examples/MicrophoneCS/MicrophoneCS.csproj27
-rw-r--r--Examples/MicrophoneCS/TranscribeCallbacks.cs114
-rw-r--r--Examples/OldMain/OldMain.vcxproj101
-rw-r--r--Examples/OldMain/OldMain.vcxproj.filters14
-rw-r--r--Examples/OldMain/dr_wav.h6434
-rw-r--r--Examples/OldMain/main.cpp684
-rw-r--r--Examples/TranscribeCS/AnsiCodes.cs68
-rw-r--r--Examples/TranscribeCS/CommandLineArgs.cs155
-rw-r--r--Examples/TranscribeCS/Transcribe.cs114
-rw-r--r--Examples/TranscribeCS/TranscribeCS.cs102
-rw-r--r--Examples/TranscribeCS/TranscribeCS.csproj19
-rw-r--r--Examples/WhisperDesktop/AppState.cpp192
-rw-r--r--Examples/WhisperDesktop/AppState.h51
-rw-r--r--Examples/WhisperDesktop/CaptureDlg.cpp505
-rw-r--r--Examples/WhisperDesktop/CaptureDlg.h143
-rw-r--r--Examples/WhisperDesktop/CircleIndicator.cpp118
-rw-r--r--Examples/WhisperDesktop/CircleIndicator.h36
-rw-r--r--Examples/WhisperDesktop/LoadModelDlg.cpp206
-rw-r--r--Examples/WhisperDesktop/LoadModelDlg.h69
-rw-r--r--Examples/WhisperDesktop/Resource.h61
-rw-r--r--Examples/WhisperDesktop/TranscribeDlg.cpp493
-rw-r--r--Examples/WhisperDesktop/TranscribeDlg.h124
-rw-r--r--Examples/WhisperDesktop/Utils/DebugConsole.cpp289
-rw-r--r--Examples/WhisperDesktop/Utils/DebugConsole.h64
-rw-r--r--Examples/WhisperDesktop/Utils/LanguageDropdown.cpp87
-rw-r--r--Examples/WhisperDesktop/Utils/LanguageDropdown.h26
-rw-r--r--Examples/WhisperDesktop/Utils/PendingState.cpp40
-rw-r--r--Examples/WhisperDesktop/Utils/PendingState.h12
-rw-r--r--Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp25
-rw-r--r--Examples/WhisperDesktop/Utils/TranslateCheckbox.h18
-rw-r--r--Examples/WhisperDesktop/Utils/WTL/atlapp.h1225
-rw-r--r--Examples/WhisperDesktop/Utils/WTL/atlcrack.h2480
-rw-r--r--Examples/WhisperDesktop/Utils/WTL/atlctrls.h9764
-rw-r--r--Examples/WhisperDesktop/Utils/WTL/atlddx.h667
-rw-r--r--Examples/WhisperDesktop/Utils/WTL/atlgdi.h3445
-rw-r--r--Examples/WhisperDesktop/Utils/WTL/atlres.h259
-rw-r--r--Examples/WhisperDesktop/Utils/WTL/atluser.h1231
-rw-r--r--Examples/WhisperDesktop/Utils/WTL/atlwinx.h623
-rw-r--r--Examples/WhisperDesktop/Utils/logger.cpp71
-rw-r--r--Examples/WhisperDesktop/Utils/logger.h36
-rw-r--r--Examples/WhisperDesktop/Utils/miscUtils.cpp254
-rw-r--r--Examples/WhisperDesktop/Utils/miscUtils.h72
-rw-r--r--Examples/WhisperDesktop/WhisperDesktop.cpp63
-rw-r--r--Examples/WhisperDesktop/WhisperDesktop.manifest16
-rw-r--r--Examples/WhisperDesktop/WhisperDesktop.rcbin0 -> 16564 bytes
-rw-r--r--Examples/WhisperDesktop/WhisperDesktop.vcxproj151
-rw-r--r--Examples/WhisperDesktop/WhisperDesktop.vcxproj.filters142
-rw-r--r--Examples/WhisperDesktop/framework.h22
-rw-r--r--Examples/WhisperDesktop/stdafx.cpp1
-rw-r--r--Examples/WhisperDesktop/stdafx.h8
-rw-r--r--Examples/WhisperDesktop/sunflower.icobin0 -> 102989 bytes
-rw-r--r--Examples/WhisperDesktop/targetver.h6
-rw-r--r--Examples/main/main.cpp315
-rw-r--r--Examples/main/main.vcxproj93
-rw-r--r--Examples/main/main.vcxproj.filters12
-rw-r--r--Examples/main/miscUtils.cpp48
-rw-r--r--Examples/main/miscUtils.h9
-rw-r--r--Examples/main/params.cpp101
-rw-r--r--Examples/main/params.h38
-rw-r--r--Tools/CompressShaders/Cabinet.cs60
-rw-r--r--Tools/CompressShaders/CompressShaders.cs244
-rw-r--r--Tools/CompressShaders/CompressShaders.csproj10
-rw-r--r--Tools/CompressShaders/DetectFp64.cs43
-rw-r--r--Tools/CompressShaders/LanguageCodes.cs103
-rw-r--r--Tools/CompressShaders/Readme.txt10
-rw-r--r--Tools/CompressShaders/ShaderNames.cs27
-rw-r--r--Tools/compareTraces/CommandLineArgs.cpp51
-rw-r--r--Tools/compareTraces/CommandLineArgs.h9
-rw-r--r--Tools/compareTraces/Readme.txt9
-rw-r--r--Tools/compareTraces/TraceReader.cpp46
-rw-r--r--Tools/compareTraces/TraceReader.h35
-rw-r--r--Tools/compareTraces/compare.cpp364
-rw-r--r--Tools/compareTraces/compare.h4
-rw-r--r--Tools/compareTraces/compareTraces.cpp16
-rw-r--r--Tools/compareTraces/compareTraces.vcxproj103
-rw-r--r--Tools/compareTraces/compareTraces.vcxproj.filters20
-rw-r--r--Tools/compareTraces/stdafx.cpp30
-rw-r--r--Tools/compareTraces/stdafx.h40
-rw-r--r--Tools/compareTraces/testUtils.cpp224
-rw-r--r--Whisper/API/MfStructs.h51
-rw-r--r--Whisper/API/Readme.txt15
-rw-r--r--Whisper/API/SpecialTokens.h25
-rw-r--r--Whisper/API/TranscribeStructs.h127
-rw-r--r--Whisper/API/iContext.cl.h66
-rw-r--r--Whisper/API/iContext.h61
-rw-r--r--Whisper/API/iMediaFoundation.cl.h48
-rw-r--r--Whisper/API/iMediaFoundation.h39
-rw-r--r--Whisper/API/iTranscribeResult.cl.h15
-rw-r--r--Whisper/API/iTranscribeResult.h12
-rw-r--r--Whisper/API/loggerApi.h35
-rw-r--r--Whisper/API/sFullParams.h136
-rw-r--r--Whisper/API/sLanguageList.h18
-rw-r--r--Whisper/API/sLoadModelCallbacks.h14
-rw-r--r--Whisper/API/whisperComLight.h4
-rw-r--r--Whisper/API/whisperWindows.h4
-rw-r--r--Whisper/CPU/BufferAllocator.cpp145
-rw-r--r--Whisper/CPU/BufferAllocator.h64
-rw-r--r--Whisper/CPU/DecoderTensors.cpp68
-rw-r--r--Whisper/CPU/DecoderTensors.h131
-rw-r--r--Whisper/CPU/HybridLoader.cpp140
-rw-r--r--Whisper/CPU/HybridLoader.h37
-rw-r--r--Whisper/CPU/KvTensors.h36
-rw-r--r--Whisper/CPU/KvTensorsCpu.cpp19
-rw-r--r--Whisper/CPU/LargeBuffer.cpp34
-rw-r--r--Whisper/CPU/LargeBuffer.h44
-rw-r--r--Whisper/CPU/MlContext.h71
-rw-r--r--Whisper/CPU/MlContextCpu.cpp597
-rw-r--r--Whisper/CPU/ParallelForRunner.cpp149
-rw-r--r--Whisper/CPU/ParallelForRunner.h52
-rw-r--r--Whisper/CPU/Readme.txt1
-rw-r--r--Whisper/CPU/Tensor.h139
-rw-r--r--Whisper/CPU/TensorCpu.cpp401
-rw-r--r--Whisper/CPU/mulMat.cpp54
-rw-r--r--Whisper/CPU/mulMat.h17
-rw-r--r--Whisper/CPU/mulMat.kernel.hpp742
-rw-r--r--Whisper/CPU/mulMatImpl.avx2.cpp362
-rw-r--r--Whisper/CPU/mulMatImpl.cpp213
-rw-r--r--Whisper/CPU/mulMatImpl.h106
-rw-r--r--Whisper/CPU/mulMatImpl.panel.cpp274
-rw-r--r--Whisper/CPU/mulMatUtils.hpp301
-rw-r--r--Whisper/CPU/simdUtils.cpp738
-rw-r--r--Whisper/CPU/simdUtils.h82
-rw-r--r--Whisper/D3D/Binder.cpp63
-rw-r--r--Whisper/D3D/Binder.h21
-rw-r--r--Whisper/D3D/MappedResource.cpp33
-rw-r--r--Whisper/D3D/MappedResource.h22
-rw-r--r--Whisper/D3D/RenderDoc/renderDoc.cpp72
-rw-r--r--Whisper/D3D/RenderDoc/renderDoc.h15
-rw-r--r--Whisper/D3D/RenderDoc/renderdoc_app.h724
-rw-r--r--Whisper/D3D/createBuffer.cpp51
-rw-r--r--Whisper/D3D/createBuffer.h8
-rw-r--r--Whisper/D3D/device.cpp120
-rw-r--r--Whisper/D3D/device.h66
-rw-r--r--Whisper/D3D/downloadBuffer.cpp72
-rw-r--r--Whisper/D3D/downloadBuffer.h9
-rw-r--r--Whisper/D3D/enums.cpp9
-rw-r--r--Whisper/D3D/enums.h34
-rw-r--r--Whisper/D3D/shaderNames.cpp53
-rw-r--r--Whisper/D3D/shaderNames.h50
-rw-r--r--Whisper/D3D/shaders.cpp104
-rw-r--r--Whisper/D3D/shaders.h7
-rw-r--r--Whisper/D3D/startup.cpp17
-rw-r--r--Whisper/D3D/startup.h11
-rw-r--r--Whisper/DllMain.cpp27
-rw-r--r--Whisper/Hybrid/HybridContext.cpp349
-rw-r--r--Whisper/Hybrid/HybridContext.h52
-rw-r--r--Whisper/Hybrid/KeyValueDownloader.cpp32
-rw-r--r--Whisper/Hybrid/KeyValueDownloader.h63
-rw-r--r--Whisper/Hybrid/Readme.txt1
-rw-r--r--Whisper/MF/AudioBuffer.cpp93
-rw-r--r--Whisper/MF/AudioBuffer.h41
-rw-r--r--Whisper/MF/AudioCapture.cpp167
-rw-r--r--Whisper/MF/AudioCapture.h12
-rw-r--r--Whisper/MF/MediaFoundation.cpp109
-rw-r--r--Whisper/MF/PcmReader.cpp274
-rw-r--r--Whisper/MF/PcmReader.h63
-rw-r--r--Whisper/MF/loadAudioFile.cpp151
-rw-r--r--Whisper/MF/loadAudioFile.h7
-rw-r--r--Whisper/MF/mfStartup.cpp128
-rw-r--r--Whisper/MF/mfStartup.h15
-rw-r--r--Whisper/MF/mfUtils.cpp69
-rw-r--r--Whisper/MF/mfUtils.h15
-rw-r--r--Whisper/ML/ConstantBuffer.cpp63
-rw-r--r--Whisper/ML/ConstantBuffer.h25
-rw-r--r--Whisper/ML/Context.ops.cpp280
-rw-r--r--Whisper/ML/LookupTables.cpp54
-rw-r--r--Whisper/ML/LookupTables.h22
-rw-r--r--Whisper/ML/LookupTablesData.cpp40
-rw-r--r--Whisper/ML/LookupTablesData.h14
-rw-r--r--Whisper/ML/MlContext.cpp744
-rw-r--r--Whisper/ML/MlContext.dbg.cpp59
-rw-r--r--Whisper/ML/MlContext.h111
-rw-r--r--Whisper/ML/Reshaper.cpp80
-rw-r--r--Whisper/ML/Reshaper.h17
-rw-r--r--Whisper/ML/TempBuffers.cpp88
-rw-r--r--Whisper/ML/TempBuffers.h48
-rw-r--r--Whisper/ML/Tensor.cpp340
-rw-r--r--Whisper/ML/Tensor.h78
-rw-r--r--Whisper/ML/TensorEx.cpp97
-rw-r--r--Whisper/ML/TensorEx.h42
-rw-r--r--Whisper/ML/TensorGpuViews.cpp23
-rw-r--r--Whisper/ML/TensorGpuViews.h32
-rw-r--r--Whisper/ML/TensorShape.cpp72
-rw-r--r--Whisper/ML/TensorShape.h120
-rw-r--r--Whisper/ML/TensorsArena.cpp117
-rw-r--r--Whisper/ML/TensorsArena.h79
-rw-r--r--Whisper/ML/mlStartup.cpp27
-rw-r--r--Whisper/ML/mlStartup.h8
-rw-r--r--Whisper/ML/reshapedMultiply.h10
-rw-r--r--Whisper/ML/tensorOpsTests.cpp183
-rw-r--r--Whisper/ML/tensorOpsTests.h15
-rw-r--r--Whisper/ML/testUtils.cpp334
-rw-r--r--Whisper/ML/testUtils.h62
-rw-r--r--Whisper/ML/testUtilsC.h10
-rw-r--r--Whisper/Readme.txt9
-rw-r--r--Whisper/Resource.rcbin0 -> 5246 bytes
-rw-r--r--Whisper/Utils/CpuProfiler.cpp65
-rw-r--r--Whisper/Utils/CpuProfiler.h26
-rw-r--r--Whisper/Utils/GpuProfiler.cpp374
-rw-r--r--Whisper/Utils/GpuProfiler.h187
-rw-r--r--Whisper/Utils/GpuProfilerSimple.h14
-rw-r--r--Whisper/Utils/Logger.cpp240
-rw-r--r--Whisper/Utils/Logger.h23
-rw-r--r--Whisper/Utils/ProfileCollection.cpp331
-rw-r--r--Whisper/Utils/ProfileCollection.h112
-rw-r--r--Whisper/Utils/ReadStream.h37
-rw-r--r--Whisper/Utils/Trace/TraceStructures.cpp31
-rw-r--r--Whisper/Utils/Trace/TraceStructures.h55
-rw-r--r--Whisper/Utils/Trace/TraceWriter.cpp263
-rw-r--r--Whisper/Utils/Trace/TraceWriter.h70
-rw-r--r--Whisper/Utils/Trace/tracing.cpp60
-rw-r--r--Whisper/Utils/Trace/tracing.h67
-rw-r--r--Whisper/Utils/miscUtils.cpp33
-rw-r--r--Whisper/Utils/miscUtils.h81
-rw-r--r--Whisper/Utils/parallelFor.cpp144
-rw-r--r--Whisper/Utils/parallelFor.h38
-rw-r--r--Whisper/Whisper.vcxproj347
-rw-r--r--Whisper/Whisper.vcxproj.filters214
-rw-r--r--Whisper/Whisper/ContextImpl.capture.cpp418
-rw-r--r--Whisper/Whisper/ContextImpl.cpp528
-rw-r--r--Whisper/Whisper/ContextImpl.h75
-rw-r--r--Whisper/Whisper/ContextImpl.misc.cpp408
-rw-r--r--Whisper/Whisper/DecoderInputBuffers.cpp66
-rw-r--r--Whisper/Whisper/DecoderInputBuffers.h29
-rw-r--r--Whisper/Whisper/DecoderResultBuffer.cpp48
-rw-r--r--Whisper/Whisper/DecoderResultBuffer.h29
-rw-r--r--Whisper/Whisper/KeyValueBuffers.cpp42
-rw-r--r--Whisper/Whisper/KeyValueBuffers.h50
-rw-r--r--Whisper/Whisper/Languages.cpp122
-rw-r--r--Whisper/Whisper/Languages.h12
-rw-r--r--Whisper/Whisper/MelInputTensor.cpp63
-rw-r--r--Whisper/Whisper/MelInputTensor.h22
-rw-r--r--Whisper/Whisper/MelStreamer.cpp493
-rw-r--r--Whisper/Whisper/MelStreamer.h99
-rw-r--r--Whisper/Whisper/ModelBuffers.cpp115
-rw-r--r--Whisper/Whisper/ModelBuffers.h114
-rw-r--r--Whisper/Whisper/ModelImpl.cpp122
-rw-r--r--Whisper/Whisper/ModelImpl.h40
-rw-r--r--Whisper/Whisper/ModelLoader.h29
-rw-r--r--Whisper/Whisper/Spectrogram.cpp124
-rw-r--r--Whisper/Whisper/Spectrogram.h42
-rw-r--r--Whisper/Whisper/TranscribeResult.h43
-rw-r--r--Whisper/Whisper/Vocabulary.cpp129
-rw-r--r--Whisper/Whisper/Vocabulary.h58
-rw-r--r--Whisper/Whisper/WhisperContext.cpp673
-rw-r--r--Whisper/Whisper/WhisperContext.h126
-rw-r--r--Whisper/Whisper/WhisperModel.cpp511
-rw-r--r--Whisper/Whisper/WhisperModel.h54
-rw-r--r--Whisper/Whisper/audioConstants.h14
-rw-r--r--Whisper/Whisper/iSpectrogram.h38
-rw-r--r--Whisper/Whisper/languageCodez.inl100
-rw-r--r--Whisper/Whisper/languageCodez.tsv99
-rw-r--r--Whisper/Whisper/loaderUtils.h24
-rw-r--r--Whisper/Whisper/melSpectrogram.cpp298
-rw-r--r--Whisper/Whisper/melSpectrogram.h34
-rw-r--r--Whisper/Whisper/sEncodeParams.h20
-rw-r--r--Whisper/Whisper/sModelParams.h19
-rw-r--r--Whisper/Whisper/sTokenData.h23
-rw-r--r--Whisper/Whisper/voiceActivityDetection.cpp199
-rw-r--r--Whisper/Whisper/voiceActivityDetection.h54
-rw-r--r--Whisper/misc.natvis50
-rw-r--r--Whisper/modelFactory.cpp19
-rw-r--r--Whisper/modelFactory.h11
-rw-r--r--Whisper/resource.h14
-rw-r--r--Whisper/source.compat/Readme.txt1
-rw-r--r--Whisper/source.compat/convertThings.cpp234
-rw-r--r--Whisper/source.compat/convertThings.h10
-rw-r--r--Whisper/source.compat/ggmlMsvc.c37
-rw-r--r--Whisper/source/LICENSE21
-rw-r--r--Whisper/source/Readme.txt1
-rw-r--r--Whisper/source/ggml.c8336
-rw-r--r--Whisper/source/ggml.h737
-rw-r--r--Whisper/source/whisper.cpp3601
-rw-r--r--Whisper/source/whisper.h330
-rw-r--r--Whisper/stdafx.cpp1
-rw-r--r--Whisper/stdafx.h43
-rw-r--r--Whisper/whisper.def7
-rw-r--r--Whisper/whisperCom.cpp1070
-rw-r--r--WhisperCpp.sln110
-rw-r--r--WhisperNet/API/CaptureDeviceId.cs24
-rw-r--r--WhisperNet/API/Parameters.cs95
-rw-r--r--WhisperNet/API/SpecialTokens.cs23
-rw-r--r--WhisperNet/API/eCaptureStatus.cs19
-rw-r--r--WhisperNet/API/eLanguage.cs206
-rw-r--r--WhisperNet/API/eLogLevel.cs34
-rw-r--r--WhisperNet/API/eModelImplementation.cs25
-rw-r--r--WhisperNet/API/eResultFlags.cs21
-rw-r--r--WhisperNet/API/iAudioBuffer.cs27
-rw-r--r--WhisperNet/API/iAudioReader.cs23
-rw-r--r--WhisperNet/API/iMediaFoundation.cs36
-rw-r--r--WhisperNet/API/iModel.cs27
-rw-r--r--WhisperNet/API/sCaptureParams.cs37
-rw-r--r--WhisperNet/Callbacks.cs44
-rw-r--r--WhisperNet/CaptureCallbacks.cs49
-rw-r--r--WhisperNet/Context.cs201
-rw-r--r--WhisperNet/ExtensionMethods.cs69
-rw-r--r--WhisperNet/Internal/AssemblyInfo.cs8
-rw-r--r--WhisperNet/Internal/NativeLogger.cs138
-rw-r--r--WhisperNet/Internal/iContext.cs37
-rw-r--r--WhisperNet/Internal/iTranscribeResult.cs122
-rw-r--r--WhisperNet/Internal/sCaptureCallbacks.cs23
-rw-r--r--WhisperNet/Internal/sCaptureDevice.cs22
-rw-r--r--WhisperNet/Internal/sFullParams.cs40
-rw-r--r--WhisperNet/Internal/sLoadModelCallbacks.cs64
-rw-r--r--WhisperNet/Internal/sLoggerSetup.cs16
-rw-r--r--WhisperNet/Internal/sProgressSink.cs20
-rw-r--r--WhisperNet/Library.cs110
-rw-r--r--WhisperNet/Readme.txt1
-rw-r--r--WhisperNet/WhisperNet.csproj24
-rw-r--r--WhisperNet/WhisperNet.nuspec27
392 files changed, 75260 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..1cbaff1
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,17 @@
+.vs/
+ComLightLib/x64/
+Whisper/x64/
+x64/
+Tools/CompressShaders/bin/
+Tools/CompressShaders/obj/
+Whisper/D3D/shaderData-Debug.inl
+Whisper/D3D/shaderData-Release.inl
+WhisperNet/bin/
+WhisperNet/obj/
+Examples/TranscribeCS/bin/
+Examples/TranscribeCS/obj/
+*.aps
+*.json
+*.user
+Examples/MicrophoneCS/obj/
+Examples/MicrophoneCS/bin/ \ No newline at end of file
diff --git a/ComLightLib/ComLightLib.vcxproj b/ComLightLib/ComLightLib.vcxproj
new file mode 100644
index 0000000..ee7f07b
--- /dev/null
+++ b/ComLightLib/ComLightLib.vcxproj
@@ -0,0 +1,116 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="comLightClient.h" />
+ <ClInclude Include="client\CComPtr.hpp" />
+ <ClInclude Include="comLightServer.h" />
+ <ClInclude Include="comLightCommon.h" />
+ <ClInclude Include="server\freeThreadedMarshaller.h" />
+ <ClInclude Include="hresult.h" />
+ <ClInclude Include="server\ObjectRoot.hpp" />
+ <ClInclude Include="pal\guiddef.h" />
+ <ClInclude Include="server\Object.hpp" />
+ <ClInclude Include="server\interfaceMap.h" />
+ <ClInclude Include="server\RefCounter.hpp" />
+ <ClInclude Include="Exception.hpp" />
+ <ClInclude Include="streams.h" />
+ <ClInclude Include="utils\guid_parse.hpp" />
+ <ClInclude Include="pal\hresult.h" />
+ <ClInclude Include="unknwn.h" />
+ <ClInclude Include="utils\typeTraits.hpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="server\freeThreadedMarshaller.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <Text Include="Readme.txt" />
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>15.0</VCProjectVersion>
+ <ProjectGuid>{52F486E7-830C-45D8-BE47-E76B5AAB2772}</ProjectGuid>
+ <Keyword>Win32Proj</Keyword>
+ <RootNamespace>ComLightLib</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <LinkIncremental>true</LinkIncremental>
+ <OutDir>$(Platform)\$(Configuration)\</OutDir>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <LinkIncremental>false</LinkIncremental>
+ <OutDir>$(Platform)\$(Configuration)\</OutDir>
+ </PropertyGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <PrecompiledHeader>NotUsing</PrecompiledHeader>
+ <WarningLevel>Level3</WarningLevel>
+ <Optimization>Disabled</Optimization>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>_DEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <PrecompiledHeader>NotUsing</PrecompiledHeader>
+ <WarningLevel>Level3</WarningLevel>
+ <Optimization>MaxSpeed</Optimization>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/ComLightLib/ComLightLib.vcxproj.filters b/ComLightLib/ComLightLib.vcxproj.filters
new file mode 100644
index 0000000..aa2c81d
--- /dev/null
+++ b/ComLightLib/ComLightLib.vcxproj.filters
@@ -0,0 +1,28 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <ClInclude Include="pal\hresult.h" />
+ <ClInclude Include="pal\guiddef.h" />
+ <ClInclude Include="utils\guid_parse.hpp" />
+ <ClInclude Include="unknwn.h" />
+ <ClInclude Include="comLightClient.h" />
+ <ClInclude Include="comLightServer.h" />
+ <ClInclude Include="client\CComPtr.hpp" />
+ <ClInclude Include="comLightCommon.h" />
+ <ClInclude Include="server\RefCounter.hpp" />
+ <ClInclude Include="server\interfaceMap.h" />
+ <ClInclude Include="server\Object.hpp" />
+ <ClInclude Include="utils\typeTraits.hpp" />
+ <ClInclude Include="server\freeThreadedMarshaller.h" />
+ <ClInclude Include="hresult.h" />
+ <ClInclude Include="server\ObjectRoot.hpp" />
+ <ClInclude Include="Exception.hpp" />
+ <ClInclude Include="streams.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="server\freeThreadedMarshaller.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <Text Include="Readme.txt" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/ComLightLib/Exception.hpp b/ComLightLib/Exception.hpp
new file mode 100644
index 0000000..57c1b78
--- /dev/null
+++ b/ComLightLib/Exception.hpp
@@ -0,0 +1,20 @@
+#pragma once
+
+namespace ComLight
+{
+ class Exception : public std::runtime_error
+ {
+ // I don't like C++ exceptions too much, but for some cases they are useful.
+ // You can throw ComLight::Exception from constructor, or from FinalConstruct() method, the library will catch & return the code from the class factory function.
+ // Unfortunately, for interface methods this doesn't work, the C++ parts of the library can't catch them without very complex trickery like code generation.
+ // You can still use this class in methods, but you'll need to catch them manually near the API boundary or the app will crash.
+ // C++ doesn't have an ABI, the framework can't catch C++ exception across the modules.
+ const HRESULT m_code;
+
+ public:
+
+ Exception( HRESULT hr ) : runtime_error( "ComLight HRESULT exception" ), m_code( hr ) { }
+
+ HRESULT code() const { return m_code; }
+ };
+} \ No newline at end of file
diff --git a/ComLightLib/Readme.txt b/ComLightLib/Readme.txt
new file mode 100644
index 0000000..1eabeec
--- /dev/null
+++ b/ComLightLib/Readme.txt
@@ -0,0 +1,3 @@
+Copy-pasted from there:
+https://github.com/Const-me/ComLightInterop/tree/master/ComLightLib
+With only a few minor changes. \ No newline at end of file
diff --git a/ComLightLib/client/CComPtr.hpp b/ComLightLib/client/CComPtr.hpp
new file mode 100644
index 0000000..7786591
--- /dev/null
+++ b/ComLightLib/client/CComPtr.hpp
@@ -0,0 +1,110 @@
+#pragma once
+
+namespace ComLight
+{
+ // COM smart pointer, very comparable to CComPtr from ATL
+ template <class I>
+ class CComPtr
+ {
+ I* p;
+
+ void callAddRef() const
+ {
+ if( nullptr == p )
+ return;
+ p->AddRef();
+ }
+
+ public:
+
+ // Construct with nullptr
+ CComPtr() : p( nullptr ) { }
+
+ // Release the pointer
+ void release()
+ {
+ if( nullptr == p )
+ return;
+ p->Release();
+ p = nullptr;
+ }
+
+ ~CComPtr()
+ {
+ release();
+ }
+
+ // Attach without AddRef()
+ void attach( I* raw )
+ {
+ release();
+ p = raw;
+ }
+
+ // Detach without Release(), set this pointer to nullptr
+ I* detach()
+ {
+ I* const result = p;
+ p = nullptr;
+ return result;
+ }
+
+ // Detach without Release() and place to the specified address, set this pointer to nullptr
+ template<class Other>
+ void detach( Other** pp )
+ {
+ // If the argument points to a non-empty object, release the old instance: would leak memory otherwise.
+ if( nullptr != *pp )
+ ( *pp )->Release();
+ ( *pp ) = detach();
+ }
+
+ // Set and AddRef()
+ void assign( I* raw )
+ {
+ release();
+ attach( raw );
+ callAddRef();
+ }
+
+ void swap( CComPtr<I>& that )
+ {
+ std::swap( p, that.p );
+ }
+
+ // Set and AddRef()
+ CComPtr( I* raw ) : p( raw )
+ {
+ callAddRef();
+ }
+
+ // Set and AddRef()
+ CComPtr( const CComPtr<I>& that ) : CComPtr( that.p ) { }
+ // Move constructor
+ CComPtr( CComPtr<I>&& that ) : p( that.p ) { that.p = nullptr; }
+
+ // Set and AddRef()
+ void operator=( I* raw )
+ {
+ assign( raw );
+ }
+
+ // Set and AddRef()
+ void operator=( const CComPtr<I>& that )
+ {
+ assign( that.p );
+ }
+
+ // Move assignment operator, destroys the other one
+ void operator=( CComPtr<I>&& that )
+ {
+ attach( that.detach() );
+ }
+
+ operator I*( ) const { return p; }
+ I* operator -> () const { return p; }
+ I** operator &() { return &p; }
+
+ operator bool() const { return nullptr != p; }
+ };
+} \ No newline at end of file
diff --git a/ComLightLib/comLightClient.h b/ComLightLib/comLightClient.h
new file mode 100644
index 0000000..3174c92
--- /dev/null
+++ b/ComLightLib/comLightClient.h
@@ -0,0 +1,23 @@
+#pragma once
+#include "comLightCommon.h"
+#include "client/CComPtr.hpp"
+#include "utils/typeTraits.hpp"
+
+namespace ComLight
+{
+ namespace details
+ {
+ template<typename T>
+ inline constexpr void** castDoublePointerToVoid( T** pp )
+ {
+ static_assert( pointersAssignable<IUnknown, T>(), "IID_PPV_ARGS macro should be used with IUnknown interfaces" );
+ return reinterpret_cast<void**>( pp );
+ }
+ }
+}
+
+#ifdef IID_PPV_ARGS
+#undef IID_PPV_ARGS
+#endif
+
+#define IID_PPV_ARGS( pp ) decltype( **pp )::iid, ::ComLight::details::castDoublePointerToVoid( pp ) \ No newline at end of file
diff --git a/ComLightLib/comLightCommon.h b/ComLightLib/comLightCommon.h
new file mode 100644
index 0000000..c571910
--- /dev/null
+++ b/ComLightLib/comLightCommon.h
@@ -0,0 +1,11 @@
+#pragma once
+#include "hresult.h"
+
+#ifdef _MSC_VER
+#include <guiddef.h>
+#else
+#include "pal/guiddef.h"
+using LPCTSTR = const char*;
+#endif
+
+#include "unknwn.h" \ No newline at end of file
diff --git a/ComLightLib/comLightServer.h b/ComLightLib/comLightServer.h
new file mode 100644
index 0000000..8b2e844
--- /dev/null
+++ b/ComLightLib/comLightServer.h
@@ -0,0 +1,15 @@
+#pragma once
+#include "comLightCommon.h"
+#include "client/CComPtr.hpp"
+
+#include "server/ObjectRoot.hpp"
+#include "server/interfaceMap.h"
+#include "server/Object.hpp"
+#include "server/freeThreadedMarshaller.h"
+
+#ifdef _MSC_VER
+// On Windows, it's controlled by library.def module definition file. There's __declspec(dllexport), but it adds underscore, I don't like that.
+#define DLLEXPORT extern "C"
+#else
+#define DLLEXPORT extern "C" __attribute__((visibility("default")))
+#endif \ No newline at end of file
diff --git a/ComLightLib/hresult.h b/ComLightLib/hresult.h
new file mode 100644
index 0000000..dbaee67
--- /dev/null
+++ b/ComLightLib/hresult.h
@@ -0,0 +1,26 @@
+#pragma once
+#include <stdint.h>
+#ifdef _MSC_VER
+#include <winerror.h>
+#include <OleCtl.h>
+#else
+#include "pal/hresult.h"
+#endif
+
+#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
+
+#ifndef _MSC_VER
+inline constexpr HRESULT HRESULT_FROM_WIN32( int c )
+{
+ return c < 0 ? c : ( ( 0xFFFF & c ) | 0x80070000 );
+}
+
+constexpr HRESULT OLE_E_BLANK = _HRESULT_TYPEDEF_( 0x80040007 );
+constexpr HRESULT E_BOUNDS = _HRESULT_TYPEDEF_( 0x8000000BL );
+
+constexpr int ERROR_HANDLE_EOF = 38;
+constexpr int ERROR_ALREADY_INITIALIZED = 1247;
+#endif
+
+constexpr HRESULT E_EOF = HRESULT_FROM_WIN32( ERROR_HANDLE_EOF );
+constexpr HRESULT E_ALREADY_INITIALIZED = HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); \ No newline at end of file
diff --git a/ComLightLib/pal/guiddef.h b/ComLightLib/pal/guiddef.h
new file mode 100644
index 0000000..ed8259f
--- /dev/null
+++ b/ComLightLib/pal/guiddef.h
@@ -0,0 +1,21 @@
+#pragma once
+#include <stdint.h>
+#include <array>
+#ifndef GUID_DEFINED
+#define GUID_DEFINED
+#endif
+
+struct GUID
+{
+ uint32_t Data1;
+ uint16_t Data2;
+ uint16_t Data3;
+ std::array<uint8_t, 8> Data4;
+
+ constexpr inline bool operator==( const GUID& that ) const
+ {
+ return Data1 == that.Data1 && Data2 == that.Data2 && Data3 == that.Data3 && Data4 == that.Data4;
+ }
+};
+
+using REFIID = const GUID&; \ No newline at end of file
diff --git a/ComLightLib/pal/hresult.h b/ComLightLib/pal/hresult.h
new file mode 100644
index 0000000..c261458
--- /dev/null
+++ b/ComLightLib/pal/hresult.h
@@ -0,0 +1,101 @@
+#pragma once
+#include <stdint.h>
+using HRESULT = int32_t;
+#define _HRESULT_TYPEDEF_(_sc) ((HRESULT)_sc)
+#define SEVERITY_ERROR 1
+#define FACILITY_CONTROL 10
+
+inline constexpr HRESULT MAKE_SCODE( uint32_t sev, uint32_t fac, uint32_t code )
+{
+ return (HRESULT)( ( (uint32_t)( sev ) << 31 ) | ( (unsigned long)( fac ) << 16 ) | ( (unsigned long)( code ) ) );
+};
+
+// ==== Copy-pasted from coreclr-master\src\pal\inc\rt\palrt.h ====
+#define S_OK _HRESULT_TYPEDEF_(0x00000000L)
+#define S_FALSE _HRESULT_TYPEDEF_(0x00000001L)
+
+#define E_NOTIMPL _HRESULT_TYPEDEF_(0x80004001L)
+#define E_NOINTERFACE _HRESULT_TYPEDEF_(0x80004002L)
+#define E_UNEXPECTED _HRESULT_TYPEDEF_(0x8000FFFFL)
+#define E_OUTOFMEMORY _HRESULT_TYPEDEF_(0x8007000EL)
+#define E_INVALIDARG _HRESULT_TYPEDEF_(0x80070057L)
+#define E_POINTER _HRESULT_TYPEDEF_(0x80004003L)
+#define E_HANDLE _HRESULT_TYPEDEF_(0x80070006L)
+#define E_ABORT _HRESULT_TYPEDEF_(0x80004004L)
+#define E_FAIL _HRESULT_TYPEDEF_(0x80004005L)
+#define E_ACCESSDENIED _HRESULT_TYPEDEF_(0x80070005L)
+#define E_PENDING _HRESULT_TYPEDEF_(0x8000000AL)
+
+#define DISP_E_PARAMNOTFOUND _HRESULT_TYPEDEF_(0x80020004L)
+#define DISP_E_TYPEMISMATCH _HRESULT_TYPEDEF_(0x80020005L)
+#define DISP_E_BADVARTYPE _HRESULT_TYPEDEF_(0x80020008L)
+#define DISP_E_OVERFLOW _HRESULT_TYPEDEF_(0x8002000AL)
+#define DISP_E_DIVBYZERO _HRESULT_TYPEDEF_(0x80020012L)
+
+#define CLASS_E_CLASSNOTAVAILABLE _HRESULT_TYPEDEF_(0x80040111L)
+#define CLASS_E_NOAGGREGATION _HRESULT_TYPEDEF_(0x80040110L)
+
+#define CO_E_CLASSSTRING _HRESULT_TYPEDEF_(0x800401F3L)
+
+#define MK_E_SYNTAX _HRESULT_TYPEDEF_(0x800401E4L)
+
+#define STG_E_INVALIDFUNCTION _HRESULT_TYPEDEF_(0x80030001L)
+#define STG_E_FILENOTFOUND _HRESULT_TYPEDEF_(0x80030002L)
+#define STG_E_PATHNOTFOUND _HRESULT_TYPEDEF_(0x80030003L)
+#define STG_E_WRITEFAULT _HRESULT_TYPEDEF_(0x8003001DL)
+#define STG_E_FILEALREADYEXISTS _HRESULT_TYPEDEF_(0x80030050L)
+#define STG_E_ABNORMALAPIEXIT _HRESULT_TYPEDEF_(0x800300FAL)
+
+#define NTE_BAD_UID _HRESULT_TYPEDEF_(0x80090001L)
+#define NTE_BAD_HASH _HRESULT_TYPEDEF_(0x80090002L)
+#define NTE_BAD_KEY _HRESULT_TYPEDEF_(0x80090003L)
+#define NTE_BAD_LEN _HRESULT_TYPEDEF_(0x80090004L)
+#define NTE_BAD_DATA _HRESULT_TYPEDEF_(0x80090005L)
+#define NTE_BAD_SIGNATURE _HRESULT_TYPEDEF_(0x80090006L)
+#define NTE_BAD_VER _HRESULT_TYPEDEF_(0x80090007L)
+#define NTE_BAD_ALGID _HRESULT_TYPEDEF_(0x80090008L)
+#define NTE_BAD_FLAGS _HRESULT_TYPEDEF_(0x80090009L)
+#define NTE_BAD_TYPE _HRESULT_TYPEDEF_(0x8009000AL)
+#define NTE_BAD_KEY_STATE _HRESULT_TYPEDEF_(0x8009000BL)
+#define NTE_BAD_HASH_STATE _HRESULT_TYPEDEF_(0x8009000CL)
+#define NTE_NO_KEY _HRESULT_TYPEDEF_(0x8009000DL)
+#define NTE_NO_MEMORY _HRESULT_TYPEDEF_(0x8009000EL)
+#define NTE_SIGNATURE_FILE_BAD _HRESULT_TYPEDEF_(0x8009001CL)
+#define NTE_FAIL _HRESULT_TYPEDEF_(0x80090020L)
+
+#define CRYPT_E_HASH_VALUE _HRESULT_TYPEDEF_(0x80091007L)
+
+#define TYPE_E_SIZETOOBIG _HRESULT_TYPEDEF_(0x800288C5L)
+#define TYPE_E_DUPLICATEID _HRESULT_TYPEDEF_(0x800288C6L)
+
+#define STD_CTL_SCODE(n) MAKE_SCODE(SEVERITY_ERROR, FACILITY_CONTROL, n)
+#define CTL_E_OVERFLOW STD_CTL_SCODE(6)
+#define CTL_E_OUTOFMEMORY STD_CTL_SCODE(7)
+#define CTL_E_DIVISIONBYZERO STD_CTL_SCODE(11)
+#define CTL_E_OUTOFSTACKSPACE STD_CTL_SCODE(28)
+#define CTL_E_FILENOTFOUND STD_CTL_SCODE(53)
+#define CTL_E_DEVICEIOERROR STD_CTL_SCODE(57)
+#define CTL_E_PERMISSIONDENIED STD_CTL_SCODE(70)
+#define CTL_E_PATHFILEACCESSERROR STD_CTL_SCODE(75)
+#define CTL_E_PATHNOTFOUND STD_CTL_SCODE(76)
+
+#define INET_E_CANNOT_CONNECT _HRESULT_TYPEDEF_(0x800C0004L)
+#define INET_E_RESOURCE_NOT_FOUND _HRESULT_TYPEDEF_(0x800C0005L)
+#define INET_E_OBJECT_NOT_FOUND _HRESULT_TYPEDEF_(0x800C0006L)
+#define INET_E_DATA_NOT_AVAILABLE _HRESULT_TYPEDEF_(0x800C0007L)
+#define INET_E_DOWNLOAD_FAILURE _HRESULT_TYPEDEF_(0x800C0008L)
+#define INET_E_CONNECTION_TIMEOUT _HRESULT_TYPEDEF_(0x800C000BL)
+#define INET_E_UNKNOWN_PROTOCOL _HRESULT_TYPEDEF_(0x800C000DL)
+
+#define DBG_PRINTEXCEPTION_C _HRESULT_TYPEDEF_(0x40010006L)
+// ==== Done pasting ====
+
+inline constexpr bool SUCCEEDED( HRESULT hr )
+{
+ return hr >= 0;
+}
+
+inline constexpr bool FAILED( HRESULT hr )
+{
+ return hr < 0;
+} \ No newline at end of file
diff --git a/ComLightLib/server/Object.hpp b/ComLightLib/server/Object.hpp
new file mode 100644
index 0000000..d2e3257
--- /dev/null
+++ b/ComLightLib/server/Object.hpp
@@ -0,0 +1,139 @@
+#pragma once
+#include <type_traits>
+#include "../comLightClient.h"
+#include "../utils/typeTraits.hpp"
+#include "../Exception.hpp"
+
+namespace ComLight
+{
+ namespace details
+ {
+ GENERATE_HAS_MEMBER( implQueryInterface );
+ GENERATE_HAS_MEMBER( implAddRef );
+ GENERATE_HAS_MEMBER( implRelease );
+ }
+
+ // Outer class of objects, implements IUnknown methods, also the class factory. The type argument must be your class implementing your interfaces, inherited from ObjectRoot<I>
+ template<class T>
+ class Object : public T
+ {
+ public:
+ Object() = default;
+
+ template<typename ... Args>
+ Object( Args&& ... args ) : T{ std::forward<Args>( args )... } {};
+
+ inline virtual ~Object() override { }
+
+ // Implement IUnknown methods
+ HRESULT COMLIGHTCALL QueryInterface( REFIID riid, void** ppvObject ) override
+ {
+ static_assert( details::has_member_implQueryInterface<T>::value, "Your object class must inherit from ComLight::ObjectRoot" );
+
+ if( nullptr == ppvObject )
+ return E_POINTER;
+
+ if( T::implQueryInterface( riid, ppvObject ) )
+ return S_OK;
+ if( T::queryExtraInterfaces( riid, ppvObject ) )
+ return S_OK;
+
+ if( riid == IUnknown::iid() )
+ {
+ ComLight::IUnknown* unk = T::getUnknown();
+ unk->AddRef();
+ *ppvObject = unk;
+ return S_OK;
+ }
+
+ return E_NOINTERFACE;
+ }
+
+ uint32_t COMLIGHTCALL AddRef() override
+ {
+ static_assert( details::has_member_implAddRef<T>::value, "Your object class must inherit from ComLight::ObjectRoot" );
+ return T::implAddRef();
+ }
+
+ uint32_t COMLIGHTCALL Release() override
+ {
+ static_assert( details::has_member_implRelease<T>::value, "Your object class must inherit from ComLight::ObjectRoot" );
+ const uint32_t ret = T::implRelease();
+ if( 0 == ret )
+ {
+ T::FinalRelease();
+ delete this;
+ }
+ return ret;
+ }
+
+ // Create a new object on the heap, store in smart pointer
+ static inline HRESULT create( CComPtr<Object<T>>& result )
+ {
+ CComPtr<Object<T>> ptr;
+ try
+ {
+ ptr = new Object<T>(); // The RefCounter constructor creates it with ref.counter 0. But then CComPtr constructor calls AddRef so we have RC=1 after this line.
+
+ HRESULT hr = ptr->internalFinalConstruct();
+ if( FAILED( hr ) )
+ return hr;
+
+ hr = ptr->FinalConstruct();
+ if( FAILED( hr ) )
+ return hr;
+
+ ptr.swap( result );
+ return S_OK;
+ }
+ catch( const Exception& ex )
+ {
+ return ex.code();
+ }
+ }
+
+ // Create a new object on the heap, store in smart pointer
+ template<typename ... Args>
+ static inline HRESULT create( CComPtr<Object<T>>& result, Args&& ... args )
+ {
+ CComPtr<Object<T>> ptr;
+ try
+ {
+ ptr = new Object<T>( std::forward<Args>( args )... );
+
+ HRESULT hr = ptr->internalFinalConstruct();
+ if( FAILED( hr ) )
+ return hr;
+
+ hr = ptr->FinalConstruct();
+ if( FAILED( hr ) )
+ return hr;
+
+ ptr.swap( result );
+ return S_OK;
+ }
+ catch( const Exception& ex )
+ {
+ return ex.code();
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+ }
+
+ // Create a new object on the heap, return one of it's interfaces. The caller is assumed to take ownership of the new object.
+ template<class I>
+ static inline HRESULT create( I** pp )
+ {
+ if( pp == nullptr )
+ return E_POINTER;
+
+ static_assert( details::pointersAssignable<I, T>(), "Object::create can't cast object to the requested interface" );
+ CComPtr<Object<T>> ptr;
+ CHECK( create( ptr ) );
+ ptr.detach( pp );
+ return S_OK;
+ }
+ };
+} \ No newline at end of file
diff --git a/ComLightLib/server/ObjectRoot.hpp b/ComLightLib/server/ObjectRoot.hpp
new file mode 100644
index 0000000..1cc8f4d
--- /dev/null
+++ b/ComLightLib/server/ObjectRoot.hpp
@@ -0,0 +1,51 @@
+#pragma once
+#include "RefCounter.hpp"
+#include "../comLightCommon.h"
+#include "../utils/typeTraits.hpp"
+
+namespace ComLight
+{
+ // Base class of objects, implements reference counting, also a few lifetime methods.
+ // The template argument is the interface you want clients to get when they ask for IID_IUnknown. By convention, that pointer defines object's identity.
+ template<class I>
+ class ObjectRoot : public RefCounter, public I
+ {
+ protected:
+
+ inline HRESULT internalFinalConstruct()
+ {
+ return S_FALSE;
+ }
+
+ inline HRESULT FinalConstruct()
+ {
+ return S_FALSE;
+ }
+
+ inline void FinalRelease() { }
+
+ IUnknown* getUnknown()
+ {
+ static_assert( details::pointersAssignable<IUnknown, I>(), "The interface doesn't derive from IUnknown" );
+ return static_cast<I*>( this );
+ }
+
+ bool queryExtraInterfaces( REFIID riid, void **ppvObject ) const
+ {
+ return false;
+ }
+
+ // Implement query interface with 2 entries, IUnknown and I.
+ bool implQueryInterface( REFIID riid, void** ppvObject )
+ {
+ if( riid == I::iid() || riid == IUnknown::iid() )
+ {
+ I* const result = this;
+ result->AddRef();
+ *ppvObject = result;
+ return true;
+ }
+ return false;
+ }
+ };
+} \ No newline at end of file
diff --git a/ComLightLib/server/RefCounter.hpp b/ComLightLib/server/RefCounter.hpp
new file mode 100644
index 0000000..9698cc8
--- /dev/null
+++ b/ComLightLib/server/RefCounter.hpp
@@ -0,0 +1,38 @@
+#pragma once
+#include <atomic>
+#include <assert.h>
+#include <limits.h>
+
+namespace ComLight
+{
+ // Very base class of objects, implements reference counting.
+ class RefCounter
+ {
+ std::atomic_uint referenceCounter;
+
+ public:
+
+ RefCounter() : referenceCounter( 0 ) { }
+
+ inline virtual ~RefCounter() { }
+
+ RefCounter( const RefCounter &that ) = delete;
+ RefCounter( RefCounter &&that ) = delete;
+
+ protected:
+
+ uint32_t implAddRef()
+ {
+ return ++referenceCounter;
+ }
+
+ uint32_t implRelease()
+ {
+ // Might be a good idea to use locks, at least in debug builds. They're much slower than atomics, but with locks it's possible to detect when 2 threads call release at the same time, for object with counter = 1.
+ // It's a memory management bug, but it would be nice if debug builds would handle that case gracefully.
+ const uint32_t rc = --referenceCounter;
+ assert( rc != UINT_MAX );
+ return rc;
+ }
+ };
+} \ No newline at end of file
diff --git a/ComLightLib/server/freeThreadedMarshaller.cpp b/ComLightLib/server/freeThreadedMarshaller.cpp
new file mode 100644
index 0000000..fc1ea80
--- /dev/null
+++ b/ComLightLib/server/freeThreadedMarshaller.cpp
@@ -0,0 +1,17 @@
+#include "freeThreadedMarshaller.h"
+#ifdef _MSC_VER
+#include <combaseapi.h>
+
+HRESULT ComLight::details::createFreeThreadedMarshaller( IUnknown* pUnkOuter, IUnknown** ppUnkMarshal )
+{
+ return ::CoCreateFreeThreadedMarshaler( (LPUNKNOWN)pUnkOuter, (LPUNKNOWN *)ppUnkMarshal );
+}
+
+bool ComLight::details::queryMarshallerInterface( REFIID riid, void **ppvObject, IUnknown* marshaller )
+{
+ if( riid != IID_IMarshal || nullptr == marshaller )
+ return false;
+ const HRESULT hr = marshaller->QueryInterface( IID_IMarshal, ppvObject );
+ return SUCCEEDED( hr ) ? true : false;
+}
+#endif \ No newline at end of file
diff --git a/ComLightLib/server/freeThreadedMarshaller.h b/ComLightLib/server/freeThreadedMarshaller.h
new file mode 100644
index 0000000..8ef774e
--- /dev/null
+++ b/ComLightLib/server/freeThreadedMarshaller.h
@@ -0,0 +1,29 @@
+#pragma once
+#ifdef _MSC_VER
+#include "../comLightCommon.h"
+
+namespace ComLight
+{
+ namespace details
+ {
+ HRESULT createFreeThreadedMarshaller( IUnknown* pUnkOuter, IUnknown** ppUnkMarshal );
+ bool queryMarshallerInterface( REFIID riid, void **ppvObject, IUnknown* marshaller );
+ }
+}
+
+#define DECLARE_FREE_THREADED_MARSHALLER() \
+private: \
+ComLight::CComPtr<ComLight::IUnknown> m_freeThreadedMarshaller; \
+protected: \
+HRESULT internalFinalConstruct() \
+{ \
+ return ComLight::details::createFreeThreadedMarshaller( getUnknown(), &m_freeThreadedMarshaller ); \
+} \
+bool queryExtraInterfaces( REFIID riid, void **ppvObject ) const \
+{ \
+ return ComLight::details::queryMarshallerInterface( riid, ppvObject, m_freeThreadedMarshaller ); \
+}
+
+#else
+#define DECLARE_FREE_THREADED_MARSHALLER()
+#endif \ No newline at end of file
diff --git a/ComLightLib/server/interfaceMap.h b/ComLightLib/server/interfaceMap.h
new file mode 100644
index 0000000..605ed33
--- /dev/null
+++ b/ComLightLib/server/interfaceMap.h
@@ -0,0 +1,31 @@
+#pragma once
+#include "../utils/typeTraits.hpp"
+
+// Unlike ATL, the interface map is optional for ComLight.
+// If you won't declare a map, the object will support 2 interfaces: IUnknown, and whatever template argument was passed to ObjectRoot class.
+#define BEGIN_COM_MAP() \
+protected: \
+bool implQueryInterface( REFIID iid, void** ppvObject ) {
+
+#define END_COM_MAP() return false; }
+
+namespace ComLight
+{
+ namespace details
+ {
+ template<typename I, typename C>
+ inline bool tryReturnInterface( REFIID iid, C* pThis, void** ppvResult )
+ {
+ static_assert( pointersAssignable<IUnknown, I>(), "Trying to implement an interface that doesn't derive from IUnknown" );
+ static_assert( pointersAssignable<I, C>(), "Declared support for an interface, but the class doesn't implement it" );
+ if( I::iid() != iid )
+ return false;
+ I* const result = pThis;
+ result->AddRef();
+ *ppvResult = result;
+ return true;
+ }
+ }
+}
+
+#define COM_INTERFACE_ENTRY( I ) if( ComLight::details::tryReturnInterface<I>( iid, this, ppvObject ) ) return true; \ No newline at end of file
diff --git a/ComLightLib/streams.h b/ComLightLib/streams.h
new file mode 100644
index 0000000..87680e1
--- /dev/null
+++ b/ComLightLib/streams.h
@@ -0,0 +1,61 @@
+#pragma once
+#include <vector>
+#include "comLightCommon.h"
+
+// COM interfaces to marshal streams across the interop.
+namespace ComLight
+{
+ enum struct eSeekOrigin : uint8_t
+ {
+ Begin = 0,
+ Current = 1,
+ End = 2
+ };
+
+ namespace details
+ {
+ template<class E>
+ inline size_t sizeofVector( const std::vector<E>& vec )
+ {
+ return sizeof( E ) * vec.size();
+ }
+ }
+
+ // COM interface for readonly stream. You'll get these interfaces what you use [ReadStream] attribute in C#.
+ struct DECLSPEC_NOVTABLE iReadStream : public IUnknown
+ {
+ DEFINE_INTERFACE_ID( "006af6db-734e-4595-8c94-19304b2389ac" );
+
+ virtual HRESULT COMLIGHTCALL read( void* lpBuffer, int nNumberOfBytesToRead, int &lpNumberOfBytesRead ) = 0;
+ virtual HRESULT COMLIGHTCALL seek( int64_t offset, eSeekOrigin origin ) = 0;
+ virtual HRESULT COMLIGHTCALL getPosition( int64_t& position ) = 0;
+ virtual HRESULT COMLIGHTCALL getLength( int64_t& length ) = 0;
+
+ template<class E>
+ inline HRESULT read( std::vector<E>& vec )
+ {
+ const int cb = (int)details::sizeofVector( vec );
+ int cbRead = 0;
+ CHECK( read( vec.data(), cb, cbRead ) );
+ if( cbRead >= cb )
+ return S_OK;
+ return E_EOF;
+ }
+ };
+
+ // COM interface for readonly stream. You'll get these interfaces what you use [WriteStream] attribute in C#.
+ struct DECLSPEC_NOVTABLE iWriteStream : public IUnknown
+ {
+ DEFINE_INTERFACE_ID( "d7c3eb39-9170-43b9-ba98-2ea1f2fed8a8" );
+
+ virtual HRESULT COMLIGHTCALL write( const void* lpBuffer, int nNumberOfBytesToWrite ) = 0;
+ virtual HRESULT COMLIGHTCALL flush() = 0;
+
+ template<class E>
+ inline HRESULT write( const std::vector<E>& vec )
+ {
+ const int cb = (int)details::sizeofVector( vec );
+ return write( vec.data(), cb );
+ }
+ };
+} \ No newline at end of file
diff --git a/ComLightLib/unknwn.h b/ComLightLib/unknwn.h
new file mode 100644
index 0000000..5f38359
--- /dev/null
+++ b/ComLightLib/unknwn.h
@@ -0,0 +1,36 @@
+#pragma once
+#include <type_traits>
+
+// Calling conventions
+#ifdef _MSC_VER
+#define COMLIGHTCALL __stdcall
+#define DECLSPEC_NOVTABLE __declspec(novtable)
+#elif defined(__GNUC__) || defined(__clang__)
+#if defined(__i386__)
+#define COMLIGHTCALL __attribute__((stdcall))
+#else
+#define COMLIGHTCALL
+#endif
+#define DECLSPEC_NOVTABLE
+#else
+#error Unsupported C++ compiler
+#endif
+
+#include "utils/guid_parse.hpp"
+
+#define DEFINE_INTERFACE_ID( guidString ) static constexpr GUID iid() { return ::ComLight::make_guid( guidString ); }
+
+namespace ComLight
+{
+ // This thing is binary compatible with IUnknown from Windows SDK. See DesktopClient demo project, it uses normal COM interop in .NET framework 4.7 to call my implementation.
+ struct DECLSPEC_NOVTABLE IUnknown
+ {
+ DEFINE_INTERFACE_ID( "00000000-0000-0000-c000-000000000046" );
+
+ virtual HRESULT COMLIGHTCALL QueryInterface( REFIID riid, void **ppvObject ) = 0;
+
+ virtual uint32_t COMLIGHTCALL AddRef() = 0;
+
+ virtual uint32_t COMLIGHTCALL Release() = 0;
+ };
+} \ No newline at end of file
diff --git a/ComLightLib/utils/guid_parse.hpp b/ComLightLib/utils/guid_parse.hpp
new file mode 100644
index 0000000..435fcb3
--- /dev/null
+++ b/ComLightLib/utils/guid_parse.hpp
@@ -0,0 +1,103 @@
+// https://github.com/tobias-loew/constexpr-GUID-cpp-11
+
+//-------------------------------------------------------------------------------------------------------
+// constexpr GUID parsing
+// Written by Alexander Bessonov
+// Written by Tobias Loew
+//
+// Licensed under the MIT license.
+//-------------------------------------------------------------------------------------------------------
+
+#pragma once
+#include <stdexcept>
+#include <string>
+#include <cassert>
+#include <cstdint>
+
+#if !defined(GUID_DEFINED)
+#define GUID_DEFINED
+struct GUID {
+ uint32_t Data1;
+ uint16_t Data2;
+ uint16_t Data3;
+ uint8_t Data4[ 8 ];
+};
+#endif
+
+namespace ComLight
+{
+ namespace details
+ {
+ constexpr const size_t short_guid_form_length = 36; // XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
+ constexpr const size_t long_guid_form_length = 38; // {XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}
+
+ constexpr uint8_t parse_hex_digit( const char c )
+ {
+ using namespace std::string_literals;
+ return
+ ( '0' <= c && c <= '9' )
+ ? c - '0'
+ : ( 'a' <= c && c <= 'f' )
+ ? 10 + c - 'a'
+ : ( 'A' <= c && c <= 'F' )
+ ? 10 + c - 'A'
+ :
+ throw std::domain_error{ "invalid character in GUID"s };
+ }
+
+ constexpr uint8_t parse_hex_uint8_t( const char *ptr )
+ {
+ return ( parse_hex_digit( ptr[ 0 ] ) << 4 ) + parse_hex_digit( ptr[ 1 ] );
+ }
+
+ constexpr uint16_t parse_hex_uint16_t( const char *ptr )
+ {
+ return ( parse_hex_uint8_t( ptr ) << 8 ) + parse_hex_uint8_t( ptr + 2 );
+ }
+
+ constexpr uint32_t parse_hex_uint32_t( const char *ptr )
+ {
+ return ( parse_hex_uint16_t( ptr ) << 16 ) + parse_hex_uint16_t( ptr + 4 );
+ }
+
+ constexpr GUID parse_guid( const char *begin )
+ {
+ return GUID{
+ parse_hex_uint32_t( begin ),
+ parse_hex_uint16_t( begin + 8 + 1 ),
+ parse_hex_uint16_t( begin + 8 + 1 + 4 + 1 ),
+ {
+ parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 ),
+ parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 ),
+ parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 ),
+ parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 ),
+ parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 + 2 ),
+ parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 + 2 + 2 ),
+ parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 + 2 + 2 + 2 ),
+ parse_hex_uint8_t( begin + 8 + 1 + 4 + 1 + 4 + 1 + 2 + 2 + 1 + 2 + 2 + 2 + 2 + 2 )
+ }
+ };
+ }
+
+ constexpr GUID make_guid_helper( const char *str, size_t N )
+ {
+ using namespace std::string_literals;
+ using namespace details;
+
+ return ( !( N == long_guid_form_length || N == short_guid_form_length ) )
+ ? throw std::domain_error{ "String GUID of the form {XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX} or XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX is expected"s }
+ : ( N == long_guid_form_length && ( str[ 0 ] != '{' || str[ long_guid_form_length - 1 ] != '}' ) )
+ ? throw std::domain_error{ "Missing opening or closing brace"s }
+
+ : parse_guid( str + ( N == long_guid_form_length ? 1 : 0 ) );
+ }
+
+
+ template<size_t N>
+ constexpr GUID make_guid( const char( &str )[ N ] )
+ {
+ return make_guid_helper( str, N - 1 );
+ }
+ }
+ using details::make_guid;
+} \ No newline at end of file
diff --git a/ComLightLib/utils/typeTraits.hpp b/ComLightLib/utils/typeTraits.hpp
new file mode 100644
index 0000000..c5ddb84
--- /dev/null
+++ b/ComLightLib/utils/typeTraits.hpp
@@ -0,0 +1,43 @@
+#pragma once
+#include <type_traits>
+
+namespace ComLight
+{
+ namespace details
+ {
+ template<class TResult, class TValue>
+ constexpr bool pointersAssignable()
+ {
+ // See this for why `&` is required: https://stackoverflow.com/a/52429468/126995
+ return std::is_assignable<TResult*&, TValue*>::value;
+ }
+ }
+}
+
+// https://en.wikibooks.org/wiki/More_C++_Idioms/Member_Detector
+#define GENERATE_HAS_MEMBER(member) \
+ \
+template < class T > \
+class HasMember_##member \
+{ \
+private: \
+ using Yes = char[2]; \
+ using No = char[1]; \
+ \
+ struct Fallback { int member; }; \
+ struct Derived : T, Fallback { }; \
+ \
+ template < class U > \
+ static No& test ( decltype(U::member)* ); \
+ template < typename U > \
+ static Yes& test ( U* ); \
+ \
+public: \
+ static constexpr bool RESULT = sizeof(test<Derived>(nullptr)) == sizeof(Yes); \
+}; \
+ \
+template < class T > \
+struct has_member_##member \
+: public std::integral_constant<bool, HasMember_##member<T>::RESULT> \
+{ \
+};
diff --git a/ComputeShaders/ComputeShaders.cpp b/ComputeShaders/ComputeShaders.cpp
new file mode 100644
index 0000000..9c03f27
--- /dev/null
+++ b/ComputeShaders/ComputeShaders.cpp
@@ -0,0 +1,3 @@
+void fnComputeShaders()
+{
+} \ No newline at end of file
diff --git a/ComputeShaders/ComputeShaders.vcxproj b/ComputeShaders/ComputeShaders.vcxproj
new file mode 100644
index 0000000..350d266
--- /dev/null
+++ b/ComputeShaders/ComputeShaders.vcxproj
@@ -0,0 +1,221 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>16.0</VCProjectVersion>
+ <Keyword>Win32Proj</Keyword>
+ <ProjectGuid>{1c39d386-96d0-47a1-bbfa-68bbdb24439c}</ProjectGuid>
+ <RootNamespace>ComputeShaders</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <MultiProcFXC>true</MultiProcFXC>
+ <OutDir>$(Platform)\$(Configuration)\</OutDir>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <OutDir>$(Platform)\$(Configuration)\</OutDir>
+ <MultiProcFXC>true</MultiProcFXC>
+ </PropertyGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ </ClCompile>
+ <Link>
+ <SubSystem>
+ </SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ </ClCompile>
+ <Link>
+ <SubSystem>
+ </SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ </ClCompile>
+ <Link>
+ <SubSystem>
+ </SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ <FxCompile>
+ <ShaderModel>5.0</ShaderModel>
+ <ShaderType>Compute</ShaderType>
+ </FxCompile>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ </ClCompile>
+ <Link>
+ <SubSystem>
+ </SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ <FxCompile>
+ <ShaderModel>5.0</ShaderModel>
+ <ShaderType>Compute</ShaderType>
+ <DisableOptimizations>true</DisableOptimizations>
+ <EnableDebuggingInformation>true</EnableDebuggingInformation>
+ </FxCompile>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ClCompile Include="ComputeShaders.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <FxCompile Include="add.hlsl" />
+ <FxCompile Include="addInPlace.hlsl" />
+ <FxCompile Include="addRepeat.hlsl" />
+ <FxCompile Include="addRepeat64.hlsl" />
+ <FxCompile Include="addRepeatGelu.hlsl" />
+ <FxCompile Include="addRepeatGelu64.hlsl" />
+ <FxCompile Include="addRepeatScale.hlsl" />
+ <FxCompile Include="addRows.hlsl" />
+ <FxCompile Include="convolutionMain.hlsl" />
+ <FxCompile Include="convolutionMain2.hlsl" />
+ <FxCompile Include="convolutionMain2Fixed.hlsl" />
+ <FxCompile Include="convolutionPrep1.hlsl" />
+ <FxCompile Include="convolutionPrep2.hlsl" />
+ <FxCompile Include="copyConvert.hlsl" />
+ <FxCompile Include="copyTranspose.hlsl" />
+ <FxCompile Include="diagMaskInf.hlsl" />
+ <FxCompile Include="flashAttention.hlsl" />
+ <FxCompile Include="flashAttentionCompat1.hlsl" />
+ <FxCompile Include="flashAttentionCompat2.hlsl" />
+ <FxCompile Include="flashAttentionCompat3.hlsl" />
+ <FxCompile Include="fmaRepeat1.hlsl" />
+ <FxCompile Include="fmaRepeat164.hlsl" />
+ <FxCompile Include="fmaRepeat2.hlsl" />
+ <FxCompile Include="matReshapePanels.hlsl" />
+ <FxCompile Include="mulMatByRow.hlsl" />
+ <FxCompile Include="mulMatByRow64.hlsl" />
+ <FxCompile Include="mulMatByRowTiled.hlsl" />
+ <FxCompile Include="mulMatByRowTiled64.hlsl" />
+ <FxCompile Include="mulMatByRowTiledEx.hlsl" />
+ <FxCompile Include="mulMatByScalar.hlsl" />
+ <FxCompile Include="mulMatDotMain.hlsl" />
+ <FxCompile Include="mulMatDotReshape.hlsl" />
+ <FxCompile Include="mulMatMadMain.hlsl" />
+ <FxCompile Include="mulMatTiled.hlsl" />
+ <FxCompile Include="mulMatTiled64.hlsl" />
+ <FxCompile Include="mulMatTiledEx.hlsl" />
+ <FxCompile Include="norm.hlsl" />
+ <FxCompile Include="normCompat.hlsl" />
+ <FxCompile Include="normFixed.hlsl" />
+ <FxCompile Include="normFixed64.hlsl" />
+ <FxCompile Include="scaleInPlace.hlsl" />
+ <FxCompile Include="softMax.hlsl" />
+ <FxCompile Include="softMax64.hlsl" />
+ <FxCompile Include="softMaxCompat.hlsl" />
+ <FxCompile Include="softMaxFixed.hlsl" />
+ <FxCompile Include="zeroMemory.hlsl" />
+ </ItemGroup>
+ <ItemGroup>
+ <None Include="componentwiseBinaryOp.hlsli" />
+ <None Include="flashAttentionCommon.hlsli" />
+ <None Include="fp64Utils.hlsli" />
+ <None Include="groupReduce.hlsli" />
+ <None Include="groupReduce64.hlsli" />
+ <None Include="miscUtils.hlsli" />
+ <None Include="repeatUtils.hlsli" />
+ </ItemGroup>
+ <ItemGroup>
+ <Text Include="Readme.txt" />
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/ComputeShaders/ComputeShaders.vcxproj.filters b/ComputeShaders/ComputeShaders.vcxproj.filters
new file mode 100644
index 0000000..12f1559
--- /dev/null
+++ b/ComputeShaders/ComputeShaders.vcxproj.filters
@@ -0,0 +1,66 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <ClCompile Include="ComputeShaders.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <FxCompile Include="mulMatDotMain.hlsl" />
+ <FxCompile Include="mulMatDotReshape.hlsl" />
+ <FxCompile Include="convolutionMain.hlsl" />
+ <FxCompile Include="convolutionPrep1.hlsl" />
+ <FxCompile Include="convolutionPrep2.hlsl" />
+ <FxCompile Include="add.hlsl" />
+ <FxCompile Include="flashAttention.hlsl" />
+ <FxCompile Include="convolutionMain2.hlsl" />
+ <FxCompile Include="norm.hlsl" />
+ <FxCompile Include="copyConvert.hlsl" />
+ <FxCompile Include="copyTranspose.hlsl" />
+ <FxCompile Include="normCompat.hlsl" />
+ <FxCompile Include="flashAttentionCompat1.hlsl" />
+ <FxCompile Include="flashAttentionCompat3.hlsl" />
+ <FxCompile Include="flashAttentionCompat2.hlsl" />
+ <FxCompile Include="scaleInPlace.hlsl" />
+ <FxCompile Include="diagMaskInf.hlsl" />
+ <FxCompile Include="softMaxCompat.hlsl" />
+ <FxCompile Include="mulMatMadMain.hlsl" />
+ <FxCompile Include="addRepeat.hlsl" />
+ <FxCompile Include="fmaRepeat1.hlsl" />
+ <FxCompile Include="fmaRepeat2.hlsl" />
+ <FxCompile Include="addInPlace.hlsl" />
+ <FxCompile Include="softMax.hlsl" />
+ <FxCompile Include="addRepeatScale.hlsl" />
+ <FxCompile Include="mulMatByRow.hlsl" />
+ <FxCompile Include="mulMatByScalar.hlsl" />
+ <FxCompile Include="mulMatTiled.hlsl" />
+ <FxCompile Include="mulMatByRow64.hlsl" />
+ <FxCompile Include="softMax64.hlsl" />
+ <FxCompile Include="softMaxFixed.hlsl" />
+ <FxCompile Include="addRepeat64.hlsl" />
+ <FxCompile Include="fmaRepeat164.hlsl" />
+ <FxCompile Include="addRepeatGelu.hlsl" />
+ <FxCompile Include="addRepeatGelu64.hlsl" />
+ <FxCompile Include="normFixed.hlsl" />
+ <FxCompile Include="normFixed64.hlsl" />
+ <FxCompile Include="mulMatByRowTiled.hlsl" />
+ <FxCompile Include="convolutionMain2Fixed.hlsl" />
+ <FxCompile Include="mulMatByRowTiled64.hlsl" />
+ <FxCompile Include="addRows.hlsl" />
+ <FxCompile Include="mulMatTiled64.hlsl" />
+ <FxCompile Include="zeroMemory.hlsl" />
+ <FxCompile Include="mulMatTiledEx.hlsl" />
+ <FxCompile Include="matReshapePanels.hlsl" />
+ <FxCompile Include="mulMatByRowTiledEx.hlsl" />
+ </ItemGroup>
+ <ItemGroup>
+ <None Include="componentwiseBinaryOp.hlsli" />
+ <None Include="miscUtils.hlsli" />
+ <None Include="groupReduce.hlsli" />
+ <None Include="fp64Utils.hlsli" />
+ <None Include="flashAttentionCommon.hlsli" />
+ <None Include="repeatUtils.hlsli" />
+ <None Include="groupReduce64.hlsli" />
+ </ItemGroup>
+ <ItemGroup>
+ <Text Include="Readme.txt" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/ComputeShaders/Readme.txt b/ComputeShaders/Readme.txt
new file mode 100644
index 0000000..18d089e
--- /dev/null
+++ b/ComputeShaders/Readme.txt
@@ -0,0 +1,11 @@
+This project compiles all the compute shaders which implement the model.
+
+Many shaders come in 2 versions, something.hlsl and something64.hlsl
+
+The version with the `64` suffix is used on AMD GPUs, the version without suffix is used on nVidia and Intel GPUs.
+
+Not all of these shaders are actually used for anything.
+Some of them are implementing binary compatibility for the reference CPU version, and not used unless messing with the `constexpr` flags in MlContext C++ class.
+Such shaders often require FP64 support, which is an optional feature in D3D11.
+CompressShaders tool detects such shaders by looking at the SFI0 chunk in the binary, and outputs a bitmap of the FP64 shaders.
+This way, missing FP64 hardware support shouldn’t break the library. \ No newline at end of file
diff --git a/ComputeShaders/add.hlsl b/ComputeShaders/add.hlsl
new file mode 100644
index 0000000..819443e
--- /dev/null
+++ b/ComputeShaders/add.hlsl
@@ -0,0 +1,6 @@
+inline float compute( float a, float b )
+{
+ return a + b;
+}
+
+#include "componentwiseBinaryOp.hlsli" \ No newline at end of file
diff --git a/ComputeShaders/addInPlace.hlsl b/ComputeShaders/addInPlace.hlsl
new file mode 100644
index 0000000..a83c221
--- /dev/null
+++ b/ComputeShaders/addInPlace.hlsl
@@ -0,0 +1,39 @@
+#ifndef THREADS
+#define THREADS 512
+#endif
+
+Buffer<float> arg0: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 size: packoffset( c0 );
+ uint4 strides: packoffset( c1 );
+ uint4 argStrides: packoffset( c3 );
+}
+
+inline uint rowOffset( uint3 idx, uint4 strides )
+{
+ return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ];
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint rdi = rowOffset( group, strides );
+ uint rsi = rowOffset( group, argStrides );
+
+ const uint rdiEnd = rdi + size[ 0 ] * strides[ 0 ];
+ rdi += thread * strides[ 0 ];
+ rsi += thread * argStrides[ 0 ];
+
+ const uint rdiInc = THREADS * strides[ 0 ];
+ const uint rsiInc = THREADS * argStrides[ 0 ];
+
+ for( ; rdi < rdiEnd; rdi += rdiInc, rsi += rsiInc )
+ {
+ float f = result[ rdi ];
+ f += arg0[ rsi ];
+ result[ rdi ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/addRepeat.hlsl b/ComputeShaders/addRepeat.hlsl
new file mode 100644
index 0000000..e5cdaa3
--- /dev/null
+++ b/ComputeShaders/addRepeat.hlsl
@@ -0,0 +1,70 @@
+// Compute tensor = tensor + repeat( pattern, tensor ) in 1 shot, without VRAM allocations
+// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor
+RWBuffer<float> tensor: register( u0 );
+Buffer<float> pattern: register( t0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 tensorSize: packoffset( c0 );
+ uint4 tensorStrides: packoffset( c1 );
+ uint4 patternSize: packoffset( c2 );
+ uint4 patternStrides: packoffset( c3 );
+}
+
+#ifndef THREADS
+#define THREADS 256
+#endif
+
+#include "repeatUtils.hlsli"
+
+inline void computeSimple( uint idx, float add )
+{
+ float f = tensor[ idx ];
+ f += add;
+ tensor[ idx ] = f;
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides );
+ uint rsi = rowOffset( group % patternSize.yzw, patternStrides );
+
+ if( patternSize[ 0 ] == 1 )
+ {
+ // The pattern only has 1 column - broadcasting over the row
+ const float p = pattern[ rsi ];
+ ROW_LOOP( it )
+ computeSimple( it.x, p );
+ }
+ else if( patternSize[ 0 ] <= THREADS )
+ {
+ // pattern size doesn't exceed thread group size: load pattern value outside of the loop
+ const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] );
+ if( thread >= threadsPerGroup )
+ return;
+
+ const float p = pattern[ rsi + ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ] ];
+ ROW_LOOP_EX( it, threadsPerGroup, tensorStrides )
+ computeSimple( it.x, p );
+ }
+ else
+ {
+ // Pattern rows are larger than the thread group, need to stream from both buffers
+ const uint rsiInc = THREADS * patternStrides[ 0 ];
+ const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ];
+ const uint rsiEnd = rsi + rsiDec;
+ rsi += thread * patternStrides[ 0 ];
+
+ ROW_LOOP( it )
+ {
+ float f = tensor[ it.x ];
+ float p = pattern[ rsi ];
+ rsi += rsiInc;
+ if( rsi >= rsiEnd )
+ rsi -= rsiDec;
+ f += p;
+ tensor[ it.x ] = f;
+ }
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/addRepeat64.hlsl b/ComputeShaders/addRepeat64.hlsl
new file mode 100644
index 0000000..b6c8c19
--- /dev/null
+++ b/ComputeShaders/addRepeat64.hlsl
@@ -0,0 +1,2 @@
+#define THREADS 64
+#include "addRepeat.hlsl" \ No newline at end of file
diff --git a/ComputeShaders/addRepeatGelu.hlsl b/ComputeShaders/addRepeatGelu.hlsl
new file mode 100644
index 0000000..7f63653
--- /dev/null
+++ b/ComputeShaders/addRepeatGelu.hlsl
@@ -0,0 +1,88 @@
+// Compute tensor = GELU( tensor + repeat( pattern, tensor ) ) in 1 shot, without VRAM allocations
+// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor
+RWBuffer<float> tensor: register( u0 );
+Buffer<float> pattern: register( t0 );
+Buffer<uint> lookupTable: register( t1 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 tensorSize: packoffset( c0 );
+ uint4 tensorStrides: packoffset( c1 );
+ uint4 patternSize: packoffset( c2 );
+ uint4 patternStrides: packoffset( c3 );
+}
+
+#ifndef THREADS
+#define THREADS 1024
+#endif
+
+#include "repeatUtils.hlsli"
+#include "miscUtils.hlsli"
+
+inline float gelu( float x )
+{
+#if 1
+ const uint index = fp16Rounded( x );
+ const uint res16 = lookupTable[ index ];
+ return f16tof32( res16 );
+#else
+ // This version is much slower, at least on AMD, despite saving these VRAM loads.
+ const float GELU_COEF_A = 0.044715;
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+ return 0.5 * x * ( 1.0 + tanh( SQRT_2_OVER_PI * x * ( 1.0 + GELU_COEF_A * x * x ) ) );
+#endif
+}
+
+inline void computeSimple( uint idx, float add )
+{
+ float f = tensor[ idx ];
+ f += add;
+ f = gelu( f );
+ tensor[ idx ] = f;
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides );
+ uint rsi = rowOffset( group % patternSize.yzw, patternStrides );
+
+ if( patternSize[ 0 ] == 1 )
+ {
+ // The pattern only has 1 column - broadcasting over the row
+ const float p = pattern[ rsi ];
+ ROW_LOOP( it )
+ computeSimple( it.x, p );
+ }
+ else if( patternSize[ 0 ] <= THREADS )
+ {
+ // pattern size doesn't exceed thread group size: load pattern value outside of the loop
+ const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] );
+ if( thread >= threadsPerGroup )
+ return;
+
+ const float p = pattern[ rsi + ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ] ];
+ ROW_LOOP_EX( it, threadsPerGroup, tensorStrides )
+ computeSimple( it.x, p );
+ }
+ else
+ {
+ // Pattern rows are larger than the thread group, need to stream from both buffers
+ const uint rsiInc = THREADS * patternStrides[ 0 ];
+ const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ];
+ const uint rsiEnd = rsi + rsiDec;
+ rsi += thread * patternStrides[ 0 ];
+
+ ROW_LOOP( it )
+ {
+ float f = tensor[ it.x ];
+ float p = pattern[ rsi ];
+ rsi += rsiInc;
+ if( rsi >= rsiEnd )
+ rsi -= rsiDec;
+ f += p;
+ f = gelu( f );
+ tensor[ it.x ] = f;
+ }
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/addRepeatGelu64.hlsl b/ComputeShaders/addRepeatGelu64.hlsl
new file mode 100644
index 0000000..3d9b2e8
--- /dev/null
+++ b/ComputeShaders/addRepeatGelu64.hlsl
@@ -0,0 +1,2 @@
+#define THREADS 64
+#include "addRepeatGelu.hlsl" \ No newline at end of file
diff --git a/ComputeShaders/addRepeatScale.hlsl b/ComputeShaders/addRepeatScale.hlsl
new file mode 100644
index 0000000..8c24088
--- /dev/null
+++ b/ComputeShaders/addRepeatScale.hlsl
@@ -0,0 +1,73 @@
+// Compute tensor = ( tensor + repeat( pattern, tensor ) ) * scale in 1 shot, without VRAM allocations
+// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor
+RWBuffer<float> tensor: register( u0 );
+Buffer<float> pattern: register( t0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 tensorSize: packoffset( c0 );
+ uint4 tensorStrides: packoffset( c1 );
+ uint4 patternSize: packoffset( c2 );
+ uint4 patternStrides: packoffset( c3 );
+ float scalingMul : packoffset( c4.x );
+}
+
+#ifndef THREADS
+#define THREADS 512
+#endif
+
+#include "repeatUtils.hlsli"
+
+inline void computeSimple( uint idx, float add )
+{
+ float f = tensor[ idx ];
+ f += add;
+ f *= scalingMul;
+ tensor[ idx ] = f;
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides );
+ uint rsi = rowOffset( group % patternSize.yzw, patternStrides );
+
+ if( patternSize[ 0 ] == 1 )
+ {
+ // The pattern only has 1 column - broadcasting over the row
+ const float p = pattern[ rsi ];
+ ROW_LOOP( it )
+ computeSimple( it.x, p );
+ }
+ else if( patternSize[ 0 ] <= THREADS )
+ {
+ // pattern size doesn't exceed thread group size: load pattern value outside of the loop
+ const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] );
+ if( thread >= threadsPerGroup )
+ return;
+
+ const float p = pattern[ rsi + ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ] ];
+ ROW_LOOP_EX( it, threadsPerGroup, tensorStrides )
+ computeSimple( it.x, p );
+ }
+ else
+ {
+ // Pattern rows are larger than the thread group, need to stream from both buffers
+ const uint rsiInc = THREADS * patternStrides[ 0 ];
+ const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ];
+ const uint rsiEnd = rsi + rsiDec;
+ rsi += thread * patternStrides[ 0 ];
+
+ ROW_LOOP( it )
+ {
+ float f = tensor[ it.x ];
+ float p = pattern[ rsi ];
+ rsi += rsiInc;
+ if( rsi >= rsiEnd )
+ rsi -= rsiDec;
+ f += p;
+ f *= scalingMul;
+ tensor[ it.x ] = f;
+ }
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/addRows.hlsl b/ComputeShaders/addRows.hlsl
new file mode 100644
index 0000000..21e5e73
--- /dev/null
+++ b/ComputeShaders/addRows.hlsl
@@ -0,0 +1,46 @@
+#ifndef THREADS
+#define THREADS 256
+#endif
+
+// dec.tokenEmbedding tensor
+Buffer<float> tokenEmbedding: register( t0 );
+// dec.positionalEmbedding tensor
+Buffer<float> positionalEmbedding: register( t1 );
+// R32_UINT buffer with the input tokens
+Buffer<uint> embd: register( t2 );
+// Output tensor
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint rowLength: packoffset( c0.x );
+ uint pastTokensCount: packoffset( c0.y );
+ uint outputRowStride: packoffset( c0.z );
+ uint2 embStrides: packoffset( c1.x );
+ uint2 posStrides: packoffset( c1.z );
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint row = group.x;
+ const uint rowTok = embd[ row ];
+ const uint rowPos = row + pastTokensCount;
+
+ uint rdi = row * outputRowStride;
+ const uint rdiEnd = rdi + rowLength;
+ rdi += thread;
+
+ uint rsiTok = rowTok * embStrides.y;
+ rsiTok += thread * embStrides.x;
+
+ uint rsiPos = rowPos * posStrides.y;
+ rsiPos += thread * posStrides.x;
+
+ for( ; rdi < rdiEnd; rdi += THREADS, rsiTok += THREADS * embStrides.x, rsiPos += THREADS * posStrides.x )
+ {
+ float a = tokenEmbedding[ rsiTok ];
+ float b = positionalEmbedding[ rsiPos ];
+ result[ rdi ] = a + b;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/componentwiseBinaryOp.hlsli b/ComputeShaders/componentwiseBinaryOp.hlsli
new file mode 100644
index 0000000..0a523ca
--- /dev/null
+++ b/ComputeShaders/componentwiseBinaryOp.hlsli
@@ -0,0 +1,43 @@
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ uint4 src1_elements: packoffset( c2 );
+ uint4 src1_strides: packoffset( c3 );
+ uint4 result_elements: packoffset( c4 );
+ uint4 result_strides: packoffset( c5 );
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint j = group.x;
+ const uint nb1 = result_strides[ 1 ];
+ const uint nb01 = src0_strides[ 1 ];
+
+ const uint nb10 = src1_strides[ 0 ];
+ const uint nb11 = src1_strides[ 1 ];
+ const uint nc = src0_elements[ 0 ];
+
+ uint rsi0 = j * nb01;
+ uint rsi1 = j * nb11;
+ uint rdi = j * nb1;
+ const uint rsi0End = rsi0 + nc;
+
+ rsi0 += thread;
+ rsi1 += thread * nb10;
+ rdi += thread;
+
+ const uint rsi1Inc = 32 * nb10;
+ for( ; rsi0 < rsi0End; rsi0 += 32, rsi1 += rsi1Inc, rdi += 32 )
+ {
+ const float a = arg0[ rsi0 ];
+ const float b = arg1[ rsi1 ];
+ const float res = compute( a, b );
+ result[ rdi ] = res;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/convolutionMain.hlsl b/ComputeShaders/convolutionMain.hlsl
new file mode 100644
index 0000000..eee3ebb
--- /dev/null
+++ b/ComputeShaders/convolutionMain.hlsl
@@ -0,0 +1,76 @@
+// ggml_compute_forward_conv_1d_1s_f16_f32, GGML_TASK_COMPUTE implementation
+// Dispatch [ ne10, ne02, 1 ] thread groups
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ uint4 src1_elements: packoffset( c2 );
+ uint4 result_elements: packoffset( c4 );
+ uint4 result_strides: packoffset( c5 );
+}
+
+#include "groupReduce.hlsli"
+
+inline void computeDotProduct( uint s0, uint s1, uint len, uint thread, inout float acc )
+{
+ float curr = 0;
+ const uint completeVectors = len / 32;
+ uint i;
+ for( i = 0; i < completeVectors; i++, s0 += 32, s1 += 32 )
+ curr = mad( arg0[ s0 + thread ], arg1[ s1 + thread ], curr );
+
+ horizontalSumCompatNew( thread, curr );
+
+ if( 0 == thread )
+ {
+ const uint rem = len % 32;
+ if( 0 != rem )
+ {
+ double f64 = curr;
+ for( i = 0; i < rem; i++ )
+ {
+ precise float a = arg0[ s0 + i ];
+ precise float b = arg1[ s1 + i ];
+ precise float prod = a * b;
+ f64 += prod;
+ }
+ curr = (float)f64;
+ }
+ acc += curr;
+ }
+}
+
+#include "miscUtils.hlsli"
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint i1 = group.y;
+ const uint i0 = group.x;
+
+ const uint ne00 = src0_elements[ 0 ];
+ const uint nk = ne00;
+ const int nh = (int)( nk / 2 );
+
+ const uint ne01 = src0_elements[ 1 ];
+ const int ew0 = roundUp32( ne01 );
+
+ float res = 0;
+ for( int k = -nh; k <= nh; k++ )
+ {
+ const uint source0 = i1 * ew0 * ne00 + uint( nh + k ) * ew0;
+ const uint source1 = uint( i0 + nh + k ) * ew0;
+ computeDotProduct( source0, source1, ew0, thread, res );
+ }
+
+ if( 0 != thread )
+ return;
+
+ const uint nb1 = result_strides[ 1 ];
+ const uint rdi = i1 * nb1 + i0;
+ result[ rdi ] = res;
+} \ No newline at end of file
diff --git a/ComputeShaders/convolutionMain2.hlsl b/ComputeShaders/convolutionMain2.hlsl
new file mode 100644
index 0000000..73c9da1
--- /dev/null
+++ b/ComputeShaders/convolutionMain2.hlsl
@@ -0,0 +1,60 @@
+// ggml_compute_forward_conv_1d_2s_f16_f32, GGML_TASK_COMPUTE implementation
+// Dispatch [ ne10 / 2, ne02, 1 ] thread groups
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ uint4 src1_elements: packoffset( c2 );
+ uint4 result_elements: packoffset( c4 );
+ uint4 result_strides: packoffset( c5 );
+}
+
+#include "groupReduce.hlsli"
+
+inline void computeDotProduct( uint s0, uint s1, uint len, uint thread, inout float acc )
+{
+ float curr = 0;
+ const uint s0End = s0 + len;
+ s0 += thread;
+ s1 += thread;
+ for( ; s0 < s0End; s0 += 32, s1 += 32 )
+ curr = mad( arg0[ s0 ], arg1[ s1 ], curr );
+
+ horizontalSumCompatNew( thread, curr );
+ if( 0 == thread )
+ acc += curr;
+}
+
+#include "miscUtils.hlsli"
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint ne00 = src0_elements[ 0 ];
+ const uint ne01 = src0_elements[ 1 ];
+ const int ew0 = roundUp32( ne01 );
+
+ float res = 0;
+ uint s0 = group.y * ew0 * ne00;
+ uint s1 = group.x * 2 * ew0;
+ // The original implementation did following:
+ // int nh = (int)( nk / 2 );
+ // for( int k = -nh; k <= nh; k++ )
+ // What we doing instead:
+ // for( uint len = ( nk / 2 ) * 2 + 1, i = 0; i < len; i++ )
+ // len = ( nk / 2 ) * 2 + 1 is equal to ( nk | 1 )
+ const uint s0End = s0 + ( ne00 | 1u ) * ew0;
+ for( ; s0 < s0End; s0 += ew0, s1 += ew0 )
+ computeDotProduct( s0, s1, ew0, thread, res );
+
+ if( 0 != thread )
+ return;
+
+ const uint nb1 = result_strides[ 1 ];
+ const uint rdi = group.y * nb1 + group.x;
+ result[ rdi ] = res;
+} \ No newline at end of file
diff --git a/ComputeShaders/convolutionMain2Fixed.hlsl b/ComputeShaders/convolutionMain2Fixed.hlsl
new file mode 100644
index 0000000..d21dcbb
--- /dev/null
+++ b/ComputeShaders/convolutionMain2Fixed.hlsl
@@ -0,0 +1,119 @@
+// Optimized version of convolutionMain2.hlsl for kernel size = 3
+// Dispatch [ ( ( ne10 / 2 ) + TILE_Y - 1 ) / TILE_Y, ne02, 1 ] thread groups of this shader
+#ifndef TILE_Y
+static const uint TILE_Y = 8;
+#endif
+#ifndef THREADS
+static const uint THREADS = 64;
+#endif
+
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ uint4 src1_elements: packoffset( c2 );
+ uint4 result_elements: packoffset( c4 );
+ uint4 result_strides: packoffset( c5 );
+}
+
+// The accumulators we're after
+groupshared float resTemp[ TILE_Y ][ THREADS ];
+
+// Multiply + accumulate the specified row
+inline void accumulate( float a0, float a1, const uint resultRow, const uint thread )
+{
+ float acc = resTemp[ resultRow ][ thread ];
+ acc = mad( a0, a1, acc );
+ resTemp[ resultRow ][ thread ] = acc;
+}
+
+inline void convolutionTile( const uint s0, uint s1, const uint thread, const uint stride, const uint height )
+{
+ // Load 3 rows from arg0
+ const float3 a0 = float3( arg0[ s0 ], arg0[ s0 + stride ], arg0[ s0 + stride * 2 ] );
+
+ // Row 0
+ float a1 = arg1[ s1 ];
+ accumulate( a0[ 0 ], a1, 0, thread );
+ s1 += stride;
+
+ for( uint i = 1; i < height; i++ )
+ {
+ // Row i*2-1
+ // Even-indexed rows only contribute to a single output rows, after muiltiplied by kernel row #1
+ a1 = arg1[ s1 ];
+ accumulate( a0[ 1 ], a1, i - 1, thread );
+ s1 += stride;
+
+ // Row i*2, contributes to 2 output rows corresponding to kernel rows #0 and #2
+ a1 = arg1[ s1 ];
+ accumulate( a0[ 2 ], a1, i - 1, thread );
+ accumulate( a0[ 0 ], a1, i, thread );
+ s1 += stride;
+ }
+
+ // Row height*2 - 1
+ a1 = arg1[ s1 ];
+ accumulate( a0[ 1 ], a1, height - 1, thread );
+ s1 += stride;
+
+ // Row height*2
+ a1 = arg1[ s1 ];
+ accumulate( a0[ 2 ], a1, height - 1, thread );
+}
+
+#include "miscUtils.hlsli"
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint i;
+ // Zero out the accumulators
+ for( i = 0; i < TILE_Y; i++ )
+ resTemp[ i ][ thread ] = 0.0;
+ GroupMemoryBarrierWithGroupSync();
+
+ const uint i1 = group.y;
+ const uint i0 = group.x * TILE_Y * 2;
+ const uint height = min( TILE_Y, ( src1_elements.x / 2 ) - group.x * TILE_Y );
+
+ const uint ne00 = src0_elements[ 0 ];
+ const uint ne01 = src0_elements[ 1 ];
+ const int ew0 = roundUp32( ne01 );
+
+ uint s0 = i1 * ew0 * ne00;
+ const uint s0End = s0 + ew0;
+ uint s1 = i0 * ew0;
+ s0 += thread;
+ s1 += thread;
+ for( ; s0 < s0End; s0 += THREADS, s1 += THREADS )
+ convolutionTile( s0, s1, thread, ew0, height );
+
+ GroupMemoryBarrierWithGroupSync();
+
+ // Now we need horizontal sums of these shared accumulators, i.e. reduce [height][THREADS] shared array into [height][1] column
+ for( i = THREADS / 2; i > 0; i /= 2 )
+ {
+ if( thread < i )
+ {
+ for( uint j = 0; j < height; j++ )
+ {
+ float sum = resTemp[ j ][ thread ];
+ sum += resTemp[ j ][ thread + i ];
+ resTemp[ j ][ thread ] = sum;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ // And finally, store that column to global memory
+ if( thread >= height )
+ return;
+ const uint nb1 = result_strides[ 1 ];
+ const uint rdi = i1 * nb1 + group.x * TILE_Y + thread;
+ result[ rdi ] = resTemp[ thread ][ 0 ];
+} \ No newline at end of file
diff --git a/ComputeShaders/convolutionPrep1.hlsl b/ComputeShaders/convolutionPrep1.hlsl
new file mode 100644
index 0000000..528ff16
--- /dev/null
+++ b/ComputeShaders/convolutionPrep1.hlsl
@@ -0,0 +1,39 @@
+// ggml_compute_forward_conv_1d_1s_f16_f32, prepare kernel data (src0)
+// Dispatch [ ne01, ne02, 1 ] thread groups
+Buffer<float> arg0: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+}
+
+inline uint roundUp32( uint x )
+{
+ return ( x + 31 ) & ( ~31u );
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint nb01 = src0_strides[ 1 ];
+ const uint nb02 = src0_strides[ 2 ];
+
+ const uint ne00 = src0_elements[ 0 ];
+ const uint ne01 = src0_elements[ 1 ];
+ const uint ew0 = roundUp32( ne01 );
+
+ const uint i02 = group.y;
+ const uint i01 = group.x;
+
+ uint rsi = i02 * nb02 + i01 * nb01;
+ const uint rsiEnd = rsi + ne00;
+ uint rdi = i02 * ew0 * ne00 + i01;
+ rsi += thread;
+ rdi += thread * ew0;
+ const uint rdiInc = 32 * ew0;
+
+ for( ; rsi < rsiEnd; rsi += 32, rdi += rdiInc )
+ result[ rdi ] = arg0[ rsi ];
+} \ No newline at end of file
diff --git a/ComputeShaders/convolutionPrep2.hlsl b/ComputeShaders/convolutionPrep2.hlsl
new file mode 100644
index 0000000..a7e7172
--- /dev/null
+++ b/ComputeShaders/convolutionPrep2.hlsl
@@ -0,0 +1,43 @@
+// ggml_compute_forward_conv_1d_1s_f16_f32, prepare source data (src1)
+// Dispatch [ ne11, 1, 1 ] thread groups
+Buffer<float> arg1: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src1_elements: packoffset( c2 );
+ uint4 src1_strides: packoffset( c3 );
+}
+
+#include "miscUtils.hlsli"
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint i11 = group.x;
+
+ const uint ne00 = src0_elements[ 0 ];
+ const uint ne01 = src0_elements[ 1 ];
+ const uint ne10 = src1_elements[ 0 ];
+ const uint nb11 = src1_strides[ 1 ];
+
+ const uint nk = ne00;
+ const uint nh = nk / 2;
+ const int ew0 = roundUp32( ne01 );
+
+ uint rsi = i11 * nb11;
+ uint rdi = nh * ew0 + i11;
+ const uint rdiInc = ew0 * 32;
+ const uint rsiEnd = rsi + ne10;
+
+ rsi += thread;
+ rdi += thread * ew0;
+
+ for( ; rsi < rsiEnd; rsi += 32, rdi += rdiInc )
+ {
+ float f = arg1[ rsi ];
+ f = adjustFp16( f );
+ result[ rdi ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/copyConvert.hlsl b/ComputeShaders/copyConvert.hlsl
new file mode 100644
index 0000000..399147f
--- /dev/null
+++ b/ComputeShaders/copyConvert.hlsl
@@ -0,0 +1,50 @@
+// ggml_compute_forward_dup_f32 when we only need to convert types, but not reshape the tensor
+// Dispatch [ ne01, ne02, ne03 ] thread groups of this shader
+Buffer<float> arg0: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ bool downcastFp32 : packoffset( c2.x );
+}
+
+#include "miscUtils.hlsli"
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint nb00 = src0_strides[ 0 ];
+ const uint nb01 = src0_strides[ 1 ];
+ const uint nb02 = src0_strides[ 2 ];
+ const uint nb03 = src0_strides[ 3 ];
+
+ const uint ne00 = src0_elements[ 0 ];
+ const uint ne01 = src0_elements[ 1 ];
+ const uint ne02 = src0_elements[ 2 ];
+ const uint ne03 = src0_elements[ 3 ];
+
+ const uint i01 = group.x;
+ const uint i02 = group.y;
+ const uint i03 = group.z;
+
+ const uint rs = ne00 * nb00;
+ //const uint id = i01 + i02 * ne02 + i03 * ne01 * ne02;
+ const uint id = ( i03 * ne01 + i02 ) * ne02 + i01;
+
+ uint rsi = i01 * nb01 + i02 * nb02 + i03 * nb03;
+ uint rdi = id * rs;
+
+ const uint rsiEnd = rsi + rs;
+ rsi += thread;
+ rdi += thread;
+ for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
+ {
+ float f = arg0[ rsi ];
+ [branch]
+ if( downcastFp32 )
+ f = adjustFp16( f );
+ result[ rdi ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/copyTranspose.hlsl b/ComputeShaders/copyTranspose.hlsl
new file mode 100644
index 0000000..fc3e82f
--- /dev/null
+++ b/ComputeShaders/copyTranspose.hlsl
@@ -0,0 +1,56 @@
+// ggml_compute_forward_dup_f32 when we actually need to reshape the tensor
+// Dispatch [ ne01, ne02, ne03 ] thread groups of this shader
+Buffer<float> arg0: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ bool downcastFp32 : packoffset( c2.x );
+}
+
+#include "miscUtils.hlsli"
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint nb00 = src0_strides[ 0 ];
+ const uint nb01 = src0_strides[ 1 ];
+ const uint nb02 = src0_strides[ 2 ];
+ const uint nb03 = src0_strides[ 3 ];
+
+ const uint ne00 = src0_elements[ 0 ];
+ const uint ne01 = src0_elements[ 1 ];
+ const uint ne02 = src0_elements[ 2 ];
+ const uint ne03 = src0_elements[ 3 ];
+
+ const uint i01 = group.x;
+ const uint i02 = group.y;
+ const uint i03 = group.z;
+
+ // We need following integer: i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02
+ // We want to minimize count of integer multiplications
+ // Also, DXBC assembly features `imad` instruction which computes a*b+c for integers, the actual hardware hopefully has an equivalent
+ // i03*ne00*ne01*ne02 + i02*ne00*ne01 + i01*ne00
+ // ( i03*ne01*ne02 + i02*ne01 + i01 ) * ne00
+ // ( ( i03*ne02 + i02) * ne01 + i01 ) * ne00
+ uint rdi = ( ( i03 * ne02 + i02 ) * ne01 + i01 ) * ne00;
+
+ const uint rdiEnd = rdi + ne00;
+
+ uint rsi = i01 * nb01 + i02 * nb02 + i03 * nb03;
+ const uint rsiInc = 32 * nb00;
+
+ rdi += thread;
+ rsi += thread * nb00;
+
+ for( ; rdi < rdiEnd; rdi += 32, rsi += rsiInc )
+ {
+ float f = arg0[ rsi ];
+ [branch]
+ if( downcastFp32 )
+ f = adjustFp16( f );
+ result[ rdi ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/diagMaskInf.hlsl b/ComputeShaders/diagMaskInf.hlsl
new file mode 100644
index 0000000..18e3938
--- /dev/null
+++ b/ComputeShaders/diagMaskInf.hlsl
@@ -0,0 +1,30 @@
+// ggml_compute_forward_diag_mask_inf_f32
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 elements: packoffset( c0 );
+ uint4 strides: packoffset( c1 );
+ uint n_past : packoffset( c2.x );
+}
+
+static const float negativeInfinity = asfloat( 0xff800000 );
+
+[numthreads( 32, 1, 1 )]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint k = group.y;
+ const uint j = group.x;
+
+ // Start of the row
+ uint rdi = k * strides[ 2 ] + j * strides[ 1 ];
+ // End of the row
+ const uint rdiEnd = rdi + elements[ 0 ] * strides[ 0 ];
+ // First index to write in this thread
+ rdi += ( n_past + j + thread + 1 ) * strides[ 0 ];
+ // Index increment
+ const uint rdiInc = 32 * strides[ 0 ];
+
+ for( ; rdi < rdiEnd; rdi += rdiInc )
+ result[ rdi ] = negativeInfinity;
+} \ No newline at end of file
diff --git a/ComputeShaders/flashAttention.hlsl b/ComputeShaders/flashAttention.hlsl
new file mode 100644
index 0000000..65212d0
--- /dev/null
+++ b/ComputeShaders/flashAttention.hlsl
@@ -0,0 +1,170 @@
+// Ported from ggml_compute_forward_flash_attn_f16
+// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups
+
+#include "flashAttentionCommon.hlsli"
+Buffer<uint> lookupTable: register( t3 );
+#include "groupReduce.hlsli"
+
+inline void computeDotProduct( Buffer<float> buff0, Buffer<float> buff1, uint s0, uint s1, const uint len, const uint thread, inout float acc )
+{
+ acc = 0;
+ const uint s0End = s0 + len;
+ s0 += thread;
+ s1 += thread;
+ for( ; s0 < s0End; s0 += 32, s1 += 32 )
+ acc = mad( buff0[ s0 ], buff1[ s1 ], acc );
+
+ horizontalSum( thread, acc );
+}
+
+inline void computeDotProduct( Buffer<float> buff0, RWBuffer<float> buff1, uint s0, uint s1, const uint len, const uint thread, inout float acc )
+{
+ acc = 0;
+ const uint s0End = s0 + len;
+ s0 += thread;
+ s1 += thread;
+ for( ; s0 < s0End; s0 += 32, s1 += 32 )
+ acc = mad( buff0[ s0 ], buff1[ s1 ], acc );
+
+ horizontalSum( thread, acc );
+}
+
+void scaleTempVector( uint i, const uint length, const uint thread, const float multiplier, bool round )
+{
+ const uint end = i + length;
+ for( i += thread; i < end; i += 32 )
+ {
+ float f = temp[ i ];
+ f *= multiplier;
+ if( round )
+ f = roundToFp16( f );
+ temp[ i ] = f;
+ }
+}
+
+#include "miscUtils.hlsli"
+
+// Transform temp[ i ] = exp( temp[ i ] - tempMax ), and return the sum of these values
+inline float applySoftMax( uint i, const uint length, const uint thread, const float tempMax )
+{
+ // Transform the values, and compute per-thread sum
+ const uint end = i + length;
+ float sum = 0;
+ for( i += thread; i < end; i += 32 )
+ {
+ float f = temp[ i ];
+ [branch]
+ if( f != negativeInfinity )
+ {
+ f -= tempMax;
+ const uint index = fp16Rounded( f );
+ const uint res16 = lookupTable[ index ];
+ f = f16tof32( res16 );
+ }
+ else
+ f = 0;
+
+ temp[ i ] = f;
+ sum += f;
+ }
+
+ // Reduce per-thread sum to the global one, over all threads of the group
+ horizontalSumBroadcast( thread, sum );
+ return sum;
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint neq0 = q_elements[ 0 ];
+ const uint neq1 = q_elements[ 1 ];
+ const uint neq2 = q_elements[ 2 ];
+ const uint neq3 = q_elements[ 3 ];
+
+ const uint nek0 = k_elements[ 0 ];
+ const uint nek1 = k_elements[ 1 ];
+
+ const uint nev1 = v_elements[ 1 ];
+
+ const uint ne0 = res_elements[ 0 ];
+ const uint ne1 = res_elements[ 1 ];
+
+ const uint nbk0 = k_strides[ 0 ];
+ const uint nbk1 = k_strides[ 1 ];
+ const uint nbk2 = k_strides[ 2 ];
+ const uint nbk3 = k_strides[ 3 ];
+
+ const uint nbq0 = q_strides[ 0 ];
+ const uint nbq1 = q_strides[ 1 ];
+ const uint nbq2 = q_strides[ 2 ];
+ const uint nbq3 = q_strides[ 3 ];
+
+ const uint nbv0 = v_strides[ 0 ];
+ const uint nbv1 = v_strides[ 1 ];
+ const uint nbv2 = v_strides[ 2 ];
+ const uint nbv3 = v_strides[ 3 ];
+
+ const uint nb0 = res_strides[ 0 ];
+ const uint nb1 = res_strides[ 1 ];
+ const uint nb2 = res_strides[ 2 ];
+ const uint nb3 = res_strides[ 3 ];
+
+ const uint D = neq0;
+ const uint N = neq1;
+ const uint P = nek1 - N;
+ const uint M = nek1;
+
+ const uint ir = group.x;
+ const uint iq3 = ir / ( neq2 * neq1 );
+ const uint iq2 = ( ir - iq3 * neq2 * neq1 ) / neq1;
+ const uint iq1 = ( ir - iq3 * neq2 * neq1 - iq2 * neq1 );
+
+ const uint tempIndex = ir * tempBufferStride;
+
+ uint ic;
+ float tvm = negativeInfinity;
+ const uint s1 = iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3;
+ uint s0 = iq2 * nbk2 + iq3 * nbk3;
+ for( ic = 0; ic < nek1; ic++, s0 += nbk1 )
+ {
+ if( masked )
+ {
+ if( ic > P + iq1 )
+ {
+ if( 0 == thread )
+ temp[ tempIndex + ic ] = negativeInfinity;
+ continue;
+ }
+ }
+
+ float dp;
+ computeDotProduct( k, q, s0, s1, neq0, thread, dp );
+ if( 0 == thread )
+ {
+ dp *= scale;
+ temp[ tempIndex + ic ] = dp;
+ tvm = max( tvm, dp );
+ }
+ }
+
+ if( 0 == thread )
+ sharedAccumulators[ 0 ] = tvm;
+ GroupMemoryBarrierWithGroupSync();
+ tvm = sharedAccumulators[ 0 ];
+
+ // Softmax
+ {
+ float sum = applySoftMax( tempIndex, M, thread, tvm );
+ scaleTempVector( tempIndex, M, thread, 1.0 / sum, true );
+ }
+
+ s0 = iq2 * nbv2 + iq3 * nbv3;
+ uint rdi = iq1 * nb1 + iq2 * nb2 + iq3 * nb3;
+ for( ic = 0; ic < nev1; ic++, s0 += nbv1, rdi += nb0 )
+ {
+ float dp;
+ computeDotProduct( v, temp, s0, tempIndex, nek1, thread, dp );
+ if( 0 == thread )
+ result[ rdi ] = dp;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/flashAttentionCommon.hlsli b/ComputeShaders/flashAttentionCommon.hlsli
new file mode 100644
index 0000000..68ed30b
--- /dev/null
+++ b/ComputeShaders/flashAttentionCommon.hlsli
@@ -0,0 +1,67 @@
+// Ported from ggml_compute_forward_flash_attn_f16
+// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups
+Buffer<float> q: register( t0 );
+Buffer<float> k: register( t1 );
+Buffer<float> v: register( t2 );
+
+RWBuffer<float> result: register( u0 );
+// This temporary buffer should fit tempBufferStride * neq1 * neq2 * neq3 elements, FP32 precision
+RWBuffer<float> temp: register( u1 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 q_elements: packoffset( c0 );
+ uint4 q_strides: packoffset( c1 );
+ uint4 k_elements: packoffset( c2 );
+ uint4 k_strides: packoffset( c3 );
+ uint4 v_elements: packoffset( c4 );
+ uint4 v_strides: packoffset( c5 );
+ uint4 res_elements: packoffset( c6 );
+ uint4 res_strides: packoffset( c7 );
+
+ bool masked : packoffset( c8.x );
+ // 1.0 / sqrt( (double) D )
+ float scale : packoffset( c8.y );
+ // This number is required to be >= nek1, and ideally rounded up to either 32 (L2 line) or 128 (L1 line) bytes
+ uint tempBufferStride: packoffset( c8.z );
+}
+
+static const float negativeInfinity = asfloat( 0xff800000 );
+
+// Convert FP32 number to FP16 using rounding to nearest, then upcast back to FP32
+inline float roundToFp16( const float src )
+{
+ const uint trunc16 = f32tof16( src );
+ const float trunc32 = f16tof32( trunc16 );
+
+ const uint truncExp = ( trunc16 >> 10 ) & 0x1F;
+ if( truncExp != 0x1F )
+ {
+ const uint next16 = trunc16 + 1;
+ const float next32 = f16tof32( next16 );
+
+ const float errTrunc = abs( src - trunc32 );
+ const float errNext = abs( src - next32 );
+
+ if( errTrunc < errNext )
+ {
+ // Truncated was closer to the source
+ return trunc32;
+ }
+ else if( errTrunc > errNext )
+ {
+ // Truncated + 1 was closer to the source
+ return next32;
+ }
+ else
+ {
+ // Exactly half, doing banker's rounding to nearest even
+ return ( 0 == ( trunc16 & 1 ) ) ? trunc32 : next32;
+ }
+ }
+ else
+ {
+ // INF or NAN
+ return trunc32;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/flashAttentionCompat1.hlsl b/ComputeShaders/flashAttentionCompat1.hlsl
new file mode 100644
index 0000000..f1f2ddb
--- /dev/null
+++ b/ComputeShaders/flashAttentionCompat1.hlsl
@@ -0,0 +1,125 @@
+// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups
+#include "flashAttentionCommon.hlsli"
+#include "groupReduce.hlsli"
+
+inline void computeDotProduct( Buffer<float> buff0, Buffer<float> buff1, uint s0, uint s1, const uint len, const uint thread, inout float acc )
+{
+ acc = 0;
+ /*
+ const uint s0End = s0 + len;
+ s0 += thread;
+ s1 += thread;
+ for( ; s0 < s0End; s0 += 32, s1 += 32 )
+ acc = mad( buff0[ s0 ], buff1[ s1 ], acc );
+ horizontalSumCompatNew( thread, acc );
+ */
+
+ const uint completeVectors = len / 32;
+ uint i;
+ for( i = 0; i < completeVectors; i++, s0 += 32, s1 += 32 )
+ acc = mad( buff0[ s0 + thread ], buff1[ s1 + thread ], acc );
+
+ horizontalSumCompatNew( thread, acc );
+
+ if( 0 == thread )
+ {
+ const uint rem = len % 32;
+ for( i = 0; i < rem; i++ )
+ {
+ precise float a = buff0[ s0 + i ];
+ precise float b = buff1[ s1 + i ];
+ precise float prod = a * b;
+ acc += prod;
+ }
+ }
+}
+
+void scaleTempVector( uint i, const uint length, const uint thread, const float multiplier )
+{
+ const uint end = i + length;
+ for( i += thread; i < end; i += 32 )
+ {
+ float f = temp[ i ];
+ f *= multiplier;
+ temp[ i ] = f;
+ }
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint neq0 = q_elements[ 0 ];
+ const uint neq1 = q_elements[ 1 ];
+ const uint neq2 = q_elements[ 2 ];
+ const uint neq3 = q_elements[ 3 ];
+
+ const uint nek0 = k_elements[ 0 ];
+ const uint nek1 = k_elements[ 1 ];
+
+ const uint nev1 = v_elements[ 1 ];
+
+ const uint ne0 = res_elements[ 0 ];
+ const uint ne1 = res_elements[ 1 ];
+
+ const uint nbk0 = k_strides[ 0 ];
+ const uint nbk1 = k_strides[ 1 ];
+ const uint nbk2 = k_strides[ 2 ];
+ const uint nbk3 = k_strides[ 3 ];
+
+ const uint nbq0 = q_strides[ 0 ];
+ const uint nbq1 = q_strides[ 1 ];
+ const uint nbq2 = q_strides[ 2 ];
+ const uint nbq3 = q_strides[ 3 ];
+
+ const uint nbv0 = v_strides[ 0 ];
+ const uint nbv1 = v_strides[ 1 ];
+ const uint nbv2 = v_strides[ 2 ];
+ const uint nbv3 = v_strides[ 3 ];
+
+ const uint nb0 = res_strides[ 0 ];
+ const uint nb1 = res_strides[ 1 ];
+ const uint nb2 = res_strides[ 2 ];
+ const uint nb3 = res_strides[ 3 ];
+
+ const uint D = neq0;
+ const uint N = neq1;
+ const uint P = nek1 - N;
+ // const uint M = P + N;
+ const uint M = nek1;
+
+ const uint ir = group.x;
+ const uint iq3 = ir / ( neq2 * neq1 );
+ const uint iq2 = ( ir - iq3 * neq2 * neq1 ) / neq1;
+ const uint iq1 = ( ir - iq3 * neq2 * neq1 - iq2 * neq1 );
+
+ const uint tempIndex = ir * tempBufferStride;
+
+ uint ic;
+ for( ic = 0; ic < nek1; ic++ )
+ {
+ // k indices
+ const uint ik3 = iq3;
+ const uint ik2 = iq2;
+ const uint ik1 = ic;
+
+ // S indices
+ const uint i1 = ik1;
+
+ if( masked )
+ {
+ if( ic > P + iq1 )
+ {
+ if( 0 == thread )
+ temp[ tempIndex + ic ] = negativeInfinity;
+ continue;
+ }
+ }
+
+ const uint s0 = ik1 * nbk1 + ik2 * nbk2 + ik3 * nbk3;
+ const uint s1 = iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3;
+ float dp;
+ computeDotProduct( k, q, s0, s1, neq0, thread, dp );
+ if( 0 == thread )
+ temp[ tempIndex + ic ] = dp * scale;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/flashAttentionCompat2.hlsl b/ComputeShaders/flashAttentionCompat2.hlsl
new file mode 100644
index 0000000..73f5fce
--- /dev/null
+++ b/ComputeShaders/flashAttentionCompat2.hlsl
@@ -0,0 +1,114 @@
+// Dispatch with [ ( neq1*neq2*neq3 + 31 ) / 32, 1, 1 ] thread groups
+#include "flashAttentionCommon.hlsli"
+Buffer<uint> lookupTable: register( t3 );
+
+void scaleTempVector( uint i, const uint length, const float multiplier )
+{
+ const uint end = i + length;
+ for( ; i < end; i++ )
+ {
+ float f = temp[ i ];
+ f *= multiplier;
+ // Rounding in this shader causes numerical errors on my GeForce 1080 Ti GPU, driver 527.56
+ // f = roundToFp16( f );
+ temp[ i ] = f;
+ }
+}
+
+inline float computeTempVectorMax( uint i, const uint length )
+{
+ // Compute per-thread maximum
+ const uint end = i + length;
+ float ax = negativeInfinity;
+ for( ; i < end; i++ )
+ ax = max( ax, temp[ i ] );
+ return ax;
+}
+
+#include "miscUtils.hlsli"
+#include "fp64Utils.hlsli"
+
+// Transform temp[ i ] = exp( temp[ i ] - tempMax ), and return the sum of these values
+inline double applySoftMax( uint i, const uint length, const float tempMax )
+{
+ // Transform the values, and compute per-thread sum
+ const uint end = i + length;
+ double sum = 0;
+ for( ; i < end; i++ )
+ {
+ float f = temp[ i ];
+ [branch]
+ if( f != negativeInfinity )
+ {
+ f -= tempMax;
+ const uint index = fp16Rounded( f );
+ const uint res16 = lookupTable[ index ];
+ f = f16tof32( res16 );
+ sum += f;
+ }
+ else
+ f = 0;
+
+ temp[ i ] = f;
+ }
+ return sum;
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 dtid: SV_DispatchThreadID )
+{
+ const uint neq0 = q_elements[ 0 ];
+ const uint neq1 = q_elements[ 1 ];
+ const uint neq2 = q_elements[ 2 ];
+ const uint neq3 = q_elements[ 3 ];
+
+ const uint nek0 = k_elements[ 0 ];
+ const uint nek1 = k_elements[ 1 ];
+
+ const uint nev1 = v_elements[ 1 ];
+
+ const uint ne0 = res_elements[ 0 ];
+ const uint ne1 = res_elements[ 1 ];
+
+ const uint nbk0 = k_strides[ 0 ];
+ const uint nbk1 = k_strides[ 1 ];
+ const uint nbk2 = k_strides[ 2 ];
+ const uint nbk3 = k_strides[ 3 ];
+
+ const uint nbq0 = q_strides[ 0 ];
+ const uint nbq1 = q_strides[ 1 ];
+ const uint nbq2 = q_strides[ 2 ];
+ const uint nbq3 = q_strides[ 3 ];
+
+ const uint nbv0 = v_strides[ 0 ];
+ const uint nbv1 = v_strides[ 1 ];
+ const uint nbv2 = v_strides[ 2 ];
+ const uint nbv3 = v_strides[ 3 ];
+
+ const uint nb0 = res_strides[ 0 ];
+ const uint nb1 = res_strides[ 1 ];
+ const uint nb2 = res_strides[ 2 ];
+ const uint nb3 = res_strides[ 3 ];
+
+ const uint D = neq0;
+ const uint N = neq1;
+ const uint P = nek1 - N;
+ // const uint M = P + N;
+ const uint M = nek1;
+
+ const uint ir = dtid.x;
+ if( ir >= neq1 * neq2 * neq3 )
+ return;
+
+ const uint iq3 = ir / ( neq2 * neq1 );
+ const uint iq2 = ( ir - iq3 * neq2 * neq1 ) / neq1;
+ const uint iq1 = ( ir - iq3 * neq2 * neq1 - iq2 * neq1 );
+
+ const uint tempIndex = ir * tempBufferStride;
+
+ // Softmax
+ float tvm = computeTempVectorMax( tempIndex, M );
+ double sum = applySoftMax( tempIndex, M, tvm );
+
+ scaleTempVector( tempIndex, M, (float)( 1.0 / sum ) );
+} \ No newline at end of file
diff --git a/ComputeShaders/flashAttentionCompat3.hlsl b/ComputeShaders/flashAttentionCompat3.hlsl
new file mode 100644
index 0000000..e0a4061
--- /dev/null
+++ b/ComputeShaders/flashAttentionCompat3.hlsl
@@ -0,0 +1,118 @@
+// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups
+#include "flashAttentionCommon.hlsli"
+#include "groupReduce.hlsli"
+#include "miscUtils.hlsli"
+
+inline void roundTempVector( uint i, const uint len, const uint thread )
+{
+ const uint iEnd = i + len;
+ for( i += thread; i < iEnd; i += 32 )
+ {
+ float f = temp[ i ];
+ f = roundToFp16( f );
+ temp[ i ] = f;
+ }
+}
+
+inline void computeDotProduct( Buffer<float> buff0, RWBuffer<float> buff1, uint s0, uint s1, const uint len, const uint thread, inout float acc )
+{
+ acc = 0;
+/* const uint s0End = s0 + len;
+ s0 += thread;
+ s1 += thread;
+ for( ; s0 < s0End; s0 += 32, s1 += 32 )
+ acc = mad( buff0[ s0 ], buff1[ s1 ], acc );
+
+ horizontalSumCompatNew( thread, acc ); */
+ const uint completeVectors = len / 32;
+ uint i;
+ for( i = 0; i < completeVectors; i++, s0 += 32, s1 += 32 )
+ acc = mad( buff0[ s0 + thread ], buff1[ s1 + thread ], acc );
+
+ horizontalSumCompatNew( thread, acc );
+
+ if( 0 == thread )
+ {
+ const uint rem = len % 32;
+ if( 0 != rem )
+ {
+ double f64 = acc;
+ for( i = 0; i < rem; i++ )
+ {
+ precise float a = buff0[ s0 + i ];
+ precise float b = buff1[ s1 + i ];
+ precise float prod = a * b;
+ f64 += prod;
+ }
+ acc = (float)f64;
+ }
+ }
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint neq0 = q_elements[ 0 ];
+ const uint neq1 = q_elements[ 1 ];
+ const uint neq2 = q_elements[ 2 ];
+ const uint neq3 = q_elements[ 3 ];
+
+ const uint nek0 = k_elements[ 0 ];
+ const uint nek1 = k_elements[ 1 ];
+
+ const uint nev1 = v_elements[ 1 ];
+
+ const uint ne0 = res_elements[ 0 ];
+ const uint ne1 = res_elements[ 1 ];
+
+ const uint nbk0 = k_strides[ 0 ];
+ const uint nbk1 = k_strides[ 1 ];
+ const uint nbk2 = k_strides[ 2 ];
+ const uint nbk3 = k_strides[ 3 ];
+
+ const uint nbq0 = q_strides[ 0 ];
+ const uint nbq1 = q_strides[ 1 ];
+ const uint nbq2 = q_strides[ 2 ];
+ const uint nbq3 = q_strides[ 3 ];
+
+ const uint nbv0 = v_strides[ 0 ];
+ const uint nbv1 = v_strides[ 1 ];
+ const uint nbv2 = v_strides[ 2 ];
+ const uint nbv3 = v_strides[ 3 ];
+
+ const uint nb0 = res_strides[ 0 ];
+ const uint nb1 = res_strides[ 1 ];
+ const uint nb2 = res_strides[ 2 ];
+ const uint nb3 = res_strides[ 3 ];
+
+ const uint D = neq0;
+ const uint N = neq1;
+ const uint P = nek1 - N;
+ // const uint M = P + N;
+ const uint M = nek1;
+
+ const uint ir = group.x;
+ const uint iq3 = ir / ( neq2 * neq1 );
+ const uint iq2 = ( ir - iq3 * neq2 * neq1 ) / neq1;
+ const uint iq1 = ( ir - iq3 * neq2 * neq1 - iq2 * neq1 );
+
+ const uint tempIndex = ir * tempBufferStride;
+
+ roundTempVector( tempIndex, nek1, thread );
+ AllMemoryBarrierWithGroupSync();
+
+ uint rdi = iq1 * nb1 + iq2 * nb2 + iq3 * nb3;
+ for( uint ic = 0; ic < nev1; ic++, rdi += nb0 )
+ {
+ // dst indices
+ const uint i1 = iq1;
+ const uint i2 = iq2;
+ const uint i3 = iq3;
+
+ const uint s0 = ic * nbv1 + i2 * nbv2 + i3 * nbv3;
+ float dp;
+ computeDotProduct( v, temp, s0, tempIndex, nek1, thread, dp );
+ if( 0 == thread )
+ result[ rdi ] = dp;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/fmaRepeat1.hlsl b/ComputeShaders/fmaRepeat1.hlsl
new file mode 100644
index 0000000..3db3827
--- /dev/null
+++ b/ComputeShaders/fmaRepeat1.hlsl
@@ -0,0 +1,77 @@
+// Implementation of fmaRepeat() when both source arguments have same size and strides
+// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor
+RWBuffer<float> tensor: register( u0 );
+Buffer<float> patternMul: register( t0 );
+Buffer<float> patternAdd: register( t1 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 tensorSize: packoffset( c0 );
+ uint4 tensorStrides: packoffset( c1 );
+ uint4 patternSize: packoffset( c2 );
+ uint4 patternStrides: packoffset( c3 );
+}
+
+#ifndef THREADS
+#define THREADS 512
+#endif
+
+#include "repeatUtils.hlsli"
+
+inline void computeSimple( uint idx, float mul, float add )
+{
+ precise float f = tensor[ idx ];
+ f *= mul;
+ f += add;
+ tensor[ idx ] = f;
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides );
+ uint rsi = rowOffset( group % patternSize.yzw, patternStrides );
+
+ if( patternSize[ 0 ] == 1 )
+ {
+ // The pattern only has 1 column - broadcasting over the row
+ const float pMul = patternMul[ rsi ];
+ const float pAdd = patternAdd[ rsi ];
+ ROW_LOOP( it )
+ computeSimple( it.x, pMul, pAdd );
+ }
+ else if( patternSize[ 0 ] <= THREADS )
+ {
+ // pattern size doesn't exceed thread group size: load pattern value outside of the loop
+ const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] );
+ if( thread >= threadsPerGroup )
+ return;
+
+ rsi += ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ];
+ const float pMul = patternMul[ rsi ];
+ const float pAdd = patternAdd[ rsi ];
+ ROW_LOOP_EX( it, threadsPerGroup, tensorStrides )
+ computeSimple( it.x, pMul, pAdd );
+ }
+ else
+ {
+ // Pattern rows are larger than the thread group, need to stream from both buffers
+ const uint rsiInc = THREADS * patternStrides[ 0 ];
+ const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ];
+ const uint rsiEnd = rsi + rsiDec;
+ rsi += thread * patternStrides[ 0 ];
+
+ ROW_LOOP( it )
+ {
+ precise float f = tensor[ it.x ];
+ float mul = patternMul[ rsi ];
+ float add = patternAdd[ rsi ];
+ rsi += rsiInc;
+ if( rsi >= rsiEnd )
+ rsi -= rsiDec;
+ f *= mul;
+ f += add;
+ tensor[ it.x ] = f;
+ }
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/fmaRepeat164.hlsl b/ComputeShaders/fmaRepeat164.hlsl
new file mode 100644
index 0000000..99813f6
--- /dev/null
+++ b/ComputeShaders/fmaRepeat164.hlsl
@@ -0,0 +1,2 @@
+#define THREADS 64
+#include "fmaRepeat1.hlsl" \ No newline at end of file
diff --git a/ComputeShaders/fmaRepeat2.hlsl b/ComputeShaders/fmaRepeat2.hlsl
new file mode 100644
index 0000000..edadf0a
--- /dev/null
+++ b/ComputeShaders/fmaRepeat2.hlsl
@@ -0,0 +1,45 @@
+// Implementation of fmaRepeat() when source arguments have different shape or VRAM layout
+// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor
+RWBuffer<float> tensor: register( u0 );
+Buffer<float> patternMul: register( t0 );
+Buffer<float> patternAdd: register( t1 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 tensorSize: packoffset( c0 );
+ uint4 tensorStrides: packoffset( c1 );
+ uint4 patternSizeMul: packoffset( c2 );
+ uint4 patternStridesMul: packoffset( c3 );
+ uint4 patternSizeAdd: packoffset( c4 );
+ uint4 patternStridesAdd: packoffset( c5 );
+}
+
+#ifndef THREADS
+#define THREADS 32
+#endif
+
+#include "repeatUtils.hlsli"
+
+inline float loadPattern( Buffer<float> buffer, uint rowStart, uint i, uint4 size, uint4 stride )
+{
+ i %= size.x;
+ return buffer[ i * stride.x + rowStart ];
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides );
+ const uint rsiMul = rowOffset( group % patternSizeMul.yzw, patternStridesMul );
+ const uint rsiAdd = rowOffset( group % patternSizeAdd.yzw, patternStridesAdd );
+
+ for( uint i = thread; it.x < it.z; it.x += it.y, i++ )
+ {
+ precise float f = tensor[ it.x ];
+ float mul = loadPattern( patternMul, rsiMul, i, patternSizeMul, patternStridesMul );
+ float add = loadPattern( patternAdd, rsiAdd, i, patternSizeAdd, patternStridesAdd );
+ f *= mul;
+ f += add;
+ tensor[ it.x ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/fp64Utils.hlsli b/ComputeShaders/fp64Utils.hlsli
new file mode 100644
index 0000000..9782718
--- /dev/null
+++ b/ComputeShaders/fp64Utils.hlsli
@@ -0,0 +1,28 @@
+// TODO: compile another version of these shader, and use it on GPUs with ExtendedDoublesShaderInstructions flag, will become slightly faster
+// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_feature_data_d3d11_options
+#ifndef ExtendedDoublesShaderInstructions
+#define ExtendedDoublesShaderInstructions 0
+#endif
+
+// Compute num/den in FP64 precision
+inline double div64( double num, double den )
+{
+#if ExtendedDoublesShaderInstructions
+ return num / den;
+#else
+ // https://en.wikipedia.org/wiki/Division_algorithm#Newton%E2%80%93Raphson_division
+ double x = 1.0f / (float)den;
+ x += x * ( 1.0 - den * x );
+ x += x * ( 1.0 - den * x );
+ return num * x;
+#endif
+}
+
+// Compute sqrt(x) in FP64 precision
+inline double sqrt64( double x )
+{
+ double root = sqrt( (float)x );
+ root = 0.5 * ( root + div64( x, root ) );
+ root = 0.5 * ( root + div64( x, root ) );
+ return root;
+} \ No newline at end of file
diff --git a/ComputeShaders/groupReduce.hlsli b/ComputeShaders/groupReduce.hlsli
new file mode 100644
index 0000000..1ffe1d8
--- /dev/null
+++ b/ComputeShaders/groupReduce.hlsli
@@ -0,0 +1,139 @@
+groupshared float sharedAccumulators[ 32 ];
+
+// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group.
+void horizontalSum( const uint thread, inout float sum )
+{
+ sharedAccumulators[ thread ] = sum;
+ for( uint i = 16; i > 1; i /= 2 )
+ {
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < i )
+ {
+ sum += sharedAccumulators[ thread + i ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ if( 0 == thread )
+ sum += sharedAccumulators[ 1 ];
+}
+
+// Compute horisontal sum of the numbers, and broadcast to all threads of the group.
+void horizontalSumBroadcast( const uint thread, inout float sum )
+{
+ horizontalSum( thread, sum );
+ if( 0 == thread )
+ sharedAccumulators[ 0 ] = sum;
+ GroupMemoryBarrierWithGroupSync();
+ sum = sharedAccumulators[ 0 ];
+}
+
+// Compute horisontal sum of the numbers, in the order equal to the CPU-running dot product implementation.
+// The result is only correct on the thread #0 of the group.
+void horizontalSumCompat( const uint thread, inout float sum )
+{
+ sharedAccumulators[ thread ] = sum;
+ GroupMemoryBarrierWithGroupSync();
+
+ if( 0 == ( thread & 8 ) )
+ {
+ // This runs on threads [ 0 .. 7 ] and [ 16 .. 23 ]
+ // sum01 = _mm256_add_ps( sum0, sum1 );
+ // sum23 = _mm256_add_ps( sum2, sum3 );
+ sum += sharedAccumulators[ thread + 8 ];
+ sharedAccumulators[ thread ] = sum;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < 8 )
+ {
+ // This runs on threads [ 0 .. 7 ]
+ // sum0123 = _mm256_add_ps( sum01, sum23 );
+ sum += sharedAccumulators[ thread + 16 ];
+ sharedAccumulators[ thread ] = sum;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < 4 )
+ {
+ // const __m128 r4 = _mm_add_ps( _mm256_castps256_ps128( sum0123 ), _mm256_extractf128_ps( sum0123, 1 ) );
+ sum += sharedAccumulators[ thread + 4 ];
+ sharedAccumulators[ thread ] = sum;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < 2 )
+ {
+ // const __m128 r2 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) );
+ sum += sharedAccumulators[ thread + 2 ];
+ sharedAccumulators[ thread ] = sum;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ if( 0 == thread )
+ {
+ // const __m128 r1 = _mm_add_ss( r2, _mm_movehdup_ps( r2 ) );
+ sum += sharedAccumulators[ 1 ];
+ }
+}
+
+// Compute horisontal sum of the numbers, in yet another creative summation order recently implemented in the upstream
+void horizontalSumCompatNew( const uint thread, inout float sum )
+{
+ // GGML_F32x8_REDUCE
+ sharedAccumulators[ thread ] = sum;
+ GroupMemoryBarrierWithGroupSync();
+
+ if( 0 == ( thread & 8 ) )
+ {
+ // Runs on threads [ 0 .. 7 ] and [ 16 .. 23 ]
+ sum += sharedAccumulators[ thread | 8 ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ if( thread < 8 )
+ {
+ // Runs on threads [ 0 .. 7 ]
+ sum += sharedAccumulators[ thread | 0x10 ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ if( thread < 4 )
+ {
+ // Runs on threads [ 0 .. 3 ]
+ sum += sharedAccumulators[ thread | 4 ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ if( thread < 4 && 0 == ( thread & 1 ) )
+ {
+ // Runs on threads [ 0, 2 ]
+ sum += sharedAccumulators[ thread | 1 ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ GroupMemoryBarrierWithGroupSync();
+
+ if( 0 == thread )
+ sum += sharedAccumulators[ 2 ];
+}
+
+
+// Compute horizontal maximum of the numbers, and broadcast to all threads of the group.
+void horizontalMaxBroadcast( const uint thread, inout float ax )
+{
+ sharedAccumulators[ thread ] = ax;
+ for( uint i = 16; i > 0; i /= 2 )
+ {
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < i )
+ {
+ ax = max( ax, sharedAccumulators[ thread + i ] );
+ sharedAccumulators[ thread ] = ax;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ ax = sharedAccumulators[ 0 ];
+} \ No newline at end of file
diff --git a/ComputeShaders/groupReduce64.hlsli b/ComputeShaders/groupReduce64.hlsli
new file mode 100644
index 0000000..7094d03
--- /dev/null
+++ b/ComputeShaders/groupReduce64.hlsli
@@ -0,0 +1,46 @@
+groupshared float sharedAccumulators[ 64 ];
+
+// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group.
+void horizontalSum( const uint thread, inout float sum )
+{
+ sharedAccumulators[ thread ] = sum;
+ for( uint i = 32; i > 1; i /= 2 )
+ {
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < i )
+ {
+ sum += sharedAccumulators[ thread + i ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ if( 0 == thread )
+ sum += sharedAccumulators[ 1 ];
+}
+
+// Compute horisontal sum of the numbers, and broadcast to all threads of the group.
+void horizontalSumBroadcast( const uint thread, inout float sum )
+{
+ horizontalSum( thread, sum );
+ if( 0 == thread )
+ sharedAccumulators[ 0 ] = sum;
+ GroupMemoryBarrierWithGroupSync();
+ sum = sharedAccumulators[ 0 ];
+}
+
+// Compute horizontal maximum of the numbers, and broadcast to all threads of the group.
+void horizontalMaxBroadcast( const uint thread, inout float ax )
+{
+ sharedAccumulators[ thread ] = ax;
+ for( uint i = 32; i > 0; i /= 2 )
+ {
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < i )
+ {
+ ax = max( ax, sharedAccumulators[ thread + i ] );
+ sharedAccumulators[ thread ] = ax;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ ax = sharedAccumulators[ 0 ];
+} \ No newline at end of file
diff --git a/ComputeShaders/matReshapePanels.hlsl b/ComputeShaders/matReshapePanels.hlsl
new file mode 100644
index 0000000..f26f246
--- /dev/null
+++ b/ComputeShaders/matReshapePanels.hlsl
@@ -0,0 +1,105 @@
+// This shader reshapes a matrix into the shape expected by mulMatTiledEx.hlsl and mulMatByRowTiledEx.hlsl compute shaders
+// It's called in runtime, also while loading models from disk.
+// So far, it's only used when running on AMD GPUs.
+#ifndef TILE_SIZE
+static const uint TILE_SIZE = 32;
+#endif
+
+// Input tensor
+Buffer<float> source: register( t0 );
+// Output tensor
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 arg0Size: packoffset( c0 );
+ uint4 arg0Strides: packoffset( c1 );
+ // Count of elements per panel
+ uint panelSize : packoffset( c2.y );
+ // Layer strides of the output matrix
+ uint2 layerStrides: packoffset( c2.z );
+}
+
+inline uint hadd( uint2 v2 ) { return v2.x + v2.y; }
+
+groupshared float tileBuffer[ TILE_SIZE ][ TILE_SIZE ];
+
+[ numthreads( TILE_SIZE, 1, 1 ) ]
+void main( const uint3 group: SV_GroupID, const uint thread : SV_GroupIndex )
+{
+ uint rdi = hadd( group.yz * layerStrides );
+ rdi += group.x * panelSize;
+ rdi += thread;
+
+ uint rsi = hadd( group.yz * arg0Strides.zw );
+ const uint baseY = group.x * TILE_SIZE;
+ const uint dispatchThread = baseY + thread;
+ // Reshaping into a column major horizontal panel, height = TILE_SIZE, width = width of the source matrix
+ uint width = arg0Size.x;
+ // Usually TILE_SIZE; can be less for the last panel on the matrix when we need to generate zeros instead of loading these numbers
+ const uint height = min( TILE_SIZE, arg0Size.y - baseY );
+
+ if( arg0Strides.x == 1 )
+ {
+ // The input matrix is row major, can improve performance with coalesced loads and group shared buffer.
+ rsi += baseY * arg0Strides.y;
+
+ const uint widthCompleteTiles = width / TILE_SIZE;
+
+ if( height < TILE_SIZE )
+ {
+ // This thread group was dispatched for the last panel of the matrix, it doesn't have enough rows
+ // Write zeros to the corresponding elements of the groupshared buffer
+ for( uint j = height; j < TILE_SIZE; j++ )
+ tileBuffer[ thread ][ j ] = 0.0;
+ }
+
+ for( uint i = 0; i < widthCompleteTiles; i++, rsi += TILE_SIZE )
+ {
+ // Load [ TILE_SIZE ] * [ TILE_SIZE ] block with fully coalesced loads, store to group shared buffer in transposed order
+ uint rsiTile = rsi + thread;
+ uint j;
+ for( j = 0; j < height; j++, rsiTile += arg0Strides.y )
+ {
+ // Each iteration of the loop loads a row of [ TILE_SIZE ] elements from the corresponding row of the source tensor
+ // Fully coalesced load
+ float f = source[ rsiTile ];
+ // Random store but the local memory's fast, this works rather well in practice
+ tileBuffer[ thread ][ j ] = f;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ // Copy from group shared buffer to output tensor
+ for( j = 0; j < TILE_SIZE; j++, rdi += TILE_SIZE )
+ {
+ // Fully coalesced loads and stores
+ float f = tileBuffer[ j ][ thread ];
+ result[ rdi ] = f;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ width %= TILE_SIZE;
+ if( 0 == width )
+ return;
+ rsi += thread * arg0Strides.y;
+ }
+ else
+ rsi += dispatchThread * arg0Strides.y;
+
+ for( uint i = 0; i < width; i++ )
+ {
+ float f;
+ [branch]
+ if( thread < height )
+ f = source[ rsi ];
+ else
+ f = 0.0;
+ rsi += arg0Strides.x;
+
+ result[ rdi ] = f;
+ rdi += TILE_SIZE;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/miscUtils.hlsli b/ComputeShaders/miscUtils.hlsli
new file mode 100644
index 0000000..b957a06
--- /dev/null
+++ b/ComputeShaders/miscUtils.hlsli
@@ -0,0 +1,84 @@
+// When GPUs are converting FP32 to FP16, they always truncate towards 0, documented there:
+// https://learn.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-data-conversion#conververting-from-a-higher-range-representation-to-a-lower-range-representation
+// Whisper code uses _mm_cvtps_ph( x, 0 ), the 0 stands for "Round to nearest even": https://www.felixcloutier.com/x86/vcvtps2ph
+// This function adjusts FP32 value making it so that truncation towards 0 results in the value equal to what CPU is doing
+inline float adjustFp16( const float src )
+{
+ const uint trunc16 = f32tof16( src );
+ const float trunc32 = f16tof32( trunc16 );
+
+ const uint truncExp = ( trunc16 >> 10 ) & 0x1F;
+ if( truncExp != 0x1F )
+ {
+ const uint next16 = trunc16 + 1;
+ const float next32 = f16tof32( next16 );
+
+ const float errTrunc = abs( src - trunc32 );
+ const float errNext = abs( src - next32 );
+
+ if( errTrunc < errNext )
+ {
+ // Truncated was closer to the source
+ return src;
+ }
+ else if( errTrunc > errNext )
+ {
+ // Truncated + 1 was closer to the source
+ return next32;
+ }
+ else
+ {
+ // Exactly half, doing banker's rounding to nearest even
+ return ( 0 == ( trunc16 & 1 ) ) ? src : next32;
+ }
+ }
+ else
+ {
+ // INF or NAN
+ return src;
+ }
+}
+
+// Convert FP32 number to FP16, using rounding to nearest
+inline uint fp16Rounded( const float src )
+{
+ const uint trunc16 = f32tof16( src );
+ const float trunc32 = f16tof32( trunc16 );
+
+ const uint truncExp = ( trunc16 >> 10 ) & 0x1F;
+ if( truncExp != 0x1F )
+ {
+ const uint next16 = trunc16 + 1;
+ const float next32 = f16tof32( next16 );
+
+ const float errTrunc = abs( src - trunc32 );
+ const float errNext = abs( src - next32 );
+
+ if( errTrunc < errNext )
+ {
+ // Truncated was closer to the source
+ return trunc16;
+ }
+ else if( errTrunc > errNext )
+ {
+ // Truncated + 1 was closer to the source
+ return next16;
+ }
+ else
+ {
+ // Exactly half, doing banker's rounding to nearest even
+ return ( 0 == ( trunc16 & 1 ) ) ? trunc16 : next16;
+ }
+ }
+ else
+ {
+ // INF or NAN
+ return trunc16;
+ }
+}
+
+// Round up the number to be a multiple of 32
+inline uint roundUp32( uint x )
+{
+ return ( x + 31 ) & ( ~31u );
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatByRow.hlsl b/ComputeShaders/mulMatByRow.hlsl
new file mode 100644
index 0000000..565cfdc
--- /dev/null
+++ b/ComputeShaders/mulMatByRow.hlsl
@@ -0,0 +1,49 @@
+// Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ]
+// Dispatch [ E1, E2, E3 ] groups of this shader
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 arg0Size: packoffset( c0 );
+ uint4 arg0Strides: packoffset( c1 );
+ uint4 arg1Size: packoffset( c2 );
+ uint4 arg1Strides: packoffset( c3 );
+ uint4 resultSize: packoffset( c4 );
+ uint4 resultStrides: packoffset( c5 );
+}
+
+#include "groupReduce.hlsli"
+
+inline uint hadd( uint3 vec )
+{
+ return vec.x + vec.y + vec.z;
+}
+inline uint hadd( uint2 vec )
+{
+ return vec.x + vec.y;
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint s0 = hadd( group * arg0Strides.yzw );
+ uint s1 = hadd( group.yz * arg1Strides.zw );
+ const uint s0End = s0 + arg0Size.x * arg0Strides.x;
+ const uint s0Inc = 32 * arg0Strides.x;
+ const uint s1Inc = 32 * arg1Strides.x;
+
+ s0 += thread * arg0Strides.x;
+ s1 += thread * arg1Strides.x;
+ float dp = 0;
+ for( ; s0 < s0End; s0 += s0Inc, s1 += s1Inc )
+ dp = mad( arg0[ s0 ], arg1[ s1 ], dp );
+
+ horizontalSum( thread, dp );
+ if( 0 != thread )
+ return;
+
+ const uint rdi = group.x + hadd( group.yz * resultStrides.zw );
+ result[ rdi ] = dp;
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatByRow64.hlsl b/ComputeShaders/mulMatByRow64.hlsl
new file mode 100644
index 0000000..db5f801
--- /dev/null
+++ b/ComputeShaders/mulMatByRow64.hlsl
@@ -0,0 +1,90 @@
+// Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ]
+// Dispatch [ E1, E2, E3 ] groups of this shader
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 arg0Size: packoffset( c0 );
+ uint4 arg0Strides: packoffset( c1 );
+ uint4 arg1Size: packoffset( c2 );
+ uint4 arg1Strides: packoffset( c3 );
+ uint4 resultSize: packoffset( c4 );
+ uint4 resultStrides: packoffset( c5 );
+}
+
+inline uint hadd( uint3 vec )
+{
+ return vec.x + vec.y + vec.z;
+}
+inline uint hadd( uint2 vec )
+{
+ return vec.x + vec.y;
+}
+
+// No idea why, but that particular configuration appears to be the fastest one on Ryzen 7 5700G iGPU
+// Not by much, though: when trying a few numbers I saw 1.30 - 1.42 seconds for this compute shader
+static const uint THREADS = 64;
+static const uint REDUCTION_BUFFER = 32;
+groupshared float sharedAccumulators[ REDUCTION_BUFFER ];
+
+// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group.
+void horizontalSum( const uint thread, inout float sum )
+{
+ if( THREADS > REDUCTION_BUFFER )
+ {
+ for( uint t = REDUCTION_BUFFER; t < THREADS; t += REDUCTION_BUFFER )
+ {
+ // Threads [ t .. t + REDUCTION_BUFFER ] store into the buffer
+ if( thread >= t && thread < t + REDUCTION_BUFFER )
+ sharedAccumulators[ thread - t ] = sum;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ // Threads [ 0 .. REDUCTION_BUFFER ] increment their local sum with the value loaded from the buffer
+ if( thread < REDUCTION_BUFFER )
+ sum += sharedAccumulators[ thread ];
+ }
+ }
+
+ if( thread < REDUCTION_BUFFER )
+ sharedAccumulators[ thread ] = sum;
+
+ for( uint i = REDUCTION_BUFFER / 2; i > 1; i /= 2 )
+ {
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < i )
+ {
+ sum += sharedAccumulators[ thread + i ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ if( 0 == thread )
+ sum += sharedAccumulators[ 1 ];
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint s0 = hadd( group * arg0Strides.yzw );
+ uint s1 = hadd( group.yz * arg1Strides.zw );
+ const uint s0End = s0 + arg0Size.x * arg0Strides.x;
+ const uint s0Inc = THREADS * arg0Strides.x;
+ const uint s1Inc = THREADS * arg1Strides.x;
+
+ s0 += thread * arg0Strides.x;
+ s1 += thread * arg1Strides.x;
+ float dp = 0;
+ for( ; s0 < s0End; s0 += s0Inc, s1 += s1Inc )
+ dp = mad( arg0[ s0 ], arg1[ s1 ], dp );
+
+ horizontalSum( thread, dp );
+ if( 0 != thread )
+ return;
+
+ const uint rdi = group.x + hadd( group.yz * resultStrides.zw );
+ result[ rdi ] = dp;
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatByRowTiled.hlsl b/ComputeShaders/mulMatByRowTiled.hlsl
new file mode 100644
index 0000000..fea2fcb
--- /dev/null
+++ b/ComputeShaders/mulMatByRowTiled.hlsl
@@ -0,0 +1,120 @@
+// Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ]
+// Dispatch [ ( E1 + TILE_Y - 1 ) / TILE_Y, E2, E3 ] thread groups of this shader
+
+#ifndef TILE_Y
+static const uint TILE_Y = 64;
+#endif
+#ifndef THREADS_X
+static const uint THREADS_X = 32;
+#endif
+#ifndef THREADS_Y
+static const uint THREADS_Y = 16;
+#endif
+
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 arg0Size: packoffset( c0 );
+ uint4 arg0Strides: packoffset( c1 );
+ uint4 arg1Size: packoffset( c2 );
+ uint4 arg1Strides: packoffset( c3 );
+ uint4 resultSize: packoffset( c4 );
+ uint4 resultStrides: packoffset( c5 );
+}
+
+groupshared float resTemp[ TILE_Y ][ THREADS_X ];
+
+inline uint hadd( uint2 vec )
+{
+ return vec.x + vec.y;
+}
+
+[ numthreads( THREADS_X, THREADS_Y, 1 ) ]
+void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID, uint threadFlattenned : SV_GroupIndex )
+{
+ uint i;
+ // Zero out the shared buffer
+ for( i = thread.y; i < TILE_Y; i += THREADS_Y )
+ resTemp[ i ][ thread.x ] = 0.0;
+ GroupMemoryBarrierWithGroupSync();
+
+ // Count of rows to compute in this thread group
+ const uint height = min( TILE_Y, arg0Size.y - group.x * TILE_Y );
+
+ uint s0 = hadd( group.yz * arg0Strides.zw ); //< arg0 layer for the thread group
+ s0 += group.x * TILE_Y * arg0Strides.y; //< arg0 first row for the thread group
+ s0 += hadd( arg0Strides.xy * thread.xy ); //< arg0 load index for the thread
+
+ uint s1 = hadd( group.yz * arg1Strides.zw ); //< arg1 layer for the thread group
+ s1 += thread.x * arg1Strides.x; //< arg1 load index for the thread
+
+ const uint completeTiles = arg0Size.x / THREADS_X;
+ // Each iteration of that loop loads THREADS_X elements from arg1,
+ // a block of [ THREADS_X, height ] elements from arg0,
+ // and accumulates these dot products in the shared buffer
+ for( uint t = 0; t < completeTiles; t++, s0 += THREADS_X * arg0Strides.x, s1 += THREADS_X * arg1Strides.x )
+ {
+ // Load THREADS_X elements from arg1
+ const float v1 = arg1[ s1 ];
+
+ uint rsi = s0;
+ for( i = thread.y; i < height; i += THREADS_Y, rsi += arg0Strides.y * THREADS_Y )
+ {
+ // Load THREADS_X elements from arg0
+ const float v0 = arg0[ rsi ];
+ // Multiply and accumulate in the shared buffer
+ float acc = resTemp[ i ][ thread.x ];
+ acc = mad( v0, v1, acc );
+ resTemp[ i ][ thread.x ] = acc;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ const uint rem = arg0Size.x % THREADS_X;
+ if( rem != 0 )
+ {
+ // E0 ain't a multiple of THREADS_X, we have a remainder
+ float v1;
+ if( thread.x < rem )
+ v1 = arg1[ s1 ];
+ else
+ v1 = 0.0;
+
+ for( i = thread.y; i < height; i += THREADS_Y, s0 += arg0Strides.y * THREADS_Y )
+ {
+ if( thread.x >= rem )
+ continue;
+ const float v0 = arg0[ s0 ];
+ float acc = resTemp[ i ][ thread.x ];
+ acc = mad( v0, v1, acc );
+ resTemp[ i ][ thread.x ] = acc;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ // Now we need horizontal sums of these shared accumulators, i.e. reduce [height][THREADS_X] shared array into [height][1] column
+ for( i = THREADS_X / 2; i > 0; i /= 2 )
+ {
+ if( thread.x < i )
+ {
+ for( uint j = thread.y; j < height; j += THREADS_Y )
+ {
+ float sum = resTemp[ j ][ thread.x ];
+ sum += resTemp[ j ][ thread.x + i ];
+ resTemp[ j ][ thread.x ] = sum;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ // And finally, store that column to global memory
+ if( threadFlattenned >= height )
+ return;
+
+ uint rdi = hadd( group.yz * resultStrides.zw ) + group.x * TILE_Y * resultStrides.x;
+ rdi += threadFlattenned * resultStrides.x;
+ result[ rdi ] = resTemp[ threadFlattenned ][ 0 ];
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatByRowTiled64.hlsl b/ComputeShaders/mulMatByRowTiled64.hlsl
new file mode 100644
index 0000000..6c63f2d
--- /dev/null
+++ b/ComputeShaders/mulMatByRowTiled64.hlsl
@@ -0,0 +1,4 @@
+#define THREADS_Y 32
+#define THREADS_X 32
+#define TILE_Y 128
+#include "mulMatByRowTiled.hlsl" \ No newline at end of file
diff --git a/ComputeShaders/mulMatByRowTiledEx.hlsl b/ComputeShaders/mulMatByRowTiledEx.hlsl
new file mode 100644
index 0000000..d377b8c
--- /dev/null
+++ b/ComputeShaders/mulMatByRowTiledEx.hlsl
@@ -0,0 +1,156 @@
+// matrix*row vector product, needs first argument reshaped into a sequence of horizontal column major panels
+#ifndef TILE_SIZE
+static const uint TILE_SIZE = 32;
+#endif
+#ifndef TILE_HEIGHT
+static const uint TILE_HEIGHT = 32;
+#endif
+#ifndef THREADS_Y
+static const uint THREADS_Y = 16;
+#endif
+
+// First tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ]
+Buffer<float> arg0: register( t0 );
+// Second tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ]
+Buffer<float> arg1: register( t1 );
+// FP32 output tensor, row major and continuous
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 arg0Size: packoffset( c0 );
+ uint arg0panel: packoffset( c1.y );
+ uint2 arg0LayerStrides: packoffset( c1.z );
+ // uint4 arg1Size: packoffset( c2 );
+ uint4 arg1Strides: packoffset( c3 );
+ uint4 resultSize: packoffset( c4 );
+ uint4 resultStrides: packoffset( c5 );
+}
+
+groupshared float tileOutput[ THREADS_Y ][ TILE_SIZE ];
+groupshared float tile0[ TILE_HEIGHT ][ TILE_SIZE ];
+groupshared float tile1[ TILE_HEIGHT ];
+
+void multiplyTiles( const uint3 thread )
+{
+ float r = 0.0;
+ for( uint i = thread.y; i < TILE_HEIGHT; i += THREADS_Y )
+ {
+ float a = tile0[ i ][ thread.x ];
+ float b = tile1[ i ];
+ r = mad( a, b, r );
+ }
+ tileOutput[ thread.y ][ thread.x ] += r;
+}
+
+void reduceOutput( const uint3 thread )
+{
+ float curr = 0.0;
+ [branch]
+ if( thread.y < THREADS_Y / 2 )
+ curr = tileOutput[ thread.y ][ thread.x ];
+
+ for( uint i = THREADS_Y / 2; i > 0; i /= 2 )
+ {
+ [branch]
+ if( thread.y < i )
+ {
+ curr += tileOutput[ thread.y + i ][ thread.x ];
+ tileOutput[ thread.y ][ thread.x ] = curr;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+}
+
+void storeTile( const uint threadFlat, const uint4 pos, const uint size )
+{
+ if( threadFlat >= size )
+ return;
+ const uint4 prod4 = pos * resultStrides;
+ const uint2 prod2 = prod4.xy + prod4.zw;
+ uint rdi = prod2.x + prod2.y;
+ result[ rdi + threadFlat ] = tileOutput[ 0 ][ threadFlat ];
+}
+
+[ numthreads( TILE_SIZE, THREADS_Y, 1 ) ]
+void main( const uint3 group: SV_GroupID, const uint3 thread : SV_GroupThreadID, uint threadFlat : SV_GroupIndex )
+{
+ uint i;
+ // Zero all 3 shared buffers
+ tileOutput[ thread.y ][ thread.x ] = 0.0;
+ for( i = thread.y; i < TILE_HEIGHT; i += THREADS_Y )
+ tile0[ i ][ thread.x ] = 0.0;
+ if( threadFlat < THREADS_Y )
+ tile1[ threadFlat ] = 0.0;
+
+ const uint2 layer = group.yz;
+ uint rsi0 = group.x * arg0panel + layer.x * arg0LayerStrides.x + layer.y * arg0LayerStrides.y;
+ uint rsi1 = layer.x * arg1Strides.z + layer.y * arg1Strides.w;
+
+ const uint threadOffset = thread.y * TILE_SIZE + thread.x;
+ rsi0 += threadOffset;
+ rsi1 += threadFlat * arg1Strides.x;
+
+ const uint completeTiles = arg0Size.x / TILE_HEIGHT;
+ for( i = 0; i < completeTiles; i++ )
+ {
+ // Load [ TILE_SIZE, TILE_HEIGHT ] block from the first source tensor into the groupshared buffer
+ for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y )
+ {
+ tile0[ j ][ thread.x ] = arg0[ rsi0 ];
+ rsi0 += THREADS_Y * TILE_SIZE;
+ }
+ // Load [ TILE_HEIGHT ] row from the second source into another groupshared buffer
+ [ branch ]
+ if( threadFlat < TILE_HEIGHT )
+ tile1[ threadFlat ] = arg1[ rsi1 ];
+ rsi1 += TILE_HEIGHT * arg1Strides.x;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ multiplyTiles( thread );
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ const uint rem = arg0Size.x % TILE_HEIGHT;
+ if( rem != 0 )
+ {
+ for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y )
+ {
+ float a;
+ [branch]
+ if( j < rem )
+ {
+ a = arg0[ rsi0 ];
+ rsi0 += THREADS_Y * TILE_SIZE;
+ }
+ else
+ a = 0.0;
+ tile0[ j ][ thread.x ] = a;
+ }
+
+ if( threadFlat < TILE_HEIGHT )
+ {
+ float b;
+ [branch]
+ if( threadFlat < rem )
+ b = arg1[ rsi1 ];
+ else
+ b = 0.0;
+ tile1[ threadFlat ] = b;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ multiplyTiles( thread );
+
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ reduceOutput( thread );
+
+ const uint resultPos = group.x * TILE_SIZE;
+ const uint outputSize = min( TILE_SIZE, resultSize.x - resultPos );
+ storeTile( threadFlat, uint4( resultPos, 0, layer ), outputSize );
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatByScalar.hlsl b/ComputeShaders/mulMatByScalar.hlsl
new file mode 100644
index 0000000..82df5d4
--- /dev/null
+++ b/ComputeShaders/mulMatByScalar.hlsl
@@ -0,0 +1,41 @@
+// Matrix * scalar product, like [ 1, E1, E2, E3 ] * [ 1, 1, E2, E3 ] = [ E1, 1, E2, E3 ]
+// Dispatch [ E2, E3, 1 ] thread groups of this shader
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 arg0Size: packoffset( c0 );
+ uint4 arg0Strides: packoffset( c1 );
+ uint4 arg1Size: packoffset( c2 );
+ uint4 arg1Strides: packoffset( c3 );
+ uint4 resultSize: packoffset( c4 );
+ uint4 resultStrides: packoffset( c5 );
+}
+
+inline uint hadd( uint2 vec )
+{
+ return vec.x + vec.y;
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const float scalarValue = arg1[ hadd( group.xy * arg1Strides.zw ) ];
+
+ uint s0 = hadd( group.xy * arg0Strides.zw );
+ const uint s0Inc = 32 * arg0Strides.y;
+ s0 += thread * arg0Strides.y;
+
+ uint rdi = hadd( group.xy * resultStrides.zw );
+ const uint rdiEnd = rdi + arg0Size.y;
+ rdi += thread;
+
+ for( ; rdi < rdiEnd; rdi += 32, s0 += s0Inc )
+ {
+ float f = arg0[ s0 ];
+ f *= scalarValue;
+ result[ rdi ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatDotMain.hlsl b/ComputeShaders/mulMatDotMain.hlsl
new file mode 100644
index 0000000..47c6d3e
--- /dev/null
+++ b/ComputeShaders/mulMatDotMain.hlsl
@@ -0,0 +1,95 @@
+// GGML_TASK_COMPUTE step for matrix*matrix product, where nb01 >= nb00;
+// Dispatch with [ ne11, ne01*ne02*ne03 ] thread groups
+// Each thread group computes a single dot product
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ uint4 src1_elements: packoffset( c2 );
+ uint4 result_elements: packoffset( c4 );
+ uint4 result_strides: packoffset( c5 );
+}
+
+inline uint product( uint3 vec )
+{
+ return vec.x * vec.y * vec.z;
+}
+
+inline uint product( uint4 vec )
+{
+ uint2 tmp = vec.xy * vec.zw;
+ return tmp.x * tmp.y;
+}
+
+inline float dotProductInner( uint i0, uint i1, uint length, uint thread )
+{
+ float res = 0;
+ for( uint i = thread; i < length; i += 32 )
+ res = mad( arg0[ i0 + i ], arg1[ i1 + i ], res );
+ return res;
+}
+
+#include "groupReduce.hlsli"
+
+[numthreads( 32, 1, 1 )]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint ne00 = src0_elements.x;
+ const uint ne01 = src0_elements.y;
+ const uint ne02 = src0_elements.z;
+ const uint ne03 = src0_elements.w;
+
+ const uint ne10 = src1_elements.x;
+ const uint ne11 = src1_elements.y;
+ const uint ne12 = src1_elements.z;
+ const uint ne13 = src1_elements.w;
+
+ const int nb00 = src0_strides.x;
+ const int nb01 = src0_strides.y;
+ const int nb02 = src0_strides.z;
+ const int nb03 = src0_strides.w;
+
+ // total rows in src0
+ // const int nr = ne01*ne02*ne03;
+ const uint nr = product( src0_elements.yzw );
+
+ const uint ir = group.y;
+
+ // src0 indices
+ const uint i03 = ir / ( ne02 * ne01 );
+ const uint i02 = ( ir - i03 * ne02 * ne01 ) / ne01;
+ const uint i01 = ( ir - i03 * ne02 * ne01 - i02 * ne01 );
+
+ const uint i13 = i03;
+ const uint i12 = i02;
+
+ const uint i0 = i01;
+ const uint i2 = i02;
+ const uint i3 = i03;
+
+ // src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+ // src1_col = wdata + ( i13 * ne12 * ne11 + i12 * ne11 + 0 ) * ne00;
+ const uint src0_row = i01 * nb01 + i02 * nb02 + i03 * nb03;
+ const uint src1_col = ( i13 * ne12 * ne11 + i12 * ne11 ) * ne00;
+
+ const uint ic = group.x;
+ float curr = dotProductInner( src0_row, src1_col + ic * ne00, ne00, thread );
+ horizontalSumCompatNew( thread, curr );
+
+ if( 0 != thread )
+ return;
+
+ const uint nb0 = result_strides.x;
+ const uint nb1 = result_strides.y;
+ const uint nb2 = result_strides.z;
+ const uint nb3 = result_strides.w;
+
+ const uint ne0 = result_elements.x;
+ // float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+ const uint dst_col = i0 * nb0 + i2 * nb2 + i3 * nb3;
+ result[ dst_col + ic * ne0 ] = curr;
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatDotReshape.hlsl b/ComputeShaders/mulMatDotReshape.hlsl
new file mode 100644
index 0000000..ffb6f83
--- /dev/null
+++ b/ComputeShaders/mulMatDotReshape.hlsl
@@ -0,0 +1,33 @@
+// GGML_TASK_INIT step for matrix*matrix product, where nb01 >= nb00;
+// Dispatch with [ ne11, ne12 ] groups
+Buffer<float> arg0: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+}
+
+#include "miscUtils.hlsli"
+
+// Each thread group of this shader copies a single rows of the matrix
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint i12 = group.y;
+ const uint i11 = group.x;
+ const uint ne10 = src0_elements.x;
+ const uint ne11 = src0_elements.y;
+ const uint nb12 = src0_strides.z;
+ const uint nb11 = src0_strides.y;
+
+ uint rdi = i11 * ne10 + i12 * ne10 * ne11;
+ const uint rdiEnd = rdi + ne10;
+ uint rsi = i12 * nb12 + i11 * nb11;
+ rdi += thread;
+ rsi += thread;
+
+ for( ; rdi < rdiEnd; rdi += 32, rsi += 32 )
+ result[ rdi ] = adjustFp16( arg0[ rsi ] );
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatMadMain.hlsl b/ComputeShaders/mulMatMadMain.hlsl
new file mode 100644
index 0000000..0bd5753
--- /dev/null
+++ b/ComputeShaders/mulMatMadMain.hlsl
@@ -0,0 +1,154 @@
+// GGML_TASK_COMPUTE step for matrix*matrix product, where nb01 < nb00
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> resultTensor: register( u0 );
+RWBuffer<float> tempBuffer: register( u1 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 aSize: packoffset( c0 );
+ uint4 aStride: packoffset( c1 );
+ uint4 bSize: packoffset( c2 );
+ uint4 bStride: packoffset( c3 );
+ uint4 resSize: packoffset( c4 );
+ bool resultFp16 : packoffset( c5.x );
+ uint ne: packoffset( c5.y );
+}
+
+#include "miscUtils.hlsli"
+
+// tempBuffer[ rdi .. ] = 0.0
+inline void writeTempZeros( uint rdi, const uint len, const uint thread )
+{
+ const uint rdiEnd = rdi + len;
+ for( rdi += thread; rdi < rdiEnd; rdi += 32 )
+ tempBuffer[ rdi ] = 0.0;
+}
+
+// tempBuffer[ rdi .. ] += mul * arg0[ rsi .. ]
+inline void vectorMad( uint rsi, uint rdi, const uint len, const float mul, const uint thread )
+{
+ const uint rsiEnd = rsi + len;
+ rsi += thread;
+ rdi += thread;
+ for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
+ {
+ float f = tempBuffer[ rdi ];
+ f = mad( mul, arg0[ rsi ], f );
+ [branch]
+ if( resultFp16 )
+ f = adjustFp16( f );
+ tempBuffer[ rdi ] = f;
+ }
+}
+
+// resultTensor[ rdi .. ] = tempBuffer[ rsi .. ]
+inline void copyRow( uint rsi, uint rdi, const uint len, const uint thread )
+{
+ const uint rsiEnd = rsi + len;
+ rsi += thread;
+ rdi += thread;
+ for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
+ {
+ float f = tempBuffer[ rsi ];
+ resultTensor[ rdi ] = f;
+ }
+}
+
+// resultTensor[ rdi .. ] += tempBuffer[ rsi .. ]
+inline void addRow( uint rsi, uint rdi, const uint len, const uint thread )
+{
+ const uint rsiEnd = rsi + len;
+ rsi += thread;
+ rdi += thread;
+ for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
+ {
+ float f = resultTensor[ rdi ];
+ f += tempBuffer[ rsi ];
+ resultTensor[ rdi ] = f;
+ }
+}
+
+[numthreads( 32, 1, 1 )]
+void main( const uint3 group: SV_GroupID, const uint thread : SV_GroupIndex )
+{
+ const uint i1 = group[ 0 ];
+ const uint i2 = group[ 1 ];
+ const uint i3 = group[ 2 ];
+
+ const uint ne00 = aSize[ 0 ];
+ const uint ne01 = aSize[ 1 ];
+ const uint ne02 = aSize[ 2 ];
+ const uint ne03 = aSize[ 3 ];
+
+ const uint ne10 = bSize[ 0 ];
+ const uint ne11 = bSize[ 1 ];
+ const uint ne12 = bSize[ 2 ];
+ const uint ne13 = bSize[ 3 ];
+
+ const uint ne0 = resSize[ 0 ];
+ const uint ne1 = resSize[ 1 ];
+ const uint ne2 = resSize[ 2 ];
+ const uint ne3 = resSize[ 3 ];
+
+ const uint nb00 = aStride[ 0 ];
+ const uint nb01 = aStride[ 1 ];
+ const uint nb02 = aStride[ 2 ];
+ const uint nb03 = aStride[ 3 ];
+
+ const uint nb10 = bStride[ 0 ];
+ const uint nb11 = bStride[ 1 ];
+ const uint nb12 = bStride[ 2 ];
+ const uint nb13 = bStride[ 3 ];
+
+ // dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+ const uint tempRowThread0 = i3 * ne2 * ne1 * ne0 + i2 * ne1 * ne0 + i1 * ne0;
+
+ // Faking 4 CPU threads trying to achieve bitwise compatibility with the CPU version
+ const uint nth = 4;
+
+ // GGML_TASK_COMPUTE
+ {
+ // src0_col = src0->data + ( i00 * nb00 + i02 * nb02 + i03 * nb03 );
+ const uint aBase = i2 * nb02 + i3 * nb03;
+ // src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ const uint bBase = i1 * nb11 + i2 * nb12 + i3 * nb13;
+
+ // total columns in src1
+ const uint nc = ne10;
+ // columns per thread
+ const uint dc = ( nc + nth - 1 ) / nth;
+
+ uint tempRow = tempRowThread0;
+ for( uint ith = 0; ith < nth; ith++, tempRow += ne )
+ {
+ writeTempZeros( tempRow, ne01, thread );
+
+ // column range for this thread
+ const uint ic0 = dc * ith;
+ const uint ic1 = min( ic0 + dc, nc );
+
+ for( uint ic = ic0; ic < ic1; ic++ )
+ {
+ const uint idxA = aBase + ic * aStride[ 0 ];
+ const uint idxB = bBase + ic * bStride[ 0 ];
+ const float bValue = arg1[ idxB ];
+ vectorMad( idxA, tempRow, ne01, bValue, thread );
+ }
+ }
+ }
+
+ // GGML_TASK_FINALIZE
+ {
+ const uint rdi = tempRowThread0;
+ // const uint rdi = i1 * resSize[ 0 ] + i2 * resSize[ 0 ] * resSize[ 1 ] + i3 * resSize[ 0 ] * resSize[ 1 ] * resSize[ 2 ];
+ // const uint rdi = ( ( i3 * resSize[ 2 ] + i2 ) * resSize[ 1 ] + i1 ) * resSize[ 0 ];
+
+ uint tempRow = tempRowThread0;
+ copyRow( tempRow, rdi, ne01, thread );
+
+ tempRow += ne;
+ for( uint ith = 1; ith < nth; ith++, tempRow += ne )
+ addRow( tempRow, rdi, ne01, thread );
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatTiled.hlsl b/ComputeShaders/mulMatTiled.hlsl
new file mode 100644
index 0000000..7e4d7d8
--- /dev/null
+++ b/ComputeShaders/mulMatTiled.hlsl
@@ -0,0 +1,236 @@
+// This compute shader implements matrix*matrix product, using tiling and other tricks to improve the performance
+#ifndef TILE_SIZE
+static const uint TILE_SIZE = 32;
+#endif
+
+#ifndef THREADS_Y
+// Performance measures on Ryzen 7 5700G iGPU, the time is just for this shader:
+// 1 (32 threads per group) - 17.1 seconds, 2 - 9.02424 seconds, 4 - 6.95762 seconds, 6 - 6.79011 seconds, 8 - 6.67279 seconds, 10 - 6.9456 seconds, 16 - 7.20502 seconds
+// On nVidia, 8 is also the fastest option.
+static const uint THREADS_Y = 8;
+#endif
+
+#ifndef STREAM_SECOND_MATRIX
+#define STREAM_SECOND_MATRIX 0
+#endif
+
+#ifndef LOAD_ORDER
+
+// Load with coalesced loads from global memory whenever possible, store into groupshared buffer with random stores
+// #define LOAD_ORDER bool2( ( 1 == arg0Strides[ 0 ] ) || ( 1 != arg0Strides[ 1 ] ), ( 1 == arg1Strides[ 0 ] ) || ( 1 != arg1Strides[ 1 ] ) )
+
+// Load with random loads from global memory, store into groupshared buffer with coalesced stores
+// On my AMD iGPU inside Ryzen 7 5700G, there's whopping 15% performance win with that tactics, from 6.67 to 5.66 seconds for this shader.
+// My nVidia GPU does about the same
+#define LOAD_ORDER bool2( false, true )
+
+#endif
+
+Buffer<float> arg0: register( t0 );
+Buffer<float> arg1: register( t1 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 arg0Size: packoffset( c0 );
+ uint4 arg0Strides: packoffset( c1 );
+ uint4 arg1Strides: packoffset( c3 );
+ uint4 resultSize: packoffset( c4 );
+ uint4 resultStrides: packoffset( c5 );
+}
+
+groupshared float tile0[ TILE_SIZE ][ TILE_SIZE ];
+#if !STREAM_SECOND_MATRIX
+groupshared float tile1[ TILE_SIZE ][ TILE_SIZE ];
+#endif
+groupshared float resTemp[ TILE_SIZE ][ TILE_SIZE ];
+
+#if STREAM_SECOND_MATRIX
+void multiplyTiles( uint rsi, const uint3 thread, const uint w, const uint h )
+{
+ for( uint i = thread.y; i < h; i += THREADS_Y, rsi += THREADS_Y * arg1Strides.y )
+ {
+ float r = 0;
+ uint rsiRow = rsi;
+ for( uint j = 0; j < w; j++, rsiRow += arg1Strides.x )
+ {
+ // One TILE_SIZE * 4 bytes coalesced load, broadcasted into THREADS_Y copies
+ const float s0 = tile0[ j ][ thread.x ];
+ // THREADS_Y broadcasts from global memory, each one is 4 bytes broadcasted into TILE_SIZE copies
+ const float s1 = arg1[ rsiRow ];
+ // Multiply and accumulate
+ r = mad( s0, s1, r );
+ }
+ // Accumulate into the output tile
+ // THREADS_Y * 128 bytes coalesced loads and stores
+ resTemp[ i ][ thread.x ] += r;
+ }
+}
+#else
+// Compute resTemp += tile0 * tile1, for TILE_SIZE^2 square matrices
+// The group size is TILE_SIZE*THREADS_Y threads in this shader
+void multiplyTiles( const uint3 thread )
+{
+ for( uint i = thread.y; i < TILE_SIZE; i += THREADS_Y )
+ {
+ float r = 0;
+ for( uint j = 0; j < TILE_SIZE; j++ )
+ {
+ // One TILE_SIZE * 4 bytes coalesced load, broadcasted into THREADS_Y copies
+ const float s0 = tile0[ j ][ thread.x ];
+ // THREADS_Y broadcasts, each one is 4 bytes broadcasted into TILE_SIZE copies
+ const float s1 = tile1[ i ][ j ];
+ // Multiply and accumulate
+ r = mad( s0, s1, r );
+ }
+ // Accumulate into the output tile
+ // THREADS_Y * 128 bytes coalesced loads and stores
+ resTemp[ i ][ thread.x ] += r;
+ }
+}
+#endif
+
+// Note we transposed these tiles while loading
+void loadTile0( uint rsi, const uint3 thread, const uint w, const uint h, const bool rowMajor )
+{
+ uint i;
+ if( rowMajor )
+ {
+ rsi += arg0Strides.y * thread.y;
+ for( i = thread.y; i < h; i += THREADS_Y, rsi += arg0Strides.y * THREADS_Y )
+ {
+ if( thread.x < w )
+ tile0[ thread.x ][ i ] = arg0[ rsi + thread.x * arg0Strides.x ];
+ else
+ tile0[ thread.x ][ i ] = 0.0;
+ }
+ }
+ else
+ {
+ // Unlike width which is smaller for the last tile, the height is always the same, and all these tiles are zero-initialized
+ if( thread.x >= h )
+ return;
+
+ rsi += arg0Strides.x * thread.y;
+ for( i = thread.y; i < w; i += THREADS_Y, rsi += arg0Strides.x * THREADS_Y )
+ tile0[ i ][ thread.x ] = arg0[ rsi + thread.x * arg0Strides.y ];
+
+ if( i >= TILE_SIZE )
+ return;
+ for( ; i < TILE_SIZE; i += THREADS_Y )
+ tile0[ i ][ thread.x ] = 0.0;
+ }
+}
+
+#if !STREAM_SECOND_MATRIX
+void loadTile1( uint rsi, const uint3 thread, const uint w, const uint h, const bool rowMajor )
+{
+ uint i;
+ if( rowMajor )
+ {
+ rsi += thread.y * arg1Strides.y;
+
+ for( i = thread.y; i < h; i += THREADS_Y, rsi += arg1Strides.y * THREADS_Y )
+ {
+ if( thread.x < w )
+ tile1[ i ][ thread.x ] = arg1[ rsi + thread.x * arg1Strides.x ];
+ else
+ tile1[ i ][ thread.x ] = 0.0;
+ }
+ }
+ else
+ {
+ // Unlike width which is smaller for the last tile, the height is always the same, and all these tiles are zero-initialized
+ if( thread.x >= h )
+ return;
+
+ rsi += thread.y * arg1Strides.x;
+ for( i = thread.y; i < w; i += THREADS_Y, rsi += arg1Strides.x * THREADS_Y )
+ tile1[ thread.x ][ i ] = arg1[ rsi + thread.x * arg0Strides.y ];
+ if( i >= TILE_SIZE )
+ return;
+ for( ; i < TILE_SIZE; i += THREADS_Y )
+ tile1[ thread.x ][ i ] = 0.0;
+ }
+}
+#endif
+
+void storeTile( const uint3 thread, const uint4 pos, const uint2 size )
+{
+ if( thread.x >= size.x )
+ return;
+ const uint4 prod4 = pos * resultStrides;
+ const uint2 prod2 = prod4.xy + prod4.zw;
+ uint rdi = prod2.x + prod2.y;
+ rdi += resultStrides.y * thread.y;
+ for( uint i = thread.y; i < size.y; i += THREADS_Y, rdi += resultStrides.y * THREADS_Y )
+ result[ rdi + thread.x * resultStrides.x ] = resTemp[ i ][ thread.x ];
+}
+
+[ numthreads( TILE_SIZE, THREADS_Y, 1 ) ]
+void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID )
+{
+ // Zero out these shared buffers
+ for( uint i = 0; i < TILE_SIZE; i += THREADS_Y )
+ {
+ tile0[ i + thread.y ][ thread.x ] = 0.0;
+#if !STREAM_SECOND_MATRIX
+ tile1[ i + thread.y ][ thread.x ] = 0.0;
+#endif
+ resTemp[ i + thread.y ][ thread.x ] = 0.0;
+ }
+
+ const uint2 resultPos = group.xy * TILE_SIZE;
+ const uint2 layer = uint2( group.z % resultSize.z, group.z / resultSize.z );
+ uint rsi0 = resultPos.x * arg0Strides.y + layer.x * arg0Strides.z + layer.y * arg0Strides.w;
+ uint rsi1 = resultPos.y * arg1Strides.y + layer.x * arg1Strides.z + layer.y * arg1Strides.w;
+
+ const uint rsi0Inc = TILE_SIZE * arg0Strides.x;
+ const uint rsi1Inc = TILE_SIZE * arg1Strides.x;
+
+ const uint completeTiles = arg0Size.x / TILE_SIZE;
+ const uint rsi0AndAligned = rsi0 + rsi0Inc * completeTiles;
+ // Output tile size
+ // Normally TILE_SIZE^2, less than that for the tiles at the right and bottom edges of the output matrix
+ const uint2 outputSize = min( TILE_SIZE, resultSize.xy - resultPos );
+
+ const bool2 loadOrder = LOAD_ORDER;
+
+#if STREAM_SECOND_MATRIX
+ rsi1 += thread.y * arg1Strides.y;
+#endif
+ for( ; rsi0 < rsi0AndAligned; rsi0 += rsi0Inc, rsi1 += rsi1Inc )
+ {
+ loadTile0( rsi0, thread, TILE_SIZE, outputSize.x, loadOrder.x );
+#if STREAM_SECOND_MATRIX
+ GroupMemoryBarrierWithGroupSync();
+ multiplyTiles( rsi1, thread, TILE_SIZE, outputSize.y );
+#else
+ loadTile1( rsi1, thread, TILE_SIZE, outputSize.y, loadOrder.y );
+ GroupMemoryBarrierWithGroupSync();
+ multiplyTiles( thread );
+#endif
+ // Need one moar barrier here.
+ // Otherwise, some threads of the group are loading the next tile into tile0/tile1 groupshared buffers on the next iteration of the loop,
+ // while other threads of the same group are still computing the matrix product, and getting incorrect values from that groupshared buffer.
+ // The missing barrier only caused a bug on AMD, and only with "ggml-large.bin" model; no idea why that is.
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ const uint rem = arg0Size.x % TILE_SIZE;
+ if( 0 != rem )
+ {
+ loadTile0( rsi0, thread, rem, outputSize.x, loadOrder.x );
+#if STREAM_SECOND_MATRIX
+ GroupMemoryBarrierWithGroupSync();
+ multiplyTiles( rsi1, thread, rem, outputSize.y );
+#else
+ loadTile1( rsi1, thread, rem, outputSize.y, loadOrder.y );
+ GroupMemoryBarrierWithGroupSync();
+ multiplyTiles( thread );
+#endif
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ storeTile( thread, uint4( resultPos, layer ), outputSize );
+} \ No newline at end of file
diff --git a/ComputeShaders/mulMatTiled64.hlsl b/ComputeShaders/mulMatTiled64.hlsl
new file mode 100644
index 0000000..45d77b1
--- /dev/null
+++ b/ComputeShaders/mulMatTiled64.hlsl
@@ -0,0 +1,3 @@
+#define TILE_SIZE 64
+#define STREAM_SECOND_MATRIX 1
+#include "mulMatTiled.hlsl" \ No newline at end of file
diff --git a/ComputeShaders/mulMatTiledEx.hlsl b/ComputeShaders/mulMatTiledEx.hlsl
new file mode 100644
index 0000000..0f23da2
--- /dev/null
+++ b/ComputeShaders/mulMatTiledEx.hlsl
@@ -0,0 +1,194 @@
+// This compute shader implements yet another version of matrix*matrix product
+// For optimal VRAM access pattern, it requires both arguments to be reshaped into a sequence of horizontal column major panels.
+// The panel height is TILE_SIZE, and the last panel of the matrix needs to be padded with zeros; see matReshapePanels.hlsl shader for the reshaping.
+// So far, it's only used when running on AMD GPUs.
+#ifndef TILE_SIZE
+static const uint TILE_SIZE = 32;
+#endif
+#ifndef TILE_HEIGHT
+static const uint TILE_HEIGHT = 32;
+#endif
+#ifndef THREADS_Y
+static const uint THREADS_Y = 16;
+#endif
+
+#ifndef STREAM_SECOND_MATRIX
+#define STREAM_SECOND_MATRIX 1
+#endif
+
+// First tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ]
+Buffer<float> arg0: register( t0 );
+// Second tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ]
+Buffer<float> arg1: register( t1 );
+// FP32 output tensor, row major and continuous
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 arg0Size: packoffset( c0 );
+ uint arg0panel: packoffset( c1.y );
+ uint2 arg0LayerStrides: packoffset( c1.z );
+
+ // uint4 arg1Size: packoffset( c2 );
+ uint arg1panel: packoffset( c3.y );
+ uint2 arg1LayerStrides: packoffset( c3.z );
+
+ uint4 resultSize: packoffset( c4 );
+ uint4 resultStrides: packoffset( c5 );
+}
+
+// Accumulator for the output tile
+// That last `+1` helps a bit, I'm not sure why exactly but probebly because memory bank conflicts.
+groupshared float tileOutput[ TILE_SIZE ][ TILE_SIZE + 1 ];
+// A smaller tile loaded from the first source matrix
+groupshared float tile0[ TILE_HEIGHT ][ TILE_SIZE ];
+#if !STREAM_SECOND_MATRIX
+// A smaller tile loaded from the second source matrix
+groupshared float tile1[ TILE_HEIGHT ][ TILE_SIZE ];
+#endif
+
+#if STREAM_SECOND_MATRIX
+void multiplyTiles( const uint3 thread, uint rsi, const uint h )
+{
+ uint2 i = uint2( thread.y, rsi );
+ const uint2 iInc = uint2( THREADS_Y, THREADS_Y );
+ for( ; i.x < TILE_SIZE; i += iInc )
+ {
+ float r = 0.0;
+ uint2 j = uint2( 0, i.y );
+ const uint2 jInc = uint2( 1, TILE_SIZE );
+ for( ; j.x < h; j += jInc )
+ {
+ float a = tile0[ j.x ][ thread.x ];
+ float b = arg1[ j.y ];
+ r = mad( a, b, r );
+ }
+ tileOutput[ i.x ][ thread.x ] += r;
+ }
+}
+#else
+void multiplyTiles( const uint3 thread )
+{
+ for( uint row = thread.y; row < TILE_SIZE; row += THREADS_Y )
+ {
+ float r = 0.0;
+ for( uint j = 0; j < TILE_HEIGHT; j++ )
+ {
+ float a = tile0[ j ][ thread.x ];
+ float b = tile1[ j ][ row ];
+ r = mad( a, b, r );
+ }
+ tileOutput[ row ][ thread.x ] += r;
+ }
+}
+#endif
+
+void storeTile( const uint3 thread, const uint4 pos, const uint2 size )
+{
+ if( thread.x >= size.x )
+ return;
+ const uint4 prod4 = pos * resultStrides;
+ const uint2 prod2 = prod4.xy + prod4.zw;
+ uint rdi = prod2.x + prod2.y;
+ rdi += resultStrides.y * thread.y;
+ rdi += thread.x;
+ for( uint i = thread.y; i < size.y; i += THREADS_Y, rdi += resultStrides.y * THREADS_Y )
+ result[ rdi ] = tileOutput[ i ][ thread.x ];
+}
+
+[numthreads( TILE_SIZE, THREADS_Y, 1 )]
+void main( const uint3 group: SV_GroupID, const uint3 thread : SV_GroupThreadID )
+{
+ uint i;
+ // Zero all 3 shared buffers
+ for( i = thread.y; i < TILE_SIZE; i += THREADS_Y )
+ tileOutput[ i ][ thread.x ] = 0.0;
+ for( i = thread.y; i < TILE_HEIGHT; i += THREADS_Y )
+ {
+ tile0[ i ][ thread.x ] = 0.0;
+#if !STREAM_SECOND_MATRIX
+ tile1[ i ][ thread.x ] = 0.0;
+#endif
+ }
+
+ const uint2 layer = uint2( group.z % resultSize.z, group.z / resultSize.z );
+
+ uint rsi0 = group.x * arg0panel + layer.x * arg0LayerStrides.x + layer.y * arg0LayerStrides.y;
+ uint rsi1 = group.y * arg1panel + layer.x * arg1LayerStrides.x + layer.y * arg1LayerStrides.y;
+
+ const uint threadOffset = thread.y * TILE_SIZE + thread.x;
+ rsi0 += threadOffset;
+#if STREAM_SECOND_MATRIX
+ rsi1 += thread.y;
+#else
+ rsi1 += threadOffset;
+#endif
+
+ const uint completeTiles = arg0Size.x / TILE_HEIGHT;
+ for( i = 0; i < completeTiles; i++ )
+ {
+ // Load [ TILE_SIZE, TILE_HEIGHT ] block from both source tensors into these groupshared buffers
+ for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y )
+ {
+ tile0[ j ][ thread.x ] = arg0[ rsi0 ];
+ rsi0 += THREADS_Y * TILE_SIZE;
+#if !STREAM_SECOND_MATRIX
+ tile1[ j ][ thread.x ] = arg1[ rsi1 ];
+ rsi1 += THREADS_Y * TILE_SIZE;
+#endif
+ }
+
+ // Wait for all threads in the group to complete these loads
+ GroupMemoryBarrierWithGroupSync();
+
+#if STREAM_SECOND_MATRIX
+ multiplyTiles( thread, rsi1, TILE_HEIGHT );
+ rsi1 += TILE_HEIGHT * TILE_SIZE;
+#else
+ // Multiply + accumulate the elements collected in the groupshared buffers
+ multiplyTiles( thread );
+#endif
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ const uint rem = arg0Size.x % TILE_HEIGHT;
+ if( rem != 0 )
+ {
+ // Load [ TILE_SIZE, rem ] block from both source tensors, and zero out the padding elements
+ for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y )
+ {
+ [branch]
+ if( j < rem )
+ {
+ tile0[ j ][ thread.x ] = arg0[ rsi0 ];
+ rsi0 += THREADS_Y * TILE_SIZE;
+#if !STREAM_SECOND_MATRIX
+ tile1[ j ][ thread.x ] = arg1[ rsi1 ];
+ rsi1 += THREADS_Y * TILE_SIZE;
+#endif
+ }
+ else
+ {
+ tile0[ j ][ thread.x ] = 0.0;
+#if !STREAM_SECOND_MATRIX
+ tile1[ j ][ thread.x ] = 0.0;
+#endif
+ }
+ }
+
+ // Wait for all threads in the group to complete these loads
+ GroupMemoryBarrierWithGroupSync();
+
+ // Multiply + accumulate the elements collected in the groupshared buffers
+#if STREAM_SECOND_MATRIX
+ multiplyTiles( thread, rsi1, rem );
+#else
+ multiplyTiles( thread );
+#endif
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ const uint2 resultPos = group.xy * TILE_SIZE;
+ const uint2 outputSize = min( TILE_SIZE, resultSize.xy - resultPos );
+ storeTile( thread, uint4( resultPos, layer ), outputSize );
+} \ No newline at end of file
diff --git a/ComputeShaders/norm.hlsl b/ComputeShaders/norm.hlsl
new file mode 100644
index 0000000..eeb82b7
--- /dev/null
+++ b/ComputeShaders/norm.hlsl
@@ -0,0 +1,86 @@
+// Ported from ggml_compute_forward_norm_f32
+// Dispatch [ ne01, ne02, ne03 ] thread groups of this shader
+Buffer<float> arg0: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ uint4 result_strides: packoffset( c3 );
+}
+
+static const float eps = 1e-5f; // TODO: make this a parameter
+
+#include "groupReduce.hlsli"
+
+float computeVectorSum( uint i, const uint length, const uint thread )
+{
+ float res = 0.0;
+
+ const uint iEnd = i + length;
+ i += thread;
+ for( ; i < iEnd; i += 32 )
+ res += arg0[ i ];
+
+ horizontalSumBroadcast( thread, res );
+ return res;
+}
+
+float offsetAndComputeSumSquares( uint rsi, uint rdi, const float mean, const uint length, const uint thread )
+{
+ float sum2 = 0.0;
+
+ const uint rsiEnd = rsi + length;
+ rsi += thread;
+ rdi += thread;
+ for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
+ {
+ float v = arg0[ rsi ] - mean;
+ result[ rdi ] = v;
+ sum2 = mad( v, v, sum2 );
+ }
+
+ horizontalSumBroadcast( thread, sum2 );
+ return sum2;
+}
+
+void scaleVector( uint rdi, const float scale, const uint length, const uint thread )
+{
+ const uint rdiEnd = rdi + length;
+ for( rdi += thread; rdi < rdiEnd; rdi += 32 )
+ {
+ float f = result[ rdi ];
+ f *= scale;
+ result[ rdi ] = f;
+ }
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint i03 = group.z;
+ const uint i02 = group.y;
+ const uint i01 = group.x;
+
+ const uint nb01 = src0_strides[ 1 ];
+ const uint nb02 = src0_strides[ 2 ];
+ const uint nb03 = src0_strides[ 3 ];
+
+ const uint p = i01 * nb01 + i02 * nb02 + i03 * nb03;
+
+ const uint ne00 = src0_elements[ 0 ];
+
+ float mean = computeVectorSum( p, ne00, thread );
+ mean /= (float)(int)ne00;
+
+ const uint nb1 = result_strides[ 1 ];
+ const uint nb2 = result_strides[ 2 ];
+ const uint nb3 = result_strides[ 3 ];
+ const uint y = i01 * nb1 + i02 * nb2 + i03 * nb3;
+
+ float sum2 = offsetAndComputeSumSquares( p, y, mean, ne00, thread );
+ const float scale = 1.0 / sqrt( sum2 / (float)(int)ne00 + eps );
+
+ scaleVector( y, scale, ne00, thread );
+} \ No newline at end of file
diff --git a/ComputeShaders/normCompat.hlsl b/ComputeShaders/normCompat.hlsl
new file mode 100644
index 0000000..23e1228
--- /dev/null
+++ b/ComputeShaders/normCompat.hlsl
@@ -0,0 +1,82 @@
+// Ported from ggml_compute_forward_norm_f32
+// Dispatch [ ( ne01 + 31 ) / 32, ne02, ne03 ] thread groups of this shader
+Buffer<float> arg0: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ uint4 result_strides: packoffset( c3 );
+}
+
+static const double eps = 1e-5; // TODO: make this a parameter
+
+#include "groupReduce.hlsli"
+
+double computeVectorSum( uint i, const uint length )
+{
+ double res = 0.0;
+ const uint iEnd = i + length;
+ for( ; i < iEnd; i++ )
+ res += arg0[ i ];
+ return res;
+}
+
+double offsetAndComputeSumSquares( uint rsi, uint rdi, const double mean, const uint length )
+{
+ precise double sum2 = 0.0;
+ const uint rsiEnd = rsi + length;
+ for( ; rsi < rsiEnd; rsi++, rdi++ )
+ {
+ double v = arg0[ rsi ];
+ v -= mean;
+ result[ rdi ] = (float)v;
+ double prod = v * v;
+ sum2 += prod;
+ }
+ return sum2;
+}
+
+void scaleVector( uint rdi, const float scale, const uint length )
+{
+ const uint rdiEnd = rdi + length;
+ for( ; rdi < rdiEnd; rdi++ )
+ {
+ float f = result[ rdi ];
+ f *= scale;
+ result[ rdi ] = f;
+ }
+}
+
+#include "fp64Utils.hlsli"
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 dtid: SV_DispatchThreadID )
+{
+ const uint i03 = dtid.z;
+ const uint i02 = dtid.y;
+ const uint i01 = dtid.x;
+ if( i01 >= src0_elements[ 1 ] )
+ return;
+
+ const uint nb01 = src0_strides[ 1 ];
+ const uint nb02 = src0_strides[ 2 ];
+ const uint nb03 = src0_strides[ 3 ];
+
+ const uint p = i01 * nb01 + i02 * nb02 + i03 * nb03;
+ const uint ne00 = src0_elements[ 0 ];
+
+ double mean = computeVectorSum( p, ne00 );
+ mean = div64( mean, (double)(int)ne00 );
+
+ const uint nb1 = result_strides[ 1 ];
+ const uint nb2 = result_strides[ 2 ];
+ const uint nb3 = result_strides[ 3 ];
+ const uint y = i01 * nb1 + i02 * nb2 + i03 * nb3;
+
+ const double sum2 = offsetAndComputeSumSquares( p, y, mean, ne00 );
+ const float scale = (float)div64( 1.0, sqrt64( sum2 / (float)(int)ne00 + eps ) );
+
+ scaleVector( y, scale, ne00 );
+} \ No newline at end of file
diff --git a/ComputeShaders/normFixed.hlsl b/ComputeShaders/normFixed.hlsl
new file mode 100644
index 0000000..8f2267f
--- /dev/null
+++ b/ComputeShaders/normFixed.hlsl
@@ -0,0 +1,124 @@
+// Ported from ggml_compute_forward_norm_f32
+// Dispatch [ ne01, ne02, ne03 ] thread groups of this shader
+Buffer<float> arg0: register( t0 );
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ uint4 result_strides: packoffset( c3 );
+}
+
+static const float eps = 1e-5f; // TODO: make this a parameter
+
+// #include "groupReduce.hlsli"
+
+#ifndef THREADS
+static const uint THREADS = 32;
+#endif
+static const uint ROW_LENGTH = 1024;
+groupshared float rowBuffer[ ROW_LENGTH ];
+
+static const uint REDUCTION_BUFFER = 32;
+groupshared float sharedAccumulators[ REDUCTION_BUFFER ];
+
+// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group.
+void horizontalSum( const uint thread, inout float sum )
+{
+ if( THREADS > REDUCTION_BUFFER )
+ {
+ for( uint t = REDUCTION_BUFFER; t < THREADS; t += REDUCTION_BUFFER )
+ {
+ // Threads [ t .. t + REDUCTION_BUFFER ] store into the buffer
+ if( thread >= t && thread < t + REDUCTION_BUFFER )
+ sharedAccumulators[ thread - t ] = sum;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ // Threads [ 0 .. REDUCTION_BUFFER ] increment their local sum with the value loaded from the buffer
+ if( thread < REDUCTION_BUFFER )
+ sum += sharedAccumulators[ thread ];
+ }
+ }
+
+ if( thread < REDUCTION_BUFFER )
+ sharedAccumulators[ thread ] = sum;
+
+ for( uint i = REDUCTION_BUFFER / 2; i > 1; i /= 2 )
+ {
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < i )
+ {
+ sum += sharedAccumulators[ thread + i ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+ if( 0 == thread )
+ sum += sharedAccumulators[ 1 ];
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint i03 = group.z;
+ const uint i02 = group.y;
+ const uint i01 = group.x;
+ const uint ne00 = ROW_LENGTH;
+
+ // First pass: copy the data to local buffer, and compute sum
+ {
+ const uint nb01 = src0_strides[ 1 ];
+ const uint nb02 = src0_strides[ 2 ];
+ const uint nb03 = src0_strides[ 3 ];
+ const uint p = i01 * nb01 + i02 * nb02 + i03 * nb03;
+
+ float sum = 0;
+ for( uint i = thread; i < ne00; i += THREADS )
+ {
+ float f = arg0[ p + i ];
+ rowBuffer[ i ] = f;
+ sum += f;
+ }
+ horizontalSum( thread, sum );
+ if( 0 == thread )
+ sharedAccumulators[ 0 ] = sum / (float)(int)ne00;
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ // Second pass: offset and compute sum of squares
+ {
+ const float mean = sharedAccumulators[ 0 ];
+ float sum2 = 0;
+ for( uint i = thread; i < ne00; i += THREADS )
+ {
+ float v = rowBuffer[ i ];
+ v -= mean;
+ rowBuffer[ i ] = v;
+ sum2 = mad( v, v, sum2 );
+ }
+ horizontalSum( thread, sum2 );
+ if( 0 == thread )
+ sharedAccumulators[ 0 ] = 1.0 / sqrt( sum2 / (float)(int)ne00 + eps );
+ GroupMemoryBarrierWithGroupSync();
+ }
+
+ // Final pass: apply the scale, and copy from group shared buffer to the destination
+ {
+ const float scale = sharedAccumulators[ 0 ];
+
+ const uint nb1 = result_strides[ 1 ];
+ const uint nb2 = result_strides[ 2 ];
+ const uint nb3 = result_strides[ 3 ];
+ const uint y = i01 * nb1 + i02 * nb2 + i03 * nb3;
+
+ for( uint i = thread; i < ne00; i += THREADS )
+ {
+ float v = rowBuffer[ i ];
+ v *= scale;
+ result[ y + i ] = v;
+ }
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/normFixed64.hlsl b/ComputeShaders/normFixed64.hlsl
new file mode 100644
index 0000000..14aab3d
--- /dev/null
+++ b/ComputeShaders/normFixed64.hlsl
@@ -0,0 +1,2 @@
+#define THREADS 64
+#include "normFixed.hlsl" \ No newline at end of file
diff --git a/ComputeShaders/repeatUtils.hlsli b/ComputeShaders/repeatUtils.hlsli
new file mode 100644
index 0000000..1181501
--- /dev/null
+++ b/ComputeShaders/repeatUtils.hlsli
@@ -0,0 +1,21 @@
+inline uint rowOffset( uint3 idx, uint4 strides )
+{
+ return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ];
+}
+
+// Initial iterator state for a row of the output tensor
+// x = current index, y = index increment, z = end of the index
+inline uint3 tensorIteratorState( uint3 group, uint thread, uint4 size, uint4 stride )
+{
+ uint3 res;
+ res.x = rowOffset( group, stride );
+ res.y = THREADS * stride[ 0 ];
+ res.z = res.x + size[ 0 ] * stride[ 0 ];
+ res.x += thread * stride[ 0 ];
+ return res;
+}
+
+// Handle a complete row of output tensor, using the iterator made by tensorIteratorState() function
+#define ROW_LOOP( ts ) for( ; ts.x < ts.z; ts.x += ts.y )
+// Same as above, using different row length
+#define ROW_LOOP_EX( ts, len, stride ) for( ; ts.x < ts.z; ts.x += len * stride[ 0 ] ) \ No newline at end of file
diff --git a/ComputeShaders/scaleInPlace.hlsl b/ComputeShaders/scaleInPlace.hlsl
new file mode 100644
index 0000000..d0320b4
--- /dev/null
+++ b/ComputeShaders/scaleInPlace.hlsl
@@ -0,0 +1,23 @@
+RWBuffer<float> buffer: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 src0_elements: packoffset( c0 );
+ uint4 src0_strides: packoffset( c1 );
+ float multiplier: packoffset( c2.x );
+}
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint nc0 = src0_elements[ 0 ];
+ uint i = group.x * src0_strides[ 1 ];
+ const uint iEnd = i + nc0;
+ const float mul = multiplier;
+ for( i += thread; i < iEnd; i += 32 )
+ {
+ float f = buffer[ i ];
+ f *= mul;
+ buffer[ i ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/softMax.hlsl b/ComputeShaders/softMax.hlsl
new file mode 100644
index 0000000..6ebd0f2
--- /dev/null
+++ b/ComputeShaders/softMax.hlsl
@@ -0,0 +1,71 @@
+// Dispatch [ nr, 1, 1 ] thread groups of this shader
+RWBuffer<float> result: register( u0 );
+
+// table_exp_f16
+Buffer<uint> lookupTable: register( t0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 elements: packoffset( c0 );
+ uint4 strides: packoffset( c1 );
+ uint nr: packoffset( c2.x );
+ float inputScale: packoffset( c2.y );
+}
+
+#include "miscUtils.hlsli"
+#include "groupReduce.hlsli"
+
+static const float negativeInfinity = asfloat( 0xff800000 );
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint p = group.x * strides[ 1 ];
+ const uint nc = elements[ 0 ];
+ const uint pEnd = p + nc;
+ uint i;
+
+ float m = negativeInfinity;
+ for( i = p + thread; i < pEnd; i += 32 )
+ m = max( m, result[ i ] );
+ horizontalMaxBroadcast( thread, m );
+
+ float sum = 0;
+ for( i = p + thread; i < pEnd; i += 32 )
+ {
+ float f = result[ i ];
+
+ [branch]
+ if( f != negativeInfinity )
+ {
+ f = ( f - m ) * inputScale;
+#if 1
+ // Similar to Radeon Graphics, computing the exponent on nVidia 1080Ti is also slightly faster than loading from the lookup table
+ f = exp( f );
+#else
+ uint s = fp16Rounded( f );
+ s = lookupTable[ s ];
+ f = f16tof32( s );
+#endif
+ sum += f;
+ }
+ else
+ f = 0;
+
+ result[ i ] = f;
+ }
+
+ horizontalSum( thread, sum );
+ if( 0 == thread )
+ sharedAccumulators[ 0 ] = 1.0 / sum;
+ GroupMemoryBarrierWithGroupSync();
+ const float scale = sharedAccumulators[ 0 ];
+
+ // ggml_vec_scale_f32
+ for( i = p + thread; i < pEnd; i += 32 )
+ {
+ float f = result[ i ];
+ f *= scale;
+ result[ i ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/softMax64.hlsl b/ComputeShaders/softMax64.hlsl
new file mode 100644
index 0000000..7ecd2ef
--- /dev/null
+++ b/ComputeShaders/softMax64.hlsl
@@ -0,0 +1,71 @@
+// Dispatch [ nr, 1, 1 ] thread groups of this shader
+RWBuffer<float> result: register( u0 );
+
+// table_exp_f16
+Buffer<uint> lookupTable: register( t0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 elements: packoffset( c0 );
+ uint4 strides: packoffset( c1 );
+ uint nr: packoffset( c2.x );
+ float inputScale: packoffset( c2.y );
+}
+
+#include "miscUtils.hlsli"
+#include "groupReduce64.hlsli"
+
+static const float negativeInfinity = asfloat( 0xff800000 );
+
+[ numthreads( 64, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint p = group.x * strides[ 1 ];
+ const uint nc = elements[ 0 ];
+ const uint pEnd = p + nc;
+ uint i;
+
+ float m = negativeInfinity;
+ for( i = p + thread; i < pEnd; i += 64 )
+ m = max( m, result[ i ] );
+ horizontalMaxBroadcast( thread, m );
+
+ float sum = 0;
+ for( i = p + thread; i < pEnd; i += 64 )
+ {
+ float f = result[ i ];
+
+ [branch]
+ if( f != negativeInfinity )
+ {
+ f = ( f - m ) * inputScale;
+#if 1
+ // At least on Radeon Graphics GPU inside Ryzen 7 5700G, computing exponent instead of loading from the buffer improves the performance
+ f = exp( f );
+#else
+ uint s = fp16Rounded( f );
+ s = lookupTable[ s ];
+ f = f16tof32( s );
+#endif
+ sum += f;
+ }
+ else
+ f = 0;
+
+ result[ i ] = f;
+ }
+
+ horizontalSum( thread, sum );
+ if( 0 == thread )
+ sharedAccumulators[ 0 ] = 1.0 / sum;
+ GroupMemoryBarrierWithGroupSync();
+ const float scale = sharedAccumulators[ 0 ];
+
+ // ggml_vec_scale_f32
+ for( i = p + thread; i < pEnd; i += 64 )
+ {
+ float f = result[ i ];
+ f *= scale;
+ result[ i ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/softMaxCompat.hlsl b/ComputeShaders/softMaxCompat.hlsl
new file mode 100644
index 0000000..2215ebd
--- /dev/null
+++ b/ComputeShaders/softMaxCompat.hlsl
@@ -0,0 +1,62 @@
+// ggml_compute_forward_soft_max_f32
+// Dispatch [ ( nr + 31 ) / 32, 1, 1 ] thread groups of this shader
+RWBuffer<float> result: register( u0 );
+
+// table_exp_f16
+Buffer<uint> lookupTable: register( t0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 elements: packoffset( c0 );
+ uint4 strides: packoffset( c1 );
+ uint nr: packoffset( c2.x );
+}
+
+#include "miscUtils.hlsli"
+#include "fp64Utils.hlsli"
+
+static const float negativeInfinity = asfloat( 0xff800000 );
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 dtid: SV_DispatchThreadID )
+{
+ if( dtid.x >= nr )
+ return;
+
+ const uint p = dtid.x * strides[ 1 ];
+ const uint nc = elements[ 0 ];
+ const uint pEnd = p + nc;
+ uint i;
+
+ float m = negativeInfinity;
+ for( i = p; i < pEnd; i++ )
+ m = max( m, result[ i ] );
+
+ double sum = 0;
+ for( i = p; i < pEnd; i++ )
+ {
+ float f = result[ i ];
+
+ [branch]
+ if( f != negativeInfinity )
+ {
+ uint s = fp16Rounded( f - m );
+ s = lookupTable[ s ];
+ f = f16tof32( s );
+ sum += f;
+ }
+ else
+ f = 0;
+
+ result[ i ] = f;
+ }
+
+ const float scale = (float)div64( 1.0, sum );
+ // ggml_vec_scale_f32
+ for( i = p; i < pEnd; i++ )
+ {
+ float f = result[ i ];
+ f *= scale;
+ result[ i ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/softMaxFixed.hlsl b/ComputeShaders/softMaxFixed.hlsl
new file mode 100644
index 0000000..7b4add2
--- /dev/null
+++ b/ComputeShaders/softMaxFixed.hlsl
@@ -0,0 +1,79 @@
+// Special softMax shader for matrices with rows of 1500 elements.
+// Uses group shared buffer of that length to save global memory bandwidth, more than 2x faster than the original.
+// Dispatch [ nr, 1, 1 ] thread groups of this shader
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 elements: packoffset( c0 );
+ uint4 strides: packoffset( c1 );
+ uint nr: packoffset( c2.x );
+ float inputScale: packoffset( c2.y );
+}
+
+#include "miscUtils.hlsli"
+#include "groupReduce64.hlsli"
+
+static const uint THREADS = 64;
+static const uint ROW_LENGTH = 1500;
+groupshared float rowBuffer[ ROW_LENGTH ];
+
+static const float negativeInfinity = asfloat( 0xff800000 );
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint p = group.x * strides[ 1 ];
+ const uint nc = ROW_LENGTH;
+ uint i;
+
+ float m = negativeInfinity;
+ // First pass: compute maximum, and copy the row into the group shared buffer
+ for( i = thread; i < nc; i += THREADS )
+ {
+ float f = result[ p + i ];
+ m = max( m, f );
+ rowBuffer[ i ] = f;
+ }
+ horizontalMaxBroadcast( thread, m );
+
+ // Second pass: apply initial scale, compute the exponent, and compute total sum over the row
+ float sum = 0;
+ for( i = thread; i < nc; i += THREADS )
+ {
+ float f = rowBuffer[ i ];
+
+ [branch]
+ if( f != negativeInfinity )
+ {
+ f = ( f - m ) * inputScale;
+#if 1
+ // At least on Radeon Graphics GPU inside Ryzen 7 5700G, computing exponent instead of loading from the buffer improves the performance
+ f = exp( f );
+#else
+ uint s = fp16Rounded( f );
+ s = lookupTable[ s ];
+ f = f16tof32( s );
+#endif
+ sum += f;
+ }
+ else
+ f = 0;
+
+ rowBuffer[ i ] = f;
+ }
+
+ horizontalSum( thread, sum );
+ if( 0 == thread )
+ sharedAccumulators[ 0 ] = 1.0 / sum;
+ GroupMemoryBarrierWithGroupSync();
+ const float scale = sharedAccumulators[ 0 ];
+
+ // Final pass: apply the final scale, and copy the row from the group shared buffer back into the global memory
+ for( i = thread; i < nc; i += THREADS )
+ {
+ float f = rowBuffer[ i ];
+ f *= scale;
+ result[ p + i ] = f;
+ }
+} \ No newline at end of file
diff --git a/ComputeShaders/zeroMemory.hlsl b/ComputeShaders/zeroMemory.hlsl
new file mode 100644
index 0000000..c486636
--- /dev/null
+++ b/ComputeShaders/zeroMemory.hlsl
@@ -0,0 +1,27 @@
+RWBuffer<float> result: register( u0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint elements: packoffset( c0.x );
+}
+
+// Thread group index is 16 bits per coordinate:
+// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/nf-d3d11-id3d11devicecontext-dispatch
+// We want this shader to support buffers up to 2 GB.
+#ifndef THREADS
+static const uint THREADS = 512;
+#endif
+#ifndef ITERATIONS
+static const uint ITERATIONS = 128;
+#endif
+
+static const uint itemsPerGroup = THREADS * ITERATIONS;
+
+[numthreads( THREADS, 1, 1 )]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint rdi = group.x * itemsPerGroup;
+ const uint rdiEnd = min( rdi + itemsPerGroup, elements );
+ for( rdi += thread; rdi < rdiEnd; rdi += THREADS )
+ result[ rdi ] = 0.0;
+} \ No newline at end of file
diff --git a/Examples/MicrophoneCS/CaptureThread.cs b/Examples/MicrophoneCS/CaptureThread.cs
new file mode 100644
index 0000000..b76a929
--- /dev/null
+++ b/Examples/MicrophoneCS/CaptureThread.cs
@@ -0,0 +1,61 @@
+using System.Runtime.ExceptionServices;
+using Whisper;
+
+namespace MicrophoneCS
+{
+ sealed class CaptureThread: CaptureCallbacks
+ {
+ public CaptureThread( CommandLineArgs args, Context context, iAudioCapture source )
+ {
+ callbacks = new TranscribeCallbacks( args );
+ this.context = context;
+ this.source = source;
+
+ thread = new Thread( threadMain ) { Name = "Capture Thread" };
+ Console.WriteLine( "Press any key to quit" );
+ thread.Start();
+ }
+
+ static void readKeyCallback( object? state )
+ {
+ CaptureThread ct = ( state as CaptureThread ) ?? throw new ApplicationException();
+ Console.ReadKey();
+ ct.shouldQuit = true;
+ }
+
+ public void join()
+ {
+ ThreadPool.QueueUserWorkItem( readKeyCallback, this );
+ thread.Join();
+ edi?.Throw();
+ }
+
+ volatile bool shouldQuit = false;
+
+ protected override bool shouldCancel( Context sender ) =>
+ shouldQuit;
+
+ protected override void captureStatusChanged( Context sender, eCaptureStatus status )
+ {
+ Console.WriteLine( $"CaptureStatusChanged: {status}" );
+ }
+
+ readonly TranscribeCallbacks callbacks;
+ readonly Thread thread;
+ readonly Context context;
+ readonly iAudioCapture source;
+ ExceptionDispatchInfo? edi = null;
+
+ void threadMain()
+ {
+ try
+ {
+ context.runCapture( source, callbacks, this );
+ }
+ catch( Exception ex )
+ {
+ edi = ExceptionDispatchInfo.Capture( ex );
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Examples/MicrophoneCS/CommandLineArgs.cs b/Examples/MicrophoneCS/CommandLineArgs.cs
new file mode 100644
index 0000000..be5fbe9
--- /dev/null
+++ b/Examples/MicrophoneCS/CommandLineArgs.cs
@@ -0,0 +1,145 @@
+using System.Globalization;
+using System.Reflection;
+using Whisper;
+
+namespace MicrophoneCS
+{
+ sealed record class CommandLineArgs
+ {
+ public int n_threads = Environment.ProcessorCount;
+ public int offset_t_ms = 0;
+ public int offset_n = 0;
+ public int duration_ms = 0;
+ public int max_context = -1;
+ public int max_len = 0;
+
+ public float word_thold = 0.01f;
+
+ public bool speed_up = false;
+ public bool translate = false;
+ public bool diarize = false;
+ public bool output_txt = false;
+ public bool print_special = false;
+ public bool print_progress = false;
+ public bool print_colors = true;
+ public bool no_timestamps = false;
+ public int[]? prompt = null;
+ public int captureDeviceIndex = 0;
+
+ public eLanguage language = eLanguage.English;
+ public string model = string.Empty;
+
+ const bool output_wts = false;
+ public bool listDevices = false;
+
+ public void apply( ref Parameters p )
+ {
+ p.setFlag( eFullParamsFlags.PrintRealtime, false );
+ p.setFlag( eFullParamsFlags.PrintProgress, print_progress );
+ p.setFlag( eFullParamsFlags.PrintTimestamps, !no_timestamps );
+ p.setFlag( eFullParamsFlags.PrintSpecial, print_special );
+ p.setFlag( eFullParamsFlags.Translate, translate );
+ p.language = language;
+ p.cpuThreads = n_threads;
+ if( max_context >= 0 )
+ p.n_max_text_ctx = max_context;
+ p.offset_ms = offset_t_ms;
+ p.duration_ms = duration_ms;
+ p.setFlag( eFullParamsFlags.TokenTimestamps, output_wts || max_len > 0 );
+ p.thold_pt = word_thold;
+ p.max_len = output_wts && max_len == 0 ? 60 : max_len;
+ p.setFlag( eFullParamsFlags.SpeedupAudio, speed_up );
+ }
+
+ public eResultFlags resultFlags()
+ {
+ eResultFlags flags = eResultFlags.None;
+ bool wts = output_wts || max_len > 0;
+ if( !no_timestamps || wts )
+ flags |= eResultFlags.Timestamps;
+ if( wts || print_colors )
+ flags |= eResultFlags.Tokens;
+ return flags;
+ }
+
+ static eLanguage parseLanguage( string lang ) =>
+ Library.languageFromCode( lang ) ?? throw new ArgumentException( $"Unknown language code \"{lang}\"" );
+
+ public CommandLineArgs( string[] argv )
+ {
+ for( int i = 0; i < argv.Length; i++ )
+ {
+ string arg = argv[ i ];
+ if( arg == "-h" || arg == "--help" )
+ {
+ printUsage();
+ throw new OperationCanceledException();
+ }
+ else if( arg == "-c" || arg == "--capture" ) captureDeviceIndex = int.Parse( argv[ ++i ] );
+ else if( arg == "-ld" || arg == "--list-devices" ) listDevices = true;
+ else if( arg == "-t" || arg == "--threads" ) n_threads = int.Parse( argv[ ++i ] );
+ else if( arg == "-ot" || arg == "--offset-t" ) offset_t_ms = int.Parse( argv[ ++i ] );
+ else if( arg == "-on" || arg == "--offset-n" ) offset_n = int.Parse( argv[ ++i ] );
+ else if( arg == "-d" || arg == "--duration" ) duration_ms = int.Parse( argv[ ++i ] );
+ else if( arg == "-mc" || arg == "--max-context" ) max_context = int.Parse( argv[ ++i ] );
+ else if( arg == "-ml" || arg == "--max-len" ) max_len = int.Parse( argv[ ++i ] );
+ else if( arg == "-wt" || arg == "--word-thold" ) word_thold = float.Parse( argv[ ++i ], CultureInfo.InvariantCulture );
+ else if( arg == "-su" || arg == "--speed-up" ) speed_up = true;
+ else if( arg == "-tr" || arg == "--translate" ) translate = true;
+ else if( arg == "-di" || arg == "--diarize" ) diarize = true;
+ else if( arg == "-otxt" || arg == "--output-txt" ) output_txt = true;
+ else if( arg == "-ps" || arg == "--print-special" ) print_special = true;
+ else if( arg == "-nc" || arg == "--no-colors" ) print_colors = false;
+ else if( arg == "-pp" || arg == "--print-progress" ) print_progress = true;
+ else if( arg == "-nt" || arg == "--no-timestamps" ) no_timestamps = true;
+ else if( arg == "-l" || arg == "--language" ) language = parseLanguage( argv[ ++i ] );
+ else if( arg == "--prompt" ) prompt = parsePrompt( argv[ ++i ] );
+ else if( arg == "-m" || arg == "--model" ) model = argv[ ++i ];
+ else
+ throw new ArgumentException( $"Unknown argument: \"{arg}\"" );
+ }
+ if( string.IsNullOrWhiteSpace( model ) )
+ throw new ArgumentException( "The model file is not provided in the arguments" );
+ if( !File.Exists( model ) )
+ throw new FileNotFoundException( "Model not found", model );
+ }
+
+ static string cstr( bool b ) => b.ToString();
+
+ static int[]? parsePrompt( string str )
+ {
+ if( string.IsNullOrWhiteSpace( str ) )
+ return null;
+ // TODO: expose whisper_tokenize function, as a method of iModel COM interface
+ throw new NotImplementedException();
+ }
+
+ void printUsage()
+ {
+ Console.WriteLine();
+
+ Console.WriteLine( "usage: {0} [options] file0.mp3 file1.wma ...", Path.GetFileName( Assembly.GetExecutingAssembly().Location ) );
+ Console.WriteLine();
+ Console.WriteLine( "options:" );
+ Console.WriteLine( " -h, --help [default] show this help message and exit" );
+ Console.WriteLine( " -t N, --threads N [{0,-7:D}] number of threads to use during computation", n_threads );
+ Console.WriteLine( " -ot N, --offset-t N [{0,-7:D}] time offset in milliseconds", offset_t_ms );
+ Console.WriteLine( " -on N, --offset-n N [{0,-7:D}] segment index offset", offset_n );
+ Console.WriteLine( " -d N, --duration N [{0,-7:D}] duration of audio to process in milliseconds", duration_ms );
+ Console.WriteLine( " -mc N, --max-context N [{0,-7:D}] maximum number of text context tokens to store", max_context );
+ Console.WriteLine( " -ml N, --max-len N [{0,-7:D}] maximum segment length in characters", max_len );
+ Console.WriteLine( " -wt N, --word-thold N [{0,-7:F2}] word timestamp probability threshold", word_thold );
+ Console.WriteLine( " -su, --speed-up [{0,-7}] speed up audio by x2 (reduced accuracy)", cstr( speed_up ) );
+ Console.WriteLine( " -tr, --translate [{0,-7}] translate from source language to english", cstr( translate ) );
+ Console.WriteLine( " -di, --diarize [{0,-7}] stereo audio diarization", cstr( diarize ) );
+ Console.WriteLine( " -otxt, --output-txt [{0,-7}] output result in a text file", cstr( output_txt ) );
+ Console.WriteLine( " -ps, --print-special [{0,-7}] print special tokens", cstr( print_special ) );
+ Console.WriteLine( " -nc, --no-colors [{0,-7}] do not print colors", cstr( !print_colors ) );
+ Console.WriteLine( " -nt, --no-timestamps [{0,-7}] do not print timestamps", cstr( no_timestamps ) );
+ Console.WriteLine( " -l LANG, --language LANG [{0,-7}] spoken language", language.getCode() );
+ Console.WriteLine( " --prompt PROMPT [ ] initial prompt" );
+ Console.WriteLine( " -m FNAME, --model FNAME [{0,-7}] model path", model );
+ Console.WriteLine( " -f FNAME, --file FNAME [{0,-7}] path of the input audio file", "" );
+ }
+ }
+} \ No newline at end of file
diff --git a/Examples/MicrophoneCS/MicrophoneCS.cs b/Examples/MicrophoneCS/MicrophoneCS.cs
new file mode 100644
index 0000000..c095ee1
--- /dev/null
+++ b/Examples/MicrophoneCS/MicrophoneCS.cs
@@ -0,0 +1,56 @@
+using Whisper;
+
+namespace MicrophoneCS
+{
+ static class Program
+ {
+ static int Main( string[] args )
+ {
+ try
+ {
+ CommandLineArgs cla;
+ try
+ {
+ cla = new CommandLineArgs( args );
+ }
+ catch( OperationCanceledException )
+ {
+ return 1;
+ }
+ const eLoggerFlags loggerFlags = eLoggerFlags.UseStandardError | eLoggerFlags.SkipFormatMessage;
+ Library.setLogSink( eLogLevel.Debug, loggerFlags );
+
+ using iMediaFoundation mf = Library.initMediaFoundation();
+ CaptureDeviceId[] devices = mf.listCaptureDevices() ??
+ throw new ApplicationException( "This computer has no audio capture devices" );
+
+ if( cla.listDevices )
+ {
+ for( int i = 0; i < devices.Length; i++ )
+ Console.WriteLine( "#{0}: {1}", i, devices[ i ].displayName );
+ return 0;
+ }
+ if( cla.captureDeviceIndex < 0 || cla.captureDeviceIndex >= devices.Length )
+ throw new ApplicationException( $"Capture device index is out of range; the valid range is [ 0 .. {devices.Length - 1} ]" );
+
+ using iAudioCapture captureDev = mf.openCaptureDevice( devices[ cla.captureDeviceIndex ] );
+
+ using iModel model = Library.loadModel( cla.model );
+ using Context context = model.createContext();
+ cla.apply( ref context.parameters );
+
+ CaptureThread thread = new CaptureThread( cla, context, captureDev );
+ thread.join();
+
+ context.timingsPrint();
+ return 0;
+ }
+ catch( Exception ex )
+ {
+ // Console.WriteLine( ex.Message );
+ Console.WriteLine( ex.ToString() );
+ return ex.HResult;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Examples/MicrophoneCS/MicrophoneCS.csproj b/Examples/MicrophoneCS/MicrophoneCS.csproj
new file mode 100644
index 0000000..f417d20
--- /dev/null
+++ b/Examples/MicrophoneCS/MicrophoneCS.csproj
@@ -0,0 +1,27 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+ <PropertyGroup>
+ <OutputType>Exe</OutputType>
+ <TargetFramework>net6.0-windows</TargetFramework>
+ <ImplicitUsings>enable</ImplicitUsings>
+ <Nullable>enable</Nullable>
+ <CheckForOverflowUnderflow>true</CheckForOverflowUnderflow>
+ <AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath>
+ <Platforms>x64</Platforms>
+ </PropertyGroup>
+
+ <ItemGroup>
+ <Compile Include="..\TranscribeCS\AnsiCodes.cs" Link="AnsiCodes.cs" />
+ </ItemGroup>
+
+ <ItemGroup>
+ <Content Include="..\..\x64\$(Configuration)\Whisper.dll" Link="Whisper.dll">
+ <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+ </Content>
+ </ItemGroup>
+
+ <ItemGroup>
+ <ProjectReference Include="..\..\WhisperNet\WhisperNet.csproj" />
+ </ItemGroup>
+
+</Project> \ No newline at end of file
diff --git a/Examples/MicrophoneCS/TranscribeCallbacks.cs b/Examples/MicrophoneCS/TranscribeCallbacks.cs
new file mode 100644
index 0000000..e4d14f4
--- /dev/null
+++ b/Examples/MicrophoneCS/TranscribeCallbacks.cs
@@ -0,0 +1,114 @@
+using System.Globalization;
+using Whisper;
+
+namespace MicrophoneCS
+{
+ /// <summary>Implementation of Callbacks abstract class, to print these segments as soon as they’re produced by the library.</summary>
+ sealed class TranscribeCallbacks: Callbacks
+ {
+ readonly CommandLineArgs args;
+ readonly eResultFlags resultFlags;
+
+ public TranscribeCallbacks( CommandLineArgs args )
+ {
+ this.args = args;
+ resultFlags = args.resultFlags();
+ Console.OutputEncoding = System.Text.Encoding.UTF8;
+ }
+
+ // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
+ // Lowest is red, middle is yellow, highest is green.
+ readonly string[] k_colors = new string[]
+ {
+ "\x1B[38;5;196m", "\x1B[38;5;202m", "\x1B[38;5;208m", "\x1B[38;5;214m", "\x1B[38;5;220m",
+ "\x1B[38;5;226m", "\x1B[38;5;190m", "\x1B[38;5;154m", "\x1B[38;5;118m", "\x1B[38;5;82m"
+ };
+
+ int colorIndex( in sToken tok )
+ {
+ float p = tok.probability;
+ float p3 = p * p * p;
+ int col = (int)( p3 * k_colors.Length );
+ col = Math.Clamp( col, 0, k_colors.Length - 1 );
+ return col;
+ }
+
+ public static string printTime( TimeSpan ts ) =>
+ ts.ToString( "hh':'mm':'ss'.'fff", CultureInfo.InvariantCulture );
+ public static string printTimeWithComma( TimeSpan ts ) =>
+ ts.ToString( "hh':'mm':'ss','fff", CultureInfo.InvariantCulture );
+
+ protected override void onNewSegment( Context sender, int countNew )
+ {
+ TranscribeResult res = sender.results( resultFlags );
+ ReadOnlySpan<sToken> tokens = res.tokens;
+
+ int s0 = res.segments.Length - countNew;
+ if( s0 == 0 )
+ Console.WriteLine();
+
+ for( int i = s0; i < res.segments.Length; i++ )
+ {
+ sSegment seg = res.segments[ i ];
+
+ if( args.no_timestamps )
+ {
+ if( args.print_colors && AnsiCodes.enabled )
+ {
+ foreach( sToken tok in res.getTokens( seg ) )
+ {
+ if( !args.print_special && tok.hasFlag( eTokenFlags.Special ) )
+ continue;
+ Console.Write( "{0}{1}{2}", k_colors[ colorIndex( tok ) ], tok.text, "\x1B[0m" );
+ }
+ }
+ else
+ Console.Write( seg.text );
+ Console.Out.Flush();
+ continue;
+ }
+
+ string speaker = "";
+#if false
+ if( args.diarize && pcmf32s.size() == 2 )
+ {
+ const size_t n_samples = pcmf32s[ 0 ].size();
+ const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples );
+ const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples );
+
+ double energy0 = 0.0f;
+ double energy1 = 0.0f;
+
+ for( int64_t j = is0; j < is1; j++ )
+ {
+ energy0 += fabs( pcmf32s[ 0 ][ j ] );
+ energy1 += fabs( pcmf32s[ 1 ][ j ] );
+ }
+
+ if( energy0 > 1.1 * energy1 )
+ speaker = "(speaker 0)";
+ else if( energy1 > 1.1 * energy0 )
+ speaker = "(speaker 1)";
+ else
+ speaker = "(speaker ?)";
+
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+ }
+#endif
+ if( args.print_colors && AnsiCodes.enabled )
+ {
+ Console.Write( "[{0} --> {1}] ", printTime( seg.time.begin ), printTime( seg.time.end ) );
+ foreach( sToken tok in res.getTokens( seg ) )
+ {
+ if( !args.print_special && tok.hasFlag( eTokenFlags.Special ) )
+ continue;
+ Console.Write( "{0}{1}{2}{3}", speaker, k_colors[ colorIndex( tok ) ], tok.text, "\x1B[0m" );
+ }
+ Console.WriteLine();
+ }
+ else
+ Console.WriteLine( "[{0} --> {1}] {2}{3}", printTime( seg.time.begin ), printTime( seg.time.end ), speaker, seg.text );
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Examples/OldMain/OldMain.vcxproj b/Examples/OldMain/OldMain.vcxproj
new file mode 100644
index 0000000..26f2c71
--- /dev/null
+++ b/Examples/OldMain/OldMain.vcxproj
@@ -0,0 +1,101 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>16.0</VCProjectVersion>
+ <Keyword>Win32Proj</Keyword>
+ <ProjectGuid>{596f9770-9aeb-49d3-86ca-4200197df12b}</ProjectGuid>
+ <RootNamespace>OldMain</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <IncludePath>$(SolutionDir)Whisper\Source\;$(IncludePath)</IncludePath>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <IncludePath>$(SolutionDir)Whisper\Source\;$(IncludePath)</IncludePath>
+ </PropertyGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <EnableEnhancedInstructionSet>AdvancedVectorExtensions2</EnableEnhancedInstructionSet>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <EnableEnhancedInstructionSet>AdvancedVectorExtensions2</EnableEnhancedInstructionSet>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <LinkTimeCodeGeneration>UseLinkTimeCodeGeneration</LinkTimeCodeGeneration>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ClCompile Include="..\..\Whisper\source\ggml.c">
+ <ExcludedFromBuild>true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\..\Whisper\source\ggmlMsvc.c" />
+ <ClCompile Include="..\..\Whisper\source\whisper.cpp" />
+ <ClCompile Include="main.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="..\..\Whisper\source\ggml.h" />
+ <ClInclude Include="..\..\Whisper\source\whisper.h" />
+ <ClInclude Include="dr_wav.h" />
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/OldMain/OldMain.vcxproj.filters b/Examples/OldMain/OldMain.vcxproj.filters
new file mode 100644
index 0000000..78f29f0
--- /dev/null
+++ b/Examples/OldMain/OldMain.vcxproj.filters
@@ -0,0 +1,14 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <ClCompile Include="..\..\Whisper\source\ggmlMsvc.c" />
+ <ClCompile Include="..\..\Whisper\source\whisper.cpp" />
+ <ClCompile Include="main.cpp" />
+ <ClCompile Include="..\..\Whisper\source\ggml.c" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="..\..\Whisper\source\ggml.h" />
+ <ClInclude Include="..\..\Whisper\source\whisper.h" />
+ <ClInclude Include="dr_wav.h" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/OldMain/dr_wav.h b/Examples/OldMain/dr_wav.h
new file mode 100644
index 0000000..fd3e95b
--- /dev/null
+++ b/Examples/OldMain/dr_wav.h
@@ -0,0 +1,6434 @@
+/*
+WAV audio loader and writer. Choice of public domain or MIT-0. See license statements at the end of this file.
+dr_wav - v0.12.16 - 2020-12-02
+
+David Reid - mackron@gmail.com
+
+GitHub: https://github.com/mackron/dr_libs
+*/
+
+/*
+RELEASE NOTES - VERSION 0.12
+============================
+Version 0.12 includes breaking changes to custom chunk handling.
+
+
+Changes to Chunk Callback
+-------------------------
+dr_wav supports the ability to fire a callback when a chunk is encounted (except for WAVE and FMT chunks). The callback has been updated to include both the
+container (RIFF or Wave64) and the FMT chunk which contains information about the format of the data in the wave file.
+
+Previously, there was no direct way to determine the container, and therefore no way to discriminate against the different IDs in the chunk header (RIFF and
+Wave64 containers encode chunk ID's differently). The `container` parameter can be used to know which ID to use.
+
+Sometimes it can be useful to know the data format at the time the chunk callback is fired. A pointer to a `drwav_fmt` object is now passed into the chunk
+callback which will give you information about the data format. To determine the sample format, use `drwav_fmt_get_format()`. This will return one of the
+`DR_WAVE_FORMAT_*` tokens.
+*/
+
+/*
+Introduction
+============
+This is a single file library. To use it, do something like the following in one .c file.
+
+ ```c
+ #define DR_WAV_IMPLEMENTATION
+ #include "dr_wav.h"
+ ```
+
+You can then #include this file in other parts of the program as you would with any other header file. Do something like the following to read audio data:
+
+ ```c
+ drwav wav;
+ if (!drwav_init_file(&wav, "my_song.wav", NULL)) {
+ // Error opening WAV file.
+ }
+
+ drwav_int32* pDecodedInterleavedPCMFrames = malloc(wav.totalPCMFrameCount * wav.channels * sizeof(drwav_int32));
+ size_t numberOfSamplesActuallyDecoded = drwav_read_pcm_frames_s32(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames);
+
+ ...
+
+ drwav_uninit(&wav);
+ ```
+
+If you just want to quickly open and read the audio data in a single operation you can do something like this:
+
+ ```c
+ unsigned int channels;
+ unsigned int sampleRate;
+ drwav_uint64 totalPCMFrameCount;
+ float* pSampleData = drwav_open_file_and_read_pcm_frames_f32("my_song.wav", &channels, &sampleRate, &totalPCMFrameCount, NULL);
+ if (pSampleData == NULL) {
+ // Error opening and reading WAV file.
+ }
+
+ ...
+
+ drwav_free(pSampleData);
+ ```
+
+The examples above use versions of the API that convert the audio data to a consistent format (32-bit signed PCM, in this case), but you can still output the
+audio data in its internal format (see notes below for supported formats):
+
+ ```c
+ size_t framesRead = drwav_read_pcm_frames(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames);
+ ```
+
+You can also read the raw bytes of audio data, which could be useful if dr_wav does not have native support for a particular data format:
+
+ ```c
+ size_t bytesRead = drwav_read_raw(&wav, bytesToRead, pRawDataBuffer);
+ ```
+
+dr_wav can also be used to output WAV files. This does not currently support compressed formats. To use this, look at `drwav_init_write()`,
+`drwav_init_file_write()`, etc. Use `drwav_write_pcm_frames()` to write samples, or `drwav_write_raw()` to write raw data in the "data" chunk.
+
+ ```c
+ drwav_data_format format;
+ format.container = drwav_container_riff; // <-- drwav_container_riff = normal WAV files, drwav_container_w64 = Sony Wave64.
+ format.format = DR_WAVE_FORMAT_PCM; // <-- Any of the DR_WAVE_FORMAT_* codes.
+ format.channels = 2;
+ format.sampleRate = 44100;
+ format.bitsPerSample = 16;
+ drwav_init_file_write(&wav, "data/recording.wav", &format, NULL);
+
+ ...
+
+ drwav_uint64 framesWritten = drwav_write_pcm_frames(pWav, frameCount, pSamples);
+ ```
+
+dr_wav has seamless support the Sony Wave64 format. The decoder will automatically detect it and it should Just Work without any manual intervention.
+
+
+Build Options
+=============
+#define these options before including this file.
+
+#define DR_WAV_NO_CONVERSION_API
+ Disables conversion APIs such as `drwav_read_pcm_frames_f32()` and `drwav_s16_to_f32()`.
+
+#define DR_WAV_NO_STDIO
+ Disables APIs that initialize a decoder from a file such as `drwav_init_file()`, `drwav_init_file_write()`, etc.
+
+
+
+Notes
+=====
+- Samples are always interleaved.
+- The default read function does not do any data conversion. Use `drwav_read_pcm_frames_f32()`, `drwav_read_pcm_frames_s32()` and `drwav_read_pcm_frames_s16()`
+ to read and convert audio data to 32-bit floating point, signed 32-bit integer and signed 16-bit integer samples respectively. Tested and supported internal
+ formats include the following:
+ - Unsigned 8-bit PCM
+ - Signed 12-bit PCM
+ - Signed 16-bit PCM
+ - Signed 24-bit PCM
+ - Signed 32-bit PCM
+ - IEEE 32-bit floating point
+ - IEEE 64-bit floating point
+ - A-law and u-law
+ - Microsoft ADPCM
+ - IMA ADPCM (DVI, format code 0x11)
+- dr_wav will try to read the WAV file as best it can, even if it's not strictly conformant to the WAV format.
+*/
+
+#ifndef dr_wav_h
+#define dr_wav_h
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define DRWAV_STRINGIFY(x) #x
+#define DRWAV_XSTRINGIFY(x) DRWAV_STRINGIFY(x)
+
+#define DRWAV_VERSION_MAJOR 0
+#define DRWAV_VERSION_MINOR 12
+#define DRWAV_VERSION_REVISION 16
+#define DRWAV_VERSION_STRING DRWAV_XSTRINGIFY(DRWAV_VERSION_MAJOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_MINOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_REVISION)
+
+#include <stddef.h> /* For size_t. */
+
+/* Sized types. */
+typedef signed char drwav_int8;
+typedef unsigned char drwav_uint8;
+typedef signed short drwav_int16;
+typedef unsigned short drwav_uint16;
+typedef signed int drwav_int32;
+typedef unsigned int drwav_uint32;
+#if defined(_MSC_VER)
+ typedef signed __int64 drwav_int64;
+ typedef unsigned __int64 drwav_uint64;
+#else
+ #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)))
+ #pragma GCC diagnostic push
+ #pragma GCC diagnostic ignored "-Wlong-long"
+ #if defined(__clang__)
+ #pragma GCC diagnostic ignored "-Wc++11-long-long"
+ #endif
+ #endif
+ typedef signed long long drwav_int64;
+ typedef unsigned long long drwav_uint64;
+ #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)))
+ #pragma GCC diagnostic pop
+ #endif
+#endif
+#if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__)) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__)
+ typedef drwav_uint64 drwav_uintptr;
+#else
+ typedef drwav_uint32 drwav_uintptr;
+#endif
+typedef drwav_uint8 drwav_bool8;
+typedef drwav_uint32 drwav_bool32;
+#define DRWAV_TRUE 1
+#define DRWAV_FALSE 0
+
+#if !defined(DRWAV_API)
+ #if defined(DRWAV_DLL)
+ #if defined(_WIN32)
+ #define DRWAV_DLL_IMPORT __declspec(dllimport)
+ #define DRWAV_DLL_EXPORT __declspec(dllexport)
+ #define DRWAV_DLL_PRIVATE static
+ #else
+ #if defined(__GNUC__) && __GNUC__ >= 4
+ #define DRWAV_DLL_IMPORT __attribute__((visibility("default")))
+ #define DRWAV_DLL_EXPORT __attribute__((visibility("default")))
+ #define DRWAV_DLL_PRIVATE __attribute__((visibility("hidden")))
+ #else
+ #define DRWAV_DLL_IMPORT
+ #define DRWAV_DLL_EXPORT
+ #define DRWAV_DLL_PRIVATE static
+ #endif
+ #endif
+
+ #if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION)
+ #define DRWAV_API DRWAV_DLL_EXPORT
+ #else
+ #define DRWAV_API DRWAV_DLL_IMPORT
+ #endif
+ #define DRWAV_PRIVATE DRWAV_DLL_PRIVATE
+ #else
+ #define DRWAV_API extern
+ #define DRWAV_PRIVATE static
+ #endif
+#endif
+
+typedef drwav_int32 drwav_result;
+#define DRWAV_SUCCESS 0
+#define DRWAV_ERROR -1 /* A generic error. */
+#define DRWAV_INVALID_ARGS -2
+#define DRWAV_INVALID_OPERATION -3
+#define DRWAV_OUT_OF_MEMORY -4
+#define DRWAV_OUT_OF_RANGE -5
+#define DRWAV_ACCESS_DENIED -6
+#define DRWAV_DOES_NOT_EXIST -7
+#define DRWAV_ALREADY_EXISTS -8
+#define DRWAV_TOO_MANY_OPEN_FILES -9
+#define DRWAV_INVALID_FILE -10
+#define DRWAV_TOO_BIG -11
+#define DRWAV_PATH_TOO_LONG -12
+#define DRWAV_NAME_TOO_LONG -13
+#define DRWAV_NOT_DIRECTORY -14
+#define DRWAV_IS_DIRECTORY -15
+#define DRWAV_DIRECTORY_NOT_EMPTY -16
+#define DRWAV_END_OF_FILE -17
+#define DRWAV_NO_SPACE -18
+#define DRWAV_BUSY -19
+#define DRWAV_IO_ERROR -20
+#define DRWAV_INTERRUPT -21
+#define DRWAV_UNAVAILABLE -22
+#define DRWAV_ALREADY_IN_USE -23
+#define DRWAV_BAD_ADDRESS -24
+#define DRWAV_BAD_SEEK -25
+#define DRWAV_BAD_PIPE -26
+#define DRWAV_DEADLOCK -27
+#define DRWAV_TOO_MANY_LINKS -28
+#define DRWAV_NOT_IMPLEMENTED -29
+#define DRWAV_NO_MESSAGE -30
+#define DRWAV_BAD_MESSAGE -31
+#define DRWAV_NO_DATA_AVAILABLE -32
+#define DRWAV_INVALID_DATA -33
+#define DRWAV_TIMEOUT -34
+#define DRWAV_NO_NETWORK -35
+#define DRWAV_NOT_UNIQUE -36
+#define DRWAV_NOT_SOCKET -37
+#define DRWAV_NO_ADDRESS -38
+#define DRWAV_BAD_PROTOCOL -39
+#define DRWAV_PROTOCOL_UNAVAILABLE -40
+#define DRWAV_PROTOCOL_NOT_SUPPORTED -41
+#define DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED -42
+#define DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED -43
+#define DRWAV_SOCKET_NOT_SUPPORTED -44
+#define DRWAV_CONNECTION_RESET -45
+#define DRWAV_ALREADY_CONNECTED -46
+#define DRWAV_NOT_CONNECTED -47
+#define DRWAV_CONNECTION_REFUSED -48
+#define DRWAV_NO_HOST -49
+#define DRWAV_IN_PROGRESS -50
+#define DRWAV_CANCELLED -51
+#define DRWAV_MEMORY_ALREADY_MAPPED -52
+#define DRWAV_AT_END -53
+
+/* Common data formats. */
+#define DR_WAVE_FORMAT_PCM 0x1
+#define DR_WAVE_FORMAT_ADPCM 0x2
+#define DR_WAVE_FORMAT_IEEE_FLOAT 0x3
+#define DR_WAVE_FORMAT_ALAW 0x6
+#define DR_WAVE_FORMAT_MULAW 0x7
+#define DR_WAVE_FORMAT_DVI_ADPCM 0x11
+#define DR_WAVE_FORMAT_EXTENSIBLE 0xFFFE
+
+/* Constants. */
+#ifndef DRWAV_MAX_SMPL_LOOPS
+#define DRWAV_MAX_SMPL_LOOPS 1
+#endif
+
+/* Flags to pass into drwav_init_ex(), etc. */
+#define DRWAV_SEQUENTIAL 0x00000001
+
+DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision);
+DRWAV_API const char* drwav_version_string(void);
+
+typedef enum
+{
+ drwav_seek_origin_start,
+ drwav_seek_origin_current
+} drwav_seek_origin;
+
+typedef enum
+{
+ drwav_container_riff,
+ drwav_container_w64,
+ drwav_container_rf64
+} drwav_container;
+
+typedef struct
+{
+ union
+ {
+ drwav_uint8 fourcc[4];
+ drwav_uint8 guid[16];
+ } id;
+
+ /* The size in bytes of the chunk. */
+ drwav_uint64 sizeInBytes;
+
+ /*
+ RIFF = 2 byte alignment.
+ W64 = 8 byte alignment.
+ */
+ unsigned int paddingSize;
+} drwav_chunk_header;
+
+typedef struct
+{
+ /*
+ The format tag exactly as specified in the wave file's "fmt" chunk. This can be used by applications
+ that require support for data formats not natively supported by dr_wav.
+ */
+ drwav_uint16 formatTag;
+
+ /* The number of channels making up the audio data. When this is set to 1 it is mono, 2 is stereo, etc. */
+ drwav_uint16 channels;
+
+ /* The sample rate. Usually set to something like 44100. */
+ drwav_uint32 sampleRate;
+
+ /* Average bytes per second. You probably don't need this, but it's left here for informational purposes. */
+ drwav_uint32 avgBytesPerSec;
+
+ /* Block align. This is equal to the number of channels * bytes per sample. */
+ drwav_uint16 blockAlign;
+
+ /* Bits per sample. */
+ drwav_uint16 bitsPerSample;
+
+ /* The size of the extended data. Only used internally for validation, but left here for informational purposes. */
+ drwav_uint16 extendedSize;
+
+ /*
+ The number of valid bits per sample. When <formatTag> is equal to WAVE_FORMAT_EXTENSIBLE, <bitsPerSample>
+ is always rounded up to the nearest multiple of 8. This variable contains information about exactly how
+ many bits are valid per sample. Mainly used for informational purposes.
+ */
+ drwav_uint16 validBitsPerSample;
+
+ /* The channel mask. Not used at the moment. */
+ drwav_uint32 channelMask;
+
+ /* The sub-format, exactly as specified by the wave file. */
+ drwav_uint8 subFormat[16];
+} drwav_fmt;
+
+DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT);
+
+
+/*
+Callback for when data is read. Return value is the number of bytes actually read.
+
+pUserData [in] The user data that was passed to drwav_init() and family.
+pBufferOut [out] The output buffer.
+bytesToRead [in] The number of bytes to read.
+
+Returns the number of bytes actually read.
+
+A return value of less than bytesToRead indicates the end of the stream. Do _not_ return from this callback until
+either the entire bytesToRead is filled or you have reached the end of the stream.
+*/
+typedef size_t (* drwav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead);
+
+/*
+Callback for when data is written. Returns value is the number of bytes actually written.
+
+pUserData [in] The user data that was passed to drwav_init_write() and family.
+pData [out] A pointer to the data to write.
+bytesToWrite [in] The number of bytes to write.
+
+Returns the number of bytes actually written.
+
+If the return value differs from bytesToWrite, it indicates an error.
+*/
+typedef size_t (* drwav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite);
+
+/*
+Callback for when data needs to be seeked.
+
+pUserData [in] The user data that was passed to drwav_init() and family.
+offset [in] The number of bytes to move, relative to the origin. Will never be negative.
+origin [in] The origin of the seek - the current position or the start of the stream.
+
+Returns whether or not the seek was successful.
+
+Whether or not it is relative to the beginning or current position is determined by the "origin" parameter which will be either drwav_seek_origin_start or
+drwav_seek_origin_current.
+*/
+typedef drwav_bool32 (* drwav_seek_proc)(void* pUserData, int offset, drwav_seek_origin origin);
+
+/*
+Callback for when drwav_init_ex() finds a chunk.
+
+pChunkUserData [in] The user data that was passed to the pChunkUserData parameter of drwav_init_ex() and family.
+onRead [in] A pointer to the function to call when reading.
+onSeek [in] A pointer to the function to call when seeking.
+pReadSeekUserData [in] The user data that was passed to the pReadSeekUserData parameter of drwav_init_ex() and family.
+pChunkHeader [in] A pointer to an object containing basic header information about the chunk. Use this to identify the chunk.
+container [in] Whether or not the WAV file is a RIFF or Wave64 container. If you're unsure of the difference, assume RIFF.
+pFMT [in] A pointer to the object containing the contents of the "fmt" chunk.
+
+Returns the number of bytes read + seeked.
+
+To read data from the chunk, call onRead(), passing in pReadSeekUserData as the first parameter. Do the same for seeking with onSeek(). The return value must
+be the total number of bytes you have read _plus_ seeked.
+
+Use the `container` argument to discriminate the fields in `pChunkHeader->id`. If the container is `drwav_container_riff` or `drwav_container_rf64` you should
+use `id.fourcc`, otherwise you should use `id.guid`.
+
+The `pFMT` parameter can be used to determine the data format of the wave file. Use `drwav_fmt_get_format()` to get the sample format, which will be one of the
+`DR_WAVE_FORMAT_*` identifiers.
+
+The read pointer will be sitting on the first byte after the chunk's header. You must not attempt to read beyond the boundary of the chunk.
+*/
+typedef drwav_uint64 (* drwav_chunk_proc)(void* pChunkUserData, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_chunk_header* pChunkHeader, drwav_container container, const drwav_fmt* pFMT);
+
+typedef struct
+{
+ void* pUserData;
+ void* (* onMalloc)(size_t sz, void* pUserData);
+ void* (* onRealloc)(void* p, size_t sz, void* pUserData);
+ void (* onFree)(void* p, void* pUserData);
+} drwav_allocation_callbacks;
+
+/* Structure for internal use. Only used for loaders opened with drwav_init_memory(). */
+typedef struct
+{
+ const drwav_uint8* data;
+ size_t dataSize;
+ size_t currentReadPos;
+} drwav__memory_stream;
+
+/* Structure for internal use. Only used for writers opened with drwav_init_memory_write(). */
+typedef struct
+{
+ void** ppData;
+ size_t* pDataSize;
+ size_t dataSize;
+ size_t dataCapacity;
+ size_t currentWritePos;
+} drwav__memory_stream_write;
+
+typedef struct
+{
+ drwav_container container; /* RIFF, W64. */
+ drwav_uint32 format; /* DR_WAVE_FORMAT_* */
+ drwav_uint32 channels;
+ drwav_uint32 sampleRate;
+ drwav_uint32 bitsPerSample;
+} drwav_data_format;
+
+
+/* See the following for details on the 'smpl' chunk: https://sites.google.com/site/musicgapi/technical-documents/wav-file-format#smpl */
+typedef struct
+{
+ drwav_uint32 cuePointId;
+ drwav_uint32 type;
+ drwav_uint32 start;
+ drwav_uint32 end;
+ drwav_uint32 fraction;
+ drwav_uint32 playCount;
+} drwav_smpl_loop;
+
+ typedef struct
+{
+ drwav_uint32 manufacturer;
+ drwav_uint32 product;
+ drwav_uint32 samplePeriod;
+ drwav_uint32 midiUnityNotes;
+ drwav_uint32 midiPitchFraction;
+ drwav_uint32 smpteFormat;
+ drwav_uint32 smpteOffset;
+ drwav_uint32 numSampleLoops;
+ drwav_uint32 samplerData;
+ drwav_smpl_loop loops[DRWAV_MAX_SMPL_LOOPS];
+} drwav_smpl;
+
+typedef struct
+{
+ /* A pointer to the function to call when more data is needed. */
+ drwav_read_proc onRead;
+
+ /* A pointer to the function to call when data needs to be written. Only used when the drwav object is opened in write mode. */
+ drwav_write_proc onWrite;
+
+ /* A pointer to the function to call when the wav file needs to be seeked. */
+ drwav_seek_proc onSeek;
+
+ /* The user data to pass to callbacks. */
+ void* pUserData;
+
+ /* Allocation callbacks. */
+ drwav_allocation_callbacks allocationCallbacks;
+
+
+ /* Whether or not the WAV file is formatted as a standard RIFF file or W64. */
+ drwav_container container;
+
+
+ /* Structure containing format information exactly as specified by the wav file. */
+ drwav_fmt fmt;
+
+ /* The sample rate. Will be set to something like 44100. */
+ drwav_uint32 sampleRate;
+
+ /* The number of channels. This will be set to 1 for monaural streams, 2 for stereo, etc. */
+ drwav_uint16 channels;
+
+ /* The bits per sample. Will be set to something like 16, 24, etc. */
+ drwav_uint16 bitsPerSample;
+
+ /* Equal to fmt.formatTag, or the value specified by fmt.subFormat if fmt.formatTag is equal to 65534 (WAVE_FORMAT_EXTENSIBLE). */
+ drwav_uint16 translatedFormatTag;
+
+ /* The total number of PCM frames making up the audio data. */
+ drwav_uint64 totalPCMFrameCount;
+
+
+ /* The size in bytes of the data chunk. */
+ drwav_uint64 dataChunkDataSize;
+
+ /* The position in the stream of the first byte of the data chunk. This is used for seeking. */
+ drwav_uint64 dataChunkDataPos;
+
+ /* The number of bytes remaining in the data chunk. */
+ drwav_uint64 bytesRemaining;
+
+
+ /*
+ Only used in sequential write mode. Keeps track of the desired size of the "data" chunk at the point of initialization time. Always
+ set to 0 for non-sequential writes and when the drwav object is opened in read mode. Used for validation.
+ */
+ drwav_uint64 dataChunkDataSizeTargetWrite;
+
+ /* Keeps track of whether or not the wav writer was initialized in sequential mode. */
+ drwav_bool32 isSequentialWrite;
+
+
+ /* smpl chunk. */
+ drwav_smpl smpl;
+
+
+ /* A hack to avoid a DRWAV_MALLOC() when opening a decoder with drwav_init_memory(). */
+ drwav__memory_stream memoryStream;
+ drwav__memory_stream_write memoryStreamWrite;
+
+ /* Generic data for compressed formats. This data is shared across all block-compressed formats. */
+ struct
+ {
+ drwav_uint64 iCurrentPCMFrame; /* The index of the next PCM frame that will be read by drwav_read_*(). This is used with "totalPCMFrameCount" to ensure we don't read excess samples at the end of the last block. */
+ } compressed;
+
+ /* Microsoft ADPCM specific data. */
+ struct
+ {
+ drwav_uint32 bytesRemainingInBlock;
+ drwav_uint16 predictor[2];
+ drwav_int32 delta[2];
+ drwav_int32 cachedFrames[4]; /* Samples are stored in this cache during decoding. */
+ drwav_uint32 cachedFrameCount;
+ drwav_int32 prevFrames[2][2]; /* The previous 2 samples for each channel (2 channels at most). */
+ } msadpcm;
+
+ /* IMA ADPCM specific data. */
+ struct
+ {
+ drwav_uint32 bytesRemainingInBlock;
+ drwav_int32 predictor[2];
+ drwav_int32 stepIndex[2];
+ drwav_int32 cachedFrames[16]; /* Samples are stored in this cache during decoding. */
+ drwav_uint32 cachedFrameCount;
+ } ima;
+} drwav;
+
+
+/*
+Initializes a pre-allocated drwav object for reading.
+
+pWav [out] A pointer to the drwav object being initialized.
+onRead [in] The function to call when data needs to be read from the client.
+onSeek [in] The function to call when the read position of the client data needs to move.
+onChunk [in, optional] The function to call when a chunk is enumerated at initialized time.
+pUserData, pReadSeekUserData [in, optional] A pointer to application defined data that will be passed to onRead and onSeek.
+pChunkUserData [in, optional] A pointer to application defined data that will be passed to onChunk.
+flags [in, optional] A set of flags for controlling how things are loaded.
+
+Returns true if successful; false otherwise.
+
+Close the loader with drwav_uninit().
+
+This is the lowest level function for initializing a WAV file. You can also use drwav_init_file() and drwav_init_memory()
+to open the stream from a file or from a block of memory respectively.
+
+Possible values for flags:
+ DRWAV_SEQUENTIAL: Never perform a backwards seek while loading. This disables the chunk callback and will cause this function
+ to return as soon as the data chunk is found. Any chunks after the data chunk will be ignored.
+
+drwav_init() is equivalent to "drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0);".
+
+The onChunk callback is not called for the WAVE or FMT chunks. The contents of the FMT chunk can be read from pWav->fmt
+after the function returns.
+
+See also: drwav_init_file(), drwav_init_memory(), drwav_uninit()
+*/
+DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/*
+Initializes a pre-allocated drwav object for writing.
+
+onWrite [in] The function to call when data needs to be written.
+onSeek [in] The function to call when the write position needs to move.
+pUserData [in, optional] A pointer to application defined data that will be passed to onWrite and onSeek.
+
+Returns true if successful; false otherwise.
+
+Close the writer with drwav_uninit().
+
+This is the lowest level function for initializing a WAV file. You can also use drwav_init_file_write() and drwav_init_memory_write()
+to open the stream from a file or from a block of memory respectively.
+
+If the total sample count is known, you can use drwav_init_write_sequential(). This avoids the need for dr_wav to perform
+a post-processing step for storing the total sample count and the size of the data chunk which requires a backwards seek.
+
+See also: drwav_init_file_write(), drwav_init_memory_write(), drwav_uninit()
+*/
+DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/*
+Utility function to determine the target size of the entire data to be written (including all headers and chunks).
+
+Returns the target size in bytes.
+
+Useful if the application needs to know the size to allocate.
+
+Only writing to the RIFF chunk and one data chunk is currently supported.
+
+See also: drwav_init_write(), drwav_init_file_write(), drwav_init_memory_write()
+*/
+DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount);
+
+/*
+Uninitializes the given drwav object.
+
+Use this only for objects initialized with drwav_init*() functions (drwav_init(), drwav_init_ex(), drwav_init_write(), drwav_init_write_sequential()).
+*/
+DRWAV_API drwav_result drwav_uninit(drwav* pWav);
+
+
+/*
+Reads raw audio data.
+
+This is the lowest level function for reading audio data. It simply reads the given number of
+bytes of the raw internal sample data.
+
+Consider using drwav_read_pcm_frames_s16(), drwav_read_pcm_frames_s32() or drwav_read_pcm_frames_f32() for
+reading sample data in a consistent format.
+
+pBufferOut can be NULL in which case a seek will be performed.
+
+Returns the number of bytes actually read.
+*/
+DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut);
+
+/*
+Reads up to the specified number of PCM frames from the WAV file.
+
+The output data will be in the file's internal format, converted to native-endian byte order. Use
+drwav_read_pcm_frames_s16/f32/s32() to read data in a specific format.
+
+If the return value is less than <framesToRead> it means the end of the file has been reached or
+you have requested more PCM frames than can possibly fit in the output buffer.
+
+This function will only work when sample data is of a fixed size and uncompressed. If you are
+using a compressed format consider using drwav_read_raw() or drwav_read_pcm_frames_s16/s32/f32().
+
+pBufferOut can be NULL in which case a seek will be performed.
+*/
+DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut);
+
+/*
+Seeks to the given PCM frame.
+
+Returns true if successful; false otherwise.
+*/
+DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex);
+
+
+/*
+Writes raw audio data.
+
+Returns the number of bytes actually written. If this differs from bytesToWrite, it indicates an error.
+*/
+DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData);
+
+/*
+Writes PCM frames.
+
+Returns the number of PCM frames written.
+
+Input samples need to be in native-endian byte order. On big-endian architectures the input data will be converted to
+little-endian. Use drwav_write_raw() to write raw audio data without performing any conversion.
+*/
+DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData);
+DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData);
+DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData);
+
+
+/* Conversion Utilities */
+#ifndef DR_WAV_NO_CONVERSION_API
+
+/*
+Reads a chunk of audio data and converts it to signed 16-bit PCM samples.
+
+pBufferOut can be NULL in which case a seek will be performed.
+
+Returns the number of PCM frames actually read.
+
+If the return value is less than <framesToRead> it means the end of the file has been reached.
+*/
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut);
+
+/* Low-level function for converting unsigned 8-bit PCM samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 24-bit PCM samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 32-bit PCM samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 32-bit floating point samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 64-bit floating point samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount);
+
+/* Low-level function for converting A-law samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting u-law samples to signed 16-bit PCM samples. */
+DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+
+/*
+Reads a chunk of audio data and converts it to IEEE 32-bit floating point samples.
+
+pBufferOut can be NULL in which case a seek will be performed.
+
+Returns the number of PCM frames actually read.
+
+If the return value is less than <framesToRead> it means the end of the file has been reached.
+*/
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut);
+
+/* Low-level function for converting unsigned 8-bit PCM samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 16-bit PCM samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 24-bit PCM samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 32-bit PCM samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 64-bit floating point samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount);
+
+/* Low-level function for converting A-law samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting u-law samples to IEEE 32-bit floating point samples. */
+DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+
+/*
+Reads a chunk of audio data and converts it to signed 32-bit PCM samples.
+
+pBufferOut can be NULL in which case a seek will be performed.
+
+Returns the number of PCM frames actually read.
+
+If the return value is less than <framesToRead> it means the end of the file has been reached.
+*/
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut);
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut);
+
+/* Low-level function for converting unsigned 8-bit PCM samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 16-bit PCM samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount);
+
+/* Low-level function for converting signed 24-bit PCM samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 32-bit floating point samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount);
+
+/* Low-level function for converting IEEE 64-bit floating point samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount);
+
+/* Low-level function for converting A-law samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+/* Low-level function for converting u-law samples to signed 32-bit PCM samples. */
+DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount);
+
+#endif /* DR_WAV_NO_CONVERSION_API */
+
+
+/* High-Level Convenience Helpers */
+
+#ifndef DR_WAV_NO_STDIO
+/*
+Helper for initializing a wave file for reading using stdio.
+
+This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav
+objects because the operating system may restrict the number of file handles an application can have open at
+any given time.
+*/
+DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/*
+Helper for initializing a wave file for writing using stdio.
+
+This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav
+objects because the operating system may restrict the number of file handles an application can have open at
+any given time.
+*/
+DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+#endif /* DR_WAV_NO_STDIO */
+
+/*
+Helper for initializing a loader from a pre-allocated memory buffer.
+
+This does not create a copy of the data. It is up to the application to ensure the buffer remains valid for
+the lifetime of the drwav object.
+
+The buffer should contain the contents of the entire wave file, not just the sample data.
+*/
+DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/*
+Helper for initializing a writer which outputs data to a memory buffer.
+
+dr_wav will manage the memory allocations, however it is up to the caller to free the data with drwav_free().
+
+The buffer will remain allocated even after drwav_uninit() is called. The buffer should not be considered valid
+until after drwav_uninit() has been called.
+*/
+DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+
+#ifndef DR_WAV_NO_CONVERSION_API
+/*
+Opens and reads an entire wav file in a single operation.
+
+The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer.
+*/
+DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+#ifndef DR_WAV_NO_STDIO
+/*
+Opens and decodes an entire wav file in a single operation.
+
+The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer.
+*/
+DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+#endif
+/*
+Opens and decodes an entire wav file from a block of memory in a single operation.
+
+The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer.
+*/
+DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks);
+#endif
+
+/* Frees data that was allocated internally by dr_wav. */
+DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks);
+
+/* Converts bytes from a wav stream to a sized type of native endian. */
+DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data);
+DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data);
+DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data);
+DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data);
+DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data);
+DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data);
+
+/* Compares a GUID for the purpose of checking the type of a Wave64 chunk. */
+DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]);
+
+/* Compares a four-character-code for the purpose of checking the type of a RIFF chunk. */
+DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b);
+
+#ifdef __cplusplus
+}
+#endif
+#endif /* dr_wav_h */
+
+
+/************************************************************************************************************************************************************
+ ************************************************************************************************************************************************************
+
+ IMPLEMENTATION
+
+ ************************************************************************************************************************************************************
+ ************************************************************************************************************************************************************/
+#if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION)
+#ifndef dr_wav_c
+#define dr_wav_c
+
+#include <stdlib.h>
+#include <string.h> /* For memcpy(), memset() */
+#include <limits.h> /* For INT_MAX */
+
+#ifndef DR_WAV_NO_STDIO
+#include <stdio.h>
+#include <wchar.h>
+#endif
+
+/* Standard library stuff. */
+#ifndef DRWAV_ASSERT
+#include <assert.h>
+#define DRWAV_ASSERT(expression) assert(expression)
+#endif
+#ifndef DRWAV_MALLOC
+#define DRWAV_MALLOC(sz) malloc((sz))
+#endif
+#ifndef DRWAV_REALLOC
+#define DRWAV_REALLOC(p, sz) realloc((p), (sz))
+#endif
+#ifndef DRWAV_FREE
+#define DRWAV_FREE(p) free((p))
+#endif
+#ifndef DRWAV_COPY_MEMORY
+#define DRWAV_COPY_MEMORY(dst, src, sz) memcpy((dst), (src), (sz))
+#endif
+#ifndef DRWAV_ZERO_MEMORY
+#define DRWAV_ZERO_MEMORY(p, sz) memset((p), 0, (sz))
+#endif
+#ifndef DRWAV_ZERO_OBJECT
+#define DRWAV_ZERO_OBJECT(p) DRWAV_ZERO_MEMORY((p), sizeof(*p))
+#endif
+
+#define drwav_countof(x) (sizeof(x) / sizeof(x[0]))
+#define drwav_align(x, a) ((((x) + (a) - 1) / (a)) * (a))
+#define drwav_min(a, b) (((a) < (b)) ? (a) : (b))
+#define drwav_max(a, b) (((a) > (b)) ? (a) : (b))
+#define drwav_clamp(x, lo, hi) (drwav_max((lo), drwav_min((hi), (x))))
+
+#define DRWAV_MAX_SIMD_VECTOR_SIZE 64 /* 64 for AVX-512 in the future. */
+
+/* CPU architecture. */
+#if defined(__x86_64__) || defined(_M_X64)
+ #define DRWAV_X64
+#elif defined(__i386) || defined(_M_IX86)
+ #define DRWAV_X86
+#elif defined(__arm__) || defined(_M_ARM)
+ #define DRWAV_ARM
+#endif
+
+#ifdef _MSC_VER
+ #define DRWAV_INLINE __forceinline
+#elif defined(__GNUC__)
+ /*
+ I've had a bug report where GCC is emitting warnings about functions possibly not being inlineable. This warning happens when
+ the __attribute__((always_inline)) attribute is defined without an "inline" statement. I think therefore there must be some
+ case where "__inline__" is not always defined, thus the compiler emitting these warnings. When using -std=c89 or -ansi on the
+ command line, we cannot use the "inline" keyword and instead need to use "__inline__". In an attempt to work around this issue
+ I am using "__inline__" only when we're compiling in strict ANSI mode.
+ */
+ #if defined(__STRICT_ANSI__)
+ #define DRWAV_INLINE __inline__ __attribute__((always_inline))
+ #else
+ #define DRWAV_INLINE inline __attribute__((always_inline))
+ #endif
+#elif defined(__WATCOMC__)
+ #define DRWAV_INLINE __inline
+#else
+ #define DRWAV_INLINE
+#endif
+
+#if defined(SIZE_MAX)
+ #define DRWAV_SIZE_MAX SIZE_MAX
+#else
+ #if defined(_WIN64) || defined(_LP64) || defined(__LP64__)
+ #define DRWAV_SIZE_MAX ((drwav_uint64)0xFFFFFFFFFFFFFFFF)
+ #else
+ #define DRWAV_SIZE_MAX 0xFFFFFFFF
+ #endif
+#endif
+
+#if defined(_MSC_VER) && _MSC_VER >= 1400
+ #define DRWAV_HAS_BYTESWAP16_INTRINSIC
+ #define DRWAV_HAS_BYTESWAP32_INTRINSIC
+ #define DRWAV_HAS_BYTESWAP64_INTRINSIC
+#elif defined(__clang__)
+ #if defined(__has_builtin)
+ #if __has_builtin(__builtin_bswap16)
+ #define DRWAV_HAS_BYTESWAP16_INTRINSIC
+ #endif
+ #if __has_builtin(__builtin_bswap32)
+ #define DRWAV_HAS_BYTESWAP32_INTRINSIC
+ #endif
+ #if __has_builtin(__builtin_bswap64)
+ #define DRWAV_HAS_BYTESWAP64_INTRINSIC
+ #endif
+ #endif
+#elif defined(__GNUC__)
+ #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3))
+ #define DRWAV_HAS_BYTESWAP32_INTRINSIC
+ #define DRWAV_HAS_BYTESWAP64_INTRINSIC
+ #endif
+ #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))
+ #define DRWAV_HAS_BYTESWAP16_INTRINSIC
+ #endif
+#endif
+
+DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision)
+{
+ if (pMajor) {
+ *pMajor = DRWAV_VERSION_MAJOR;
+ }
+
+ if (pMinor) {
+ *pMinor = DRWAV_VERSION_MINOR;
+ }
+
+ if (pRevision) {
+ *pRevision = DRWAV_VERSION_REVISION;
+ }
+}
+
+DRWAV_API const char* drwav_version_string(void)
+{
+ return DRWAV_VERSION_STRING;
+}
+
+/*
+These limits are used for basic validation when initializing the decoder. If you exceed these limits, first of all: what on Earth are
+you doing?! (Let me know, I'd be curious!) Second, you can adjust these by #define-ing them before the dr_wav implementation.
+*/
+#ifndef DRWAV_MAX_SAMPLE_RATE
+#define DRWAV_MAX_SAMPLE_RATE 384000
+#endif
+#ifndef DRWAV_MAX_CHANNELS
+#define DRWAV_MAX_CHANNELS 256
+#endif
+#ifndef DRWAV_MAX_BITS_PER_SAMPLE
+#define DRWAV_MAX_BITS_PER_SAMPLE 64
+#endif
+
+static const drwav_uint8 drwavGUID_W64_RIFF[16] = {0x72,0x69,0x66,0x66, 0x2E,0x91, 0xCF,0x11, 0xA5,0xD6, 0x28,0xDB,0x04,0xC1,0x00,0x00}; /* 66666972-912E-11CF-A5D6-28DB04C10000 */
+static const drwav_uint8 drwavGUID_W64_WAVE[16] = {0x77,0x61,0x76,0x65, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 65766177-ACF3-11D3-8CD1-00C04F8EDB8A */
+/*static const drwav_uint8 drwavGUID_W64_JUNK[16] = {0x6A,0x75,0x6E,0x6B, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A};*/ /* 6B6E756A-ACF3-11D3-8CD1-00C04F8EDB8A */
+static const drwav_uint8 drwavGUID_W64_FMT [16] = {0x66,0x6D,0x74,0x20, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 20746D66-ACF3-11D3-8CD1-00C04F8EDB8A */
+static const drwav_uint8 drwavGUID_W64_FACT[16] = {0x66,0x61,0x63,0x74, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 74636166-ACF3-11D3-8CD1-00C04F8EDB8A */
+static const drwav_uint8 drwavGUID_W64_DATA[16] = {0x64,0x61,0x74,0x61, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 61746164-ACF3-11D3-8CD1-00C04F8EDB8A */
+static const drwav_uint8 drwavGUID_W64_SMPL[16] = {0x73,0x6D,0x70,0x6C, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 6C706D73-ACF3-11D3-8CD1-00C04F8EDB8A */
+
+static DRWAV_INLINE drwav_bool32 drwav__guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16])
+{
+ int i;
+ for (i = 0; i < 16; i += 1) {
+ if (a[i] != b[i]) {
+ return DRWAV_FALSE;
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+static DRWAV_INLINE drwav_bool32 drwav__fourcc_equal(const drwav_uint8* a, const char* b)
+{
+ return
+ a[0] == b[0] &&
+ a[1] == b[1] &&
+ a[2] == b[2] &&
+ a[3] == b[3];
+}
+
+
+
+static DRWAV_INLINE int drwav__is_little_endian(void)
+{
+#if defined(DRWAV_X86) || defined(DRWAV_X64)
+ return DRWAV_TRUE;
+#elif defined(__BYTE_ORDER) && defined(__LITTLE_ENDIAN) && __BYTE_ORDER == __LITTLE_ENDIAN
+ return DRWAV_TRUE;
+#else
+ int n = 1;
+ return (*(char*)&n) == 1;
+#endif
+}
+
+static DRWAV_INLINE drwav_uint16 drwav__bytes_to_u16(const drwav_uint8* data)
+{
+ return (data[0] << 0) | (data[1] << 8);
+}
+
+static DRWAV_INLINE drwav_int16 drwav__bytes_to_s16(const drwav_uint8* data)
+{
+ return (short)drwav__bytes_to_u16(data);
+}
+
+static DRWAV_INLINE drwav_uint32 drwav__bytes_to_u32(const drwav_uint8* data)
+{
+ return (data[0] << 0) | (data[1] << 8) | (data[2] << 16) | (data[3] << 24);
+}
+
+static DRWAV_INLINE drwav_int32 drwav__bytes_to_s32(const drwav_uint8* data)
+{
+ return (drwav_int32)drwav__bytes_to_u32(data);
+}
+
+static DRWAV_INLINE drwav_uint64 drwav__bytes_to_u64(const drwav_uint8* data)
+{
+ return
+ ((drwav_uint64)data[0] << 0) | ((drwav_uint64)data[1] << 8) | ((drwav_uint64)data[2] << 16) | ((drwav_uint64)data[3] << 24) |
+ ((drwav_uint64)data[4] << 32) | ((drwav_uint64)data[5] << 40) | ((drwav_uint64)data[6] << 48) | ((drwav_uint64)data[7] << 56);
+}
+
+static DRWAV_INLINE drwav_int64 drwav__bytes_to_s64(const drwav_uint8* data)
+{
+ return (drwav_int64)drwav__bytes_to_u64(data);
+}
+
+static DRWAV_INLINE void drwav__bytes_to_guid(const drwav_uint8* data, drwav_uint8* guid)
+{
+ int i;
+ for (i = 0; i < 16; ++i) {
+ guid[i] = data[i];
+ }
+}
+
+
+static DRWAV_INLINE drwav_uint16 drwav__bswap16(drwav_uint16 n)
+{
+#ifdef DRWAV_HAS_BYTESWAP16_INTRINSIC
+ #if defined(_MSC_VER)
+ return _byteswap_ushort(n);
+ #elif defined(__GNUC__) || defined(__clang__)
+ return __builtin_bswap16(n);
+ #else
+ #error "This compiler does not support the byte swap intrinsic."
+ #endif
+#else
+ return ((n & 0xFF00) >> 8) |
+ ((n & 0x00FF) << 8);
+#endif
+}
+
+static DRWAV_INLINE drwav_uint32 drwav__bswap32(drwav_uint32 n)
+{
+#ifdef DRWAV_HAS_BYTESWAP32_INTRINSIC
+ #if defined(_MSC_VER)
+ return _byteswap_ulong(n);
+ #elif defined(__GNUC__) || defined(__clang__)
+ #if defined(DRWAV_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 6) && !defined(DRWAV_64BIT) /* <-- 64-bit inline assembly has not been tested, so disabling for now. */
+ /* Inline assembly optimized implementation for ARM. In my testing, GCC does not generate optimized code with __builtin_bswap32(). */
+ drwav_uint32 r;
+ __asm__ __volatile__ (
+ #if defined(DRWAV_64BIT)
+ "rev %w[out], %w[in]" : [out]"=r"(r) : [in]"r"(n) /* <-- This is untested. If someone in the community could test this, that would be appreciated! */
+ #else
+ "rev %[out], %[in]" : [out]"=r"(r) : [in]"r"(n)
+ #endif
+ );
+ return r;
+ #else
+ return __builtin_bswap32(n);
+ #endif
+ #else
+ #error "This compiler does not support the byte swap intrinsic."
+ #endif
+#else
+ return ((n & 0xFF000000) >> 24) |
+ ((n & 0x00FF0000) >> 8) |
+ ((n & 0x0000FF00) << 8) |
+ ((n & 0x000000FF) << 24);
+#endif
+}
+
+static DRWAV_INLINE drwav_uint64 drwav__bswap64(drwav_uint64 n)
+{
+#ifdef DRWAV_HAS_BYTESWAP64_INTRINSIC
+ #if defined(_MSC_VER)
+ return _byteswap_uint64(n);
+ #elif defined(__GNUC__) || defined(__clang__)
+ return __builtin_bswap64(n);
+ #else
+ #error "This compiler does not support the byte swap intrinsic."
+ #endif
+#else
+ /* Weird "<< 32" bitshift is required for C89 because it doesn't support 64-bit constants. Should be optimized out by a good compiler. */
+ return ((n & ((drwav_uint64)0xFF000000 << 32)) >> 56) |
+ ((n & ((drwav_uint64)0x00FF0000 << 32)) >> 40) |
+ ((n & ((drwav_uint64)0x0000FF00 << 32)) >> 24) |
+ ((n & ((drwav_uint64)0x000000FF << 32)) >> 8) |
+ ((n & ((drwav_uint64)0xFF000000 )) << 8) |
+ ((n & ((drwav_uint64)0x00FF0000 )) << 24) |
+ ((n & ((drwav_uint64)0x0000FF00 )) << 40) |
+ ((n & ((drwav_uint64)0x000000FF )) << 56);
+#endif
+}
+
+
+static DRWAV_INLINE drwav_int16 drwav__bswap_s16(drwav_int16 n)
+{
+ return (drwav_int16)drwav__bswap16((drwav_uint16)n);
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_s16(drwav_int16* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ pSamples[iSample] = drwav__bswap_s16(pSamples[iSample]);
+ }
+}
+
+
+static DRWAV_INLINE void drwav__bswap_s24(drwav_uint8* p)
+{
+ drwav_uint8 t;
+ t = p[0];
+ p[0] = p[2];
+ p[2] = t;
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_s24(drwav_uint8* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ drwav_uint8* pSample = pSamples + (iSample*3);
+ drwav__bswap_s24(pSample);
+ }
+}
+
+
+static DRWAV_INLINE drwav_int32 drwav__bswap_s32(drwav_int32 n)
+{
+ return (drwav_int32)drwav__bswap32((drwav_uint32)n);
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_s32(drwav_int32* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ pSamples[iSample] = drwav__bswap_s32(pSamples[iSample]);
+ }
+}
+
+
+static DRWAV_INLINE float drwav__bswap_f32(float n)
+{
+ union {
+ drwav_uint32 i;
+ float f;
+ } x;
+ x.f = n;
+ x.i = drwav__bswap32(x.i);
+
+ return x.f;
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_f32(float* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ pSamples[iSample] = drwav__bswap_f32(pSamples[iSample]);
+ }
+}
+
+
+static DRWAV_INLINE double drwav__bswap_f64(double n)
+{
+ union {
+ drwav_uint64 i;
+ double f;
+ } x;
+ x.f = n;
+ x.i = drwav__bswap64(x.i);
+
+ return x.f;
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_f64(double* pSamples, drwav_uint64 sampleCount)
+{
+ drwav_uint64 iSample;
+ for (iSample = 0; iSample < sampleCount; iSample += 1) {
+ pSamples[iSample] = drwav__bswap_f64(pSamples[iSample]);
+ }
+}
+
+
+static DRWAV_INLINE void drwav__bswap_samples_pcm(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample)
+{
+ /* Assumes integer PCM. Floating point PCM is done in drwav__bswap_samples_ieee(). */
+ switch (bytesPerSample)
+ {
+ case 2: /* s16, s12 (loosely packed) */
+ {
+ drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount);
+ } break;
+ case 3: /* s24 */
+ {
+ drwav__bswap_samples_s24((drwav_uint8*)pSamples, sampleCount);
+ } break;
+ case 4: /* s32 */
+ {
+ drwav__bswap_samples_s32((drwav_int32*)pSamples, sampleCount);
+ } break;
+ default:
+ {
+ /* Unsupported format. */
+ DRWAV_ASSERT(DRWAV_FALSE);
+ } break;
+ }
+}
+
+static DRWAV_INLINE void drwav__bswap_samples_ieee(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample)
+{
+ switch (bytesPerSample)
+ {
+ #if 0 /* Contributions welcome for f16 support. */
+ case 2: /* f16 */
+ {
+ drwav__bswap_samples_f16((drwav_float16*)pSamples, sampleCount);
+ } break;
+ #endif
+ case 4: /* f32 */
+ {
+ drwav__bswap_samples_f32((float*)pSamples, sampleCount);
+ } break;
+ case 8: /* f64 */
+ {
+ drwav__bswap_samples_f64((double*)pSamples, sampleCount);
+ } break;
+ default:
+ {
+ /* Unsupported format. */
+ DRWAV_ASSERT(DRWAV_FALSE);
+ } break;
+ }
+}
+
+static DRWAV_INLINE void drwav__bswap_samples(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample, drwav_uint16 format)
+{
+ switch (format)
+ {
+ case DR_WAVE_FORMAT_PCM:
+ {
+ drwav__bswap_samples_pcm(pSamples, sampleCount, bytesPerSample);
+ } break;
+
+ case DR_WAVE_FORMAT_IEEE_FLOAT:
+ {
+ drwav__bswap_samples_ieee(pSamples, sampleCount, bytesPerSample);
+ } break;
+
+ case DR_WAVE_FORMAT_ALAW:
+ case DR_WAVE_FORMAT_MULAW:
+ {
+ drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount);
+ } break;
+
+ case DR_WAVE_FORMAT_ADPCM:
+ case DR_WAVE_FORMAT_DVI_ADPCM:
+ default:
+ {
+ /* Unsupported format. */
+ DRWAV_ASSERT(DRWAV_FALSE);
+ } break;
+ }
+}
+
+
+static void* drwav__malloc_default(size_t sz, void* pUserData)
+{
+ (void)pUserData;
+ return DRWAV_MALLOC(sz);
+}
+
+static void* drwav__realloc_default(void* p, size_t sz, void* pUserData)
+{
+ (void)pUserData;
+ return DRWAV_REALLOC(p, sz);
+}
+
+static void drwav__free_default(void* p, void* pUserData)
+{
+ (void)pUserData;
+ DRWAV_FREE(p);
+}
+
+
+static void* drwav__malloc_from_callbacks(size_t sz, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pAllocationCallbacks == NULL) {
+ return NULL;
+ }
+
+ if (pAllocationCallbacks->onMalloc != NULL) {
+ return pAllocationCallbacks->onMalloc(sz, pAllocationCallbacks->pUserData);
+ }
+
+ /* Try using realloc(). */
+ if (pAllocationCallbacks->onRealloc != NULL) {
+ return pAllocationCallbacks->onRealloc(NULL, sz, pAllocationCallbacks->pUserData);
+ }
+
+ return NULL;
+}
+
+static void* drwav__realloc_from_callbacks(void* p, size_t szNew, size_t szOld, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pAllocationCallbacks == NULL) {
+ return NULL;
+ }
+
+ if (pAllocationCallbacks->onRealloc != NULL) {
+ return pAllocationCallbacks->onRealloc(p, szNew, pAllocationCallbacks->pUserData);
+ }
+
+ /* Try emulating realloc() in terms of malloc()/free(). */
+ if (pAllocationCallbacks->onMalloc != NULL && pAllocationCallbacks->onFree != NULL) {
+ void* p2;
+
+ p2 = pAllocationCallbacks->onMalloc(szNew, pAllocationCallbacks->pUserData);
+ if (p2 == NULL) {
+ return NULL;
+ }
+
+ if (p != NULL) {
+ DRWAV_COPY_MEMORY(p2, p, szOld);
+ pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData);
+ }
+
+ return p2;
+ }
+
+ return NULL;
+}
+
+static void drwav__free_from_callbacks(void* p, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (p == NULL || pAllocationCallbacks == NULL) {
+ return;
+ }
+
+ if (pAllocationCallbacks->onFree != NULL) {
+ pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData);
+ }
+}
+
+
+static drwav_allocation_callbacks drwav_copy_allocation_callbacks_or_defaults(const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pAllocationCallbacks != NULL) {
+ /* Copy. */
+ return *pAllocationCallbacks;
+ } else {
+ /* Defaults. */
+ drwav_allocation_callbacks allocationCallbacks;
+ allocationCallbacks.pUserData = NULL;
+ allocationCallbacks.onMalloc = drwav__malloc_default;
+ allocationCallbacks.onRealloc = drwav__realloc_default;
+ allocationCallbacks.onFree = drwav__free_default;
+ return allocationCallbacks;
+ }
+}
+
+
+static DRWAV_INLINE drwav_bool32 drwav__is_compressed_format_tag(drwav_uint16 formatTag)
+{
+ return
+ formatTag == DR_WAVE_FORMAT_ADPCM ||
+ formatTag == DR_WAVE_FORMAT_DVI_ADPCM;
+}
+
+static unsigned int drwav__chunk_padding_size_riff(drwav_uint64 chunkSize)
+{
+ return (unsigned int)(chunkSize % 2);
+}
+
+static unsigned int drwav__chunk_padding_size_w64(drwav_uint64 chunkSize)
+{
+ return (unsigned int)(chunkSize % 8);
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut);
+static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut);
+static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount);
+
+static drwav_result drwav__read_chunk_header(drwav_read_proc onRead, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_chunk_header* pHeaderOut)
+{
+ if (container == drwav_container_riff || container == drwav_container_rf64) {
+ drwav_uint8 sizeInBytes[4];
+
+ if (onRead(pUserData, pHeaderOut->id.fourcc, 4) != 4) {
+ return DRWAV_AT_END;
+ }
+
+ if (onRead(pUserData, sizeInBytes, 4) != 4) {
+ return DRWAV_INVALID_FILE;
+ }
+
+ pHeaderOut->sizeInBytes = drwav__bytes_to_u32(sizeInBytes);
+ pHeaderOut->paddingSize = drwav__chunk_padding_size_riff(pHeaderOut->sizeInBytes);
+ *pRunningBytesReadOut += 8;
+ } else {
+ drwav_uint8 sizeInBytes[8];
+
+ if (onRead(pUserData, pHeaderOut->id.guid, 16) != 16) {
+ return DRWAV_AT_END;
+ }
+
+ if (onRead(pUserData, sizeInBytes, 8) != 8) {
+ return DRWAV_INVALID_FILE;
+ }
+
+ pHeaderOut->sizeInBytes = drwav__bytes_to_u64(sizeInBytes) - 24; /* <-- Subtract 24 because w64 includes the size of the header. */
+ pHeaderOut->paddingSize = drwav__chunk_padding_size_w64(pHeaderOut->sizeInBytes);
+ *pRunningBytesReadOut += 24;
+ }
+
+ return DRWAV_SUCCESS;
+}
+
+static drwav_bool32 drwav__seek_forward(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData)
+{
+ drwav_uint64 bytesRemainingToSeek = offset;
+ while (bytesRemainingToSeek > 0) {
+ if (bytesRemainingToSeek > 0x7FFFFFFF) {
+ if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingToSeek -= 0x7FFFFFFF;
+ } else {
+ if (!onSeek(pUserData, (int)bytesRemainingToSeek, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingToSeek = 0;
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+static drwav_bool32 drwav__seek_from_start(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData)
+{
+ if (offset <= 0x7FFFFFFF) {
+ return onSeek(pUserData, (int)offset, drwav_seek_origin_start);
+ }
+
+ /* Larger than 32-bit seek. */
+ if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_start)) {
+ return DRWAV_FALSE;
+ }
+ offset -= 0x7FFFFFFF;
+
+ for (;;) {
+ if (offset <= 0x7FFFFFFF) {
+ return onSeek(pUserData, (int)offset, drwav_seek_origin_current);
+ }
+
+ if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ offset -= 0x7FFFFFFF;
+ }
+
+ /* Should never get here. */
+ /*return DRWAV_TRUE; */
+}
+
+
+static drwav_bool32 drwav__read_fmt(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_fmt* fmtOut)
+{
+ drwav_chunk_header header;
+ drwav_uint8 fmt[16];
+
+ if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+
+ /* Skip non-fmt chunks. */
+ while (((container == drwav_container_riff || container == drwav_container_rf64) && !drwav__fourcc_equal(header.id.fourcc, "fmt ")) || (container == drwav_container_w64 && !drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT))) {
+ if (!drwav__seek_forward(onSeek, header.sizeInBytes + header.paddingSize, pUserData)) {
+ return DRWAV_FALSE;
+ }
+ *pRunningBytesReadOut += header.sizeInBytes + header.paddingSize;
+
+ /* Try the next header. */
+ if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+ }
+
+
+ /* Validation. */
+ if (container == drwav_container_riff || container == drwav_container_rf64) {
+ if (!drwav__fourcc_equal(header.id.fourcc, "fmt ")) {
+ return DRWAV_FALSE;
+ }
+ } else {
+ if (!drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT)) {
+ return DRWAV_FALSE;
+ }
+ }
+
+
+ if (onRead(pUserData, fmt, sizeof(fmt)) != sizeof(fmt)) {
+ return DRWAV_FALSE;
+ }
+ *pRunningBytesReadOut += sizeof(fmt);
+
+ fmtOut->formatTag = drwav__bytes_to_u16(fmt + 0);
+ fmtOut->channels = drwav__bytes_to_u16(fmt + 2);
+ fmtOut->sampleRate = drwav__bytes_to_u32(fmt + 4);
+ fmtOut->avgBytesPerSec = drwav__bytes_to_u32(fmt + 8);
+ fmtOut->blockAlign = drwav__bytes_to_u16(fmt + 12);
+ fmtOut->bitsPerSample = drwav__bytes_to_u16(fmt + 14);
+
+ fmtOut->extendedSize = 0;
+ fmtOut->validBitsPerSample = 0;
+ fmtOut->channelMask = 0;
+ memset(fmtOut->subFormat, 0, sizeof(fmtOut->subFormat));
+
+ if (header.sizeInBytes > 16) {
+ drwav_uint8 fmt_cbSize[2];
+ int bytesReadSoFar = 0;
+
+ if (onRead(pUserData, fmt_cbSize, sizeof(fmt_cbSize)) != sizeof(fmt_cbSize)) {
+ return DRWAV_FALSE; /* Expecting more data. */
+ }
+ *pRunningBytesReadOut += sizeof(fmt_cbSize);
+
+ bytesReadSoFar = 18;
+
+ fmtOut->extendedSize = drwav__bytes_to_u16(fmt_cbSize);
+ if (fmtOut->extendedSize > 0) {
+ /* Simple validation. */
+ if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) {
+ if (fmtOut->extendedSize != 22) {
+ return DRWAV_FALSE;
+ }
+ }
+
+ if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) {
+ drwav_uint8 fmtext[22];
+ if (onRead(pUserData, fmtext, fmtOut->extendedSize) != fmtOut->extendedSize) {
+ return DRWAV_FALSE; /* Expecting more data. */
+ }
+
+ fmtOut->validBitsPerSample = drwav__bytes_to_u16(fmtext + 0);
+ fmtOut->channelMask = drwav__bytes_to_u32(fmtext + 2);
+ drwav__bytes_to_guid(fmtext + 6, fmtOut->subFormat);
+ } else {
+ if (!onSeek(pUserData, fmtOut->extendedSize, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ }
+ *pRunningBytesReadOut += fmtOut->extendedSize;
+
+ bytesReadSoFar += fmtOut->extendedSize;
+ }
+
+ /* Seek past any leftover bytes. For w64 the leftover will be defined based on the chunk size. */
+ if (!onSeek(pUserData, (int)(header.sizeInBytes - bytesReadSoFar), drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ *pRunningBytesReadOut += (header.sizeInBytes - bytesReadSoFar);
+ }
+
+ if (header.paddingSize > 0) {
+ if (!onSeek(pUserData, header.paddingSize, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+ *pRunningBytesReadOut += header.paddingSize;
+ }
+
+ return DRWAV_TRUE;
+}
+
+
+static size_t drwav__on_read(drwav_read_proc onRead, void* pUserData, void* pBufferOut, size_t bytesToRead, drwav_uint64* pCursor)
+{
+ size_t bytesRead;
+
+ DRWAV_ASSERT(onRead != NULL);
+ DRWAV_ASSERT(pCursor != NULL);
+
+ bytesRead = onRead(pUserData, pBufferOut, bytesToRead);
+ *pCursor += bytesRead;
+ return bytesRead;
+}
+
+#if 0
+static drwav_bool32 drwav__on_seek(drwav_seek_proc onSeek, void* pUserData, int offset, drwav_seek_origin origin, drwav_uint64* pCursor)
+{
+ DRWAV_ASSERT(onSeek != NULL);
+ DRWAV_ASSERT(pCursor != NULL);
+
+ if (!onSeek(pUserData, offset, origin)) {
+ return DRWAV_FALSE;
+ }
+
+ if (origin == drwav_seek_origin_start) {
+ *pCursor = offset;
+ } else {
+ *pCursor += offset;
+ }
+
+ return DRWAV_TRUE;
+}
+#endif
+
+
+
+static drwav_uint32 drwav_get_bytes_per_pcm_frame(drwav* pWav)
+{
+ /*
+ The bytes per frame is a bit ambiguous. It can be either be based on the bits per sample, or the block align. The way I'm doing it here
+ is that if the bits per sample is a multiple of 8, use floor(bitsPerSample*channels/8), otherwise fall back to the block align.
+ */
+ if ((pWav->bitsPerSample & 0x7) == 0) {
+ /* Bits per sample is a multiple of 8. */
+ return (pWav->bitsPerSample * pWav->fmt.channels) >> 3;
+ } else {
+ return pWav->fmt.blockAlign;
+ }
+}
+
+DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT)
+{
+ if (pFMT == NULL) {
+ return 0;
+ }
+
+ if (pFMT->formatTag != DR_WAVE_FORMAT_EXTENSIBLE) {
+ return pFMT->formatTag;
+ } else {
+ return drwav__bytes_to_u16(pFMT->subFormat); /* Only the first two bytes are required. */
+ }
+}
+
+static drwav_bool32 drwav_preinit(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pWav == NULL || onRead == NULL || onSeek == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav));
+ pWav->onRead = onRead;
+ pWav->onSeek = onSeek;
+ pWav->pUserData = pReadSeekUserData;
+ pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks);
+
+ if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) {
+ return DRWAV_FALSE; /* Invalid allocation callbacks. */
+ }
+
+ return DRWAV_TRUE;
+}
+
+static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags)
+{
+ /* This function assumes drwav_preinit() has been called beforehand. */
+
+ drwav_uint64 cursor; /* <-- Keeps track of the byte position so we can seek to specific locations. */
+ drwav_bool32 sequential;
+ drwav_uint8 riff[4];
+ drwav_fmt fmt;
+ unsigned short translatedFormatTag;
+ drwav_bool32 foundDataChunk;
+ drwav_uint64 dataChunkSize = 0; /* <-- Important! Don't explicitly set this to 0 anywhere else. Calculation of the size of the data chunk is performed in different paths depending on the container. */
+ drwav_uint64 sampleCountFromFactChunk = 0; /* Same as dataChunkSize - make sure this is the only place this is initialized to 0. */
+ drwav_uint64 chunkSize;
+
+ cursor = 0;
+ sequential = (flags & DRWAV_SEQUENTIAL) != 0;
+
+ /* The first 4 bytes should be the RIFF identifier. */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, riff, sizeof(riff), &cursor) != sizeof(riff)) {
+ return DRWAV_FALSE;
+ }
+
+ /*
+ The first 4 bytes can be used to identify the container. For RIFF files it will start with "RIFF" and for
+ w64 it will start with "riff".
+ */
+ if (drwav__fourcc_equal(riff, "RIFF")) {
+ pWav->container = drwav_container_riff;
+ } else if (drwav__fourcc_equal(riff, "riff")) {
+ int i;
+ drwav_uint8 riff2[12];
+
+ pWav->container = drwav_container_w64;
+
+ /* Check the rest of the GUID for validity. */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, riff2, sizeof(riff2), &cursor) != sizeof(riff2)) {
+ return DRWAV_FALSE;
+ }
+
+ for (i = 0; i < 12; ++i) {
+ if (riff2[i] != drwavGUID_W64_RIFF[i+4]) {
+ return DRWAV_FALSE;
+ }
+ }
+ } else if (drwav__fourcc_equal(riff, "RF64")) {
+ pWav->container = drwav_container_rf64;
+ } else {
+ return DRWAV_FALSE; /* Unknown or unsupported container. */
+ }
+
+
+ if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) {
+ drwav_uint8 chunkSizeBytes[4];
+ drwav_uint8 wave[4];
+
+ /* RIFF/WAVE */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) {
+ return DRWAV_FALSE;
+ }
+
+ if (pWav->container == drwav_container_riff) {
+ if (drwav__bytes_to_u32(chunkSizeBytes) < 36) {
+ return DRWAV_FALSE; /* Chunk size should always be at least 36 bytes. */
+ }
+ } else {
+ if (drwav__bytes_to_u32(chunkSizeBytes) != 0xFFFFFFFF) {
+ return DRWAV_FALSE; /* Chunk size should always be set to -1/0xFFFFFFFF for RF64. The actual size is retrieved later. */
+ }
+ }
+
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) {
+ return DRWAV_FALSE;
+ }
+
+ if (!drwav__fourcc_equal(wave, "WAVE")) {
+ return DRWAV_FALSE; /* Expecting "WAVE". */
+ }
+ } else {
+ drwav_uint8 chunkSizeBytes[8];
+ drwav_uint8 wave[16];
+
+ /* W64 */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) {
+ return DRWAV_FALSE;
+ }
+
+ if (drwav__bytes_to_u64(chunkSizeBytes) < 80) {
+ return DRWAV_FALSE;
+ }
+
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) {
+ return DRWAV_FALSE;
+ }
+
+ if (!drwav__guid_equal(wave, drwavGUID_W64_WAVE)) {
+ return DRWAV_FALSE;
+ }
+ }
+
+
+ /* For RF64, the "ds64" chunk must come next, before the "fmt " chunk. */
+ if (pWav->container == drwav_container_rf64) {
+ drwav_uint8 sizeBytes[8];
+ drwav_uint64 bytesRemainingInChunk;
+ drwav_chunk_header header;
+ drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header);
+ if (result != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ if (!drwav__fourcc_equal(header.id.fourcc, "ds64")) {
+ return DRWAV_FALSE; /* Expecting "ds64". */
+ }
+
+ bytesRemainingInChunk = header.sizeInBytes + header.paddingSize;
+
+ /* We don't care about the size of the RIFF chunk - skip it. */
+ if (!drwav__seek_forward(pWav->onSeek, 8, pWav->pUserData)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingInChunk -= 8;
+ cursor += 8;
+
+
+ /* Next 8 bytes is the size of the "data" chunk. */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingInChunk -= 8;
+ dataChunkSize = drwav__bytes_to_u64(sizeBytes);
+
+
+ /* Next 8 bytes is the same count which we would usually derived from the FACT chunk if it was available. */
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) {
+ return DRWAV_FALSE;
+ }
+ bytesRemainingInChunk -= 8;
+ sampleCountFromFactChunk = drwav__bytes_to_u64(sizeBytes);
+
+
+ /* Skip over everything else. */
+ if (!drwav__seek_forward(pWav->onSeek, bytesRemainingInChunk, pWav->pUserData)) {
+ return DRWAV_FALSE;
+ }
+ cursor += bytesRemainingInChunk;
+ }
+
+
+ /* The next bytes should be the "fmt " chunk. */
+ if (!drwav__read_fmt(pWav->onRead, pWav->onSeek, pWav->pUserData, pWav->container, &cursor, &fmt)) {
+ return DRWAV_FALSE; /* Failed to read the "fmt " chunk. */
+ }
+
+ /* Basic validation. */
+ if ((fmt.sampleRate == 0 || fmt.sampleRate > DRWAV_MAX_SAMPLE_RATE) ||
+ (fmt.channels == 0 || fmt.channels > DRWAV_MAX_CHANNELS) ||
+ (fmt.bitsPerSample == 0 || fmt.bitsPerSample > DRWAV_MAX_BITS_PER_SAMPLE) ||
+ fmt.blockAlign == 0) {
+ return DRWAV_FALSE; /* Probably an invalid WAV file. */
+ }
+
+
+ /* Translate the internal format. */
+ translatedFormatTag = fmt.formatTag;
+ if (translatedFormatTag == DR_WAVE_FORMAT_EXTENSIBLE) {
+ translatedFormatTag = drwav__bytes_to_u16(fmt.subFormat + 0);
+ }
+
+
+ /*
+ We need to enumerate over each chunk for two reasons:
+ 1) The "data" chunk may not be the next one
+ 2) We may want to report each chunk back to the client
+
+ In order to correctly report each chunk back to the client we will need to keep looping until the end of the file.
+ */
+ foundDataChunk = DRWAV_FALSE;
+
+ /* The next chunk we care about is the "data" chunk. This is not necessarily the next chunk so we'll need to loop. */
+ for (;;)
+ {
+ drwav_chunk_header header;
+ drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header);
+ if (result != DRWAV_SUCCESS) {
+ if (!foundDataChunk) {
+ return DRWAV_FALSE;
+ } else {
+ break; /* Probably at the end of the file. Get out of the loop. */
+ }
+ }
+
+ /* Tell the client about this chunk. */
+ if (!sequential && onChunk != NULL) {
+ drwav_uint64 callbackBytesRead = onChunk(pChunkUserData, pWav->onRead, pWav->onSeek, pWav->pUserData, &header, pWav->container, &fmt);
+
+ /*
+ dr_wav may need to read the contents of the chunk, so we now need to seek back to the position before
+ we called the callback.
+ */
+ if (callbackBytesRead > 0) {
+ if (!drwav__seek_from_start(pWav->onSeek, cursor, pWav->pUserData)) {
+ return DRWAV_FALSE;
+ }
+ }
+ }
+
+
+ if (!foundDataChunk) {
+ pWav->dataChunkDataPos = cursor;
+ }
+
+ chunkSize = header.sizeInBytes;
+ if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) {
+ if (drwav__fourcc_equal(header.id.fourcc, "data")) {
+ foundDataChunk = DRWAV_TRUE;
+ if (pWav->container != drwav_container_rf64) { /* The data chunk size for RF64 will always be set to 0xFFFFFFFF here. It was set to it's true value earlier. */
+ dataChunkSize = chunkSize;
+ }
+ }
+ } else {
+ if (drwav__guid_equal(header.id.guid, drwavGUID_W64_DATA)) {
+ foundDataChunk = DRWAV_TRUE;
+ dataChunkSize = chunkSize;
+ }
+ }
+
+ /*
+ If at this point we have found the data chunk and we're running in sequential mode, we need to break out of this loop. The reason for
+ this is that we would otherwise require a backwards seek which sequential mode forbids.
+ */
+ if (foundDataChunk && sequential) {
+ break;
+ }
+
+ /* Optional. Get the total sample count from the FACT chunk. This is useful for compressed formats. */
+ if (pWav->container == drwav_container_riff) {
+ if (drwav__fourcc_equal(header.id.fourcc, "fact")) {
+ drwav_uint32 sampleCount;
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCount, 4, &cursor) != 4) {
+ return DRWAV_FALSE;
+ }
+ chunkSize -= 4;
+
+ if (!foundDataChunk) {
+ pWav->dataChunkDataPos = cursor;
+ }
+
+ /*
+ The sample count in the "fact" chunk is either unreliable, or I'm not understanding it properly. For now I am only enabling this
+ for Microsoft ADPCM formats.
+ */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ sampleCountFromFactChunk = sampleCount;
+ } else {
+ sampleCountFromFactChunk = 0;
+ }
+ }
+ } else if (pWav->container == drwav_container_w64) {
+ if (drwav__guid_equal(header.id.guid, drwavGUID_W64_FACT)) {
+ if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCountFromFactChunk, 8, &cursor) != 8) {
+ return DRWAV_FALSE;
+ }
+ chunkSize -= 8;
+
+ if (!foundDataChunk) {
+ pWav->dataChunkDataPos = cursor;
+ }
+ }
+ } else if (pWav->container == drwav_container_rf64) {
+ /* We retrieved the sample count from the ds64 chunk earlier so no need to do that here. */
+ }
+
+ /* "smpl" chunk. */
+ if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) {
+ if (drwav__fourcc_equal(header.id.fourcc, "smpl")) {
+ drwav_uint8 smplHeaderData[36]; /* 36 = size of the smpl header section, not including the loop data. */
+ if (chunkSize >= sizeof(smplHeaderData)) {
+ drwav_uint64 bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplHeaderData, sizeof(smplHeaderData), &cursor);
+ chunkSize -= bytesJustRead;
+
+ if (bytesJustRead == sizeof(smplHeaderData)) {
+ drwav_uint32 iLoop;
+
+ pWav->smpl.manufacturer = drwav__bytes_to_u32(smplHeaderData+0);
+ pWav->smpl.product = drwav__bytes_to_u32(smplHeaderData+4);
+ pWav->smpl.samplePeriod = drwav__bytes_to_u32(smplHeaderData+8);
+ pWav->smpl.midiUnityNotes = drwav__bytes_to_u32(smplHeaderData+12);
+ pWav->smpl.midiPitchFraction = drwav__bytes_to_u32(smplHeaderData+16);
+ pWav->smpl.smpteFormat = drwav__bytes_to_u32(smplHeaderData+20);
+ pWav->smpl.smpteOffset = drwav__bytes_to_u32(smplHeaderData+24);
+ pWav->smpl.numSampleLoops = drwav__bytes_to_u32(smplHeaderData+28);
+ pWav->smpl.samplerData = drwav__bytes_to_u32(smplHeaderData+32);
+
+ for (iLoop = 0; iLoop < pWav->smpl.numSampleLoops && iLoop < drwav_countof(pWav->smpl.loops); ++iLoop) {
+ drwav_uint8 smplLoopData[24]; /* 24 = size of a loop section in the smpl chunk. */
+ bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplLoopData, sizeof(smplLoopData), &cursor);
+ chunkSize -= bytesJustRead;
+
+ if (bytesJustRead == sizeof(smplLoopData)) {
+ pWav->smpl.loops[iLoop].cuePointId = drwav__bytes_to_u32(smplLoopData+0);
+ pWav->smpl.loops[iLoop].type = drwav__bytes_to_u32(smplLoopData+4);
+ pWav->smpl.loops[iLoop].start = drwav__bytes_to_u32(smplLoopData+8);
+ pWav->smpl.loops[iLoop].end = drwav__bytes_to_u32(smplLoopData+12);
+ pWav->smpl.loops[iLoop].fraction = drwav__bytes_to_u32(smplLoopData+16);
+ pWav->smpl.loops[iLoop].playCount = drwav__bytes_to_u32(smplLoopData+20);
+ } else {
+ break; /* Break from the smpl loop for loop. */
+ }
+ }
+ }
+ } else {
+ /* Looks like invalid data. Ignore the chunk. */
+ }
+ }
+ } else {
+ if (drwav__guid_equal(header.id.guid, drwavGUID_W64_SMPL)) {
+ /*
+ This path will be hit when a W64 WAV file contains a smpl chunk. I don't have a sample file to test this path, so a contribution
+ is welcome to add support for this.
+ */
+ }
+ }
+
+ /* Make sure we seek past the padding. */
+ chunkSize += header.paddingSize;
+ if (!drwav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData)) {
+ break;
+ }
+ cursor += chunkSize;
+
+ if (!foundDataChunk) {
+ pWav->dataChunkDataPos = cursor;
+ }
+ }
+
+ /* If we haven't found a data chunk, return an error. */
+ if (!foundDataChunk) {
+ return DRWAV_FALSE;
+ }
+
+ /* We may have moved passed the data chunk. If so we need to move back. If running in sequential mode we can assume we are already sitting on the data chunk. */
+ if (!sequential) {
+ if (!drwav__seek_from_start(pWav->onSeek, pWav->dataChunkDataPos, pWav->pUserData)) {
+ return DRWAV_FALSE;
+ }
+ cursor = pWav->dataChunkDataPos;
+ }
+
+
+ /* At this point we should be sitting on the first byte of the raw audio data. */
+
+ pWav->fmt = fmt;
+ pWav->sampleRate = fmt.sampleRate;
+ pWav->channels = fmt.channels;
+ pWav->bitsPerSample = fmt.bitsPerSample;
+ pWav->bytesRemaining = dataChunkSize;
+ pWav->translatedFormatTag = translatedFormatTag;
+ pWav->dataChunkDataSize = dataChunkSize;
+
+ if (sampleCountFromFactChunk != 0) {
+ pWav->totalPCMFrameCount = sampleCountFromFactChunk;
+ } else {
+ pWav->totalPCMFrameCount = dataChunkSize / drwav_get_bytes_per_pcm_frame(pWav);
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ drwav_uint64 totalBlockHeaderSizeInBytes;
+ drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign;
+
+ /* Make sure any trailing partial block is accounted for. */
+ if ((blockCount * fmt.blockAlign) < dataChunkSize) {
+ blockCount += 1;
+ }
+
+ /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */
+ totalBlockHeaderSizeInBytes = blockCount * (6*fmt.channels);
+ pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels;
+ }
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ drwav_uint64 totalBlockHeaderSizeInBytes;
+ drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign;
+
+ /* Make sure any trailing partial block is accounted for. */
+ if ((blockCount * fmt.blockAlign) < dataChunkSize) {
+ blockCount += 1;
+ }
+
+ /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */
+ totalBlockHeaderSizeInBytes = blockCount * (4*fmt.channels);
+ pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels;
+
+ /* The header includes a decoded sample for each channel which acts as the initial predictor sample. */
+ pWav->totalPCMFrameCount += blockCount;
+ }
+ }
+
+ /* Some formats only support a certain number of channels. */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM || pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ if (pWav->channels > 2) {
+ return DRWAV_FALSE;
+ }
+ }
+
+#ifdef DR_WAV_LIBSNDFILE_COMPAT
+ /*
+ I use libsndfile as a benchmark for testing, however in the version I'm using (from the Windows installer on the libsndfile website),
+ it appears the total sample count libsndfile uses for MS-ADPCM is incorrect. It would seem they are computing the total sample count
+ from the number of blocks, however this results in the inclusion of extra silent samples at the end of the last block. The correct
+ way to know the total sample count is to inspect the "fact" chunk, which should always be present for compressed formats, and should
+ always include the sample count. This little block of code below is only used to emulate the libsndfile logic so I can properly run my
+ correctness tests against libsndfile, and is disabled by default.
+ */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign;
+ pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (6*pWav->channels))) * 2)) / fmt.channels; /* x2 because two samples per byte. */
+ }
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign;
+ pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (4*pWav->channels))) * 2) + (blockCount * pWav->channels)) / fmt.channels;
+ }
+#endif
+
+ return DRWAV_TRUE;
+}
+
+DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (!drwav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init__internal(pWav, onChunk, pChunkUserData, flags);
+}
+
+
+static drwav_uint32 drwav__riff_chunk_size_riff(drwav_uint64 dataChunkSize)
+{
+ drwav_uint64 chunkSize = 4 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 24 = "fmt " chunk. */
+ if (chunkSize > 0xFFFFFFFFUL) {
+ chunkSize = 0xFFFFFFFFUL;
+ }
+
+ return (drwav_uint32)chunkSize; /* Safe cast due to the clamp above. */
+}
+
+static drwav_uint32 drwav__data_chunk_size_riff(drwav_uint64 dataChunkSize)
+{
+ if (dataChunkSize <= 0xFFFFFFFFUL) {
+ return (drwav_uint32)dataChunkSize;
+ } else {
+ return 0xFFFFFFFFUL;
+ }
+}
+
+static drwav_uint64 drwav__riff_chunk_size_w64(drwav_uint64 dataChunkSize)
+{
+ drwav_uint64 dataSubchunkPaddingSize = drwav__chunk_padding_size_w64(dataChunkSize);
+
+ return 80 + 24 + dataChunkSize + dataSubchunkPaddingSize; /* +24 because W64 includes the size of the GUID and size fields. */
+}
+
+static drwav_uint64 drwav__data_chunk_size_w64(drwav_uint64 dataChunkSize)
+{
+ return 24 + dataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */
+}
+
+static drwav_uint64 drwav__riff_chunk_size_rf64(drwav_uint64 dataChunkSize)
+{
+ drwav_uint64 chunkSize = 4 + 36 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 36 = "ds64" chunk. 24 = "fmt " chunk. */
+ if (chunkSize > 0xFFFFFFFFUL) {
+ chunkSize = 0xFFFFFFFFUL;
+ }
+
+ return chunkSize;
+}
+
+static drwav_uint64 drwav__data_chunk_size_rf64(drwav_uint64 dataChunkSize)
+{
+ return dataChunkSize;
+}
+
+
+static size_t drwav__write(drwav* pWav, const void* pData, size_t dataSize)
+{
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->onWrite != NULL);
+
+ /* Generic write. Assumes no byte reordering required. */
+ return pWav->onWrite(pWav->pUserData, pData, dataSize);
+}
+
+static size_t drwav__write_u16ne_to_le(drwav* pWav, drwav_uint16 value)
+{
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->onWrite != NULL);
+
+ if (!drwav__is_little_endian()) {
+ value = drwav__bswap16(value);
+ }
+
+ return drwav__write(pWav, &value, 2);
+}
+
+static size_t drwav__write_u32ne_to_le(drwav* pWav, drwav_uint32 value)
+{
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->onWrite != NULL);
+
+ if (!drwav__is_little_endian()) {
+ value = drwav__bswap32(value);
+ }
+
+ return drwav__write(pWav, &value, 4);
+}
+
+static size_t drwav__write_u64ne_to_le(drwav* pWav, drwav_uint64 value)
+{
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->onWrite != NULL);
+
+ if (!drwav__is_little_endian()) {
+ value = drwav__bswap64(value);
+ }
+
+ return drwav__write(pWav, &value, 8);
+}
+
+
+static drwav_bool32 drwav_preinit_write(drwav* pWav, const drwav_data_format* pFormat, drwav_bool32 isSequential, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pWav == NULL || onWrite == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ if (!isSequential && onSeek == NULL) {
+ return DRWAV_FALSE; /* <-- onSeek is required when in non-sequential mode. */
+ }
+
+ /* Not currently supporting compressed formats. Will need to add support for the "fact" chunk before we enable this. */
+ if (pFormat->format == DR_WAVE_FORMAT_EXTENSIBLE) {
+ return DRWAV_FALSE;
+ }
+ if (pFormat->format == DR_WAVE_FORMAT_ADPCM || pFormat->format == DR_WAVE_FORMAT_DVI_ADPCM) {
+ return DRWAV_FALSE;
+ }
+
+ DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav));
+ pWav->onWrite = onWrite;
+ pWav->onSeek = onSeek;
+ pWav->pUserData = pUserData;
+ pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks);
+
+ if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) {
+ return DRWAV_FALSE; /* Invalid allocation callbacks. */
+ }
+
+ pWav->fmt.formatTag = (drwav_uint16)pFormat->format;
+ pWav->fmt.channels = (drwav_uint16)pFormat->channels;
+ pWav->fmt.sampleRate = pFormat->sampleRate;
+ pWav->fmt.avgBytesPerSec = (drwav_uint32)((pFormat->bitsPerSample * pFormat->sampleRate * pFormat->channels) / 8);
+ pWav->fmt.blockAlign = (drwav_uint16)((pFormat->channels * pFormat->bitsPerSample) / 8);
+ pWav->fmt.bitsPerSample = (drwav_uint16)pFormat->bitsPerSample;
+ pWav->fmt.extendedSize = 0;
+ pWav->isSequentialWrite = isSequential;
+
+ return DRWAV_TRUE;
+}
+
+static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount)
+{
+ /* The function assumes drwav_preinit_write() was called beforehand. */
+
+ size_t runningPos = 0;
+ drwav_uint64 initialDataChunkSize = 0;
+ drwav_uint64 chunkSizeFMT;
+
+ /*
+ The initial values for the "RIFF" and "data" chunks depends on whether or not we are initializing in sequential mode or not. In
+ sequential mode we set this to its final values straight away since they can be calculated from the total sample count. In non-
+ sequential mode we initialize it all to zero and fill it out in drwav_uninit() using a backwards seek.
+ */
+ if (pWav->isSequentialWrite) {
+ initialDataChunkSize = (totalSampleCount * pWav->fmt.bitsPerSample) / 8;
+
+ /*
+ The RIFF container has a limit on the number of samples. drwav is not allowing this. There's no practical limits for Wave64
+ so for the sake of simplicity I'm not doing any validation for that.
+ */
+ if (pFormat->container == drwav_container_riff) {
+ if (initialDataChunkSize > (0xFFFFFFFFUL - 36)) {
+ return DRWAV_FALSE; /* Not enough room to store every sample. */
+ }
+ }
+ }
+
+ pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize;
+
+
+ /* "RIFF" chunk. */
+ if (pFormat->container == drwav_container_riff) {
+ drwav_uint32 chunkSizeRIFF = 28 + (drwav_uint32)initialDataChunkSize; /* +28 = "WAVE" + [sizeof "fmt " chunk] */
+ runningPos += drwav__write(pWav, "RIFF", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeRIFF);
+ runningPos += drwav__write(pWav, "WAVE", 4);
+ } else if (pFormat->container == drwav_container_w64) {
+ drwav_uint64 chunkSizeRIFF = 80 + 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */
+ runningPos += drwav__write(pWav, drwavGUID_W64_RIFF, 16);
+ runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeRIFF);
+ runningPos += drwav__write(pWav, drwavGUID_W64_WAVE, 16);
+ } else if (pFormat->container == drwav_container_rf64) {
+ runningPos += drwav__write(pWav, "RF64", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always 0xFFFFFFFF for RF64. Set to a proper value in the "ds64" chunk. */
+ runningPos += drwav__write(pWav, "WAVE", 4);
+ }
+
+
+ /* "ds64" chunk (RF64 only). */
+ if (pFormat->container == drwav_container_rf64) {
+ drwav_uint32 initialds64ChunkSize = 28; /* 28 = [Size of RIFF (8 bytes)] + [Size of DATA (8 bytes)] + [Sample Count (8 bytes)] + [Table Length (4 bytes)]. Table length always set to 0. */
+ drwav_uint64 initialRiffChunkSize = 8 + initialds64ChunkSize + initialDataChunkSize; /* +8 for the ds64 header. */
+
+ runningPos += drwav__write(pWav, "ds64", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, initialds64ChunkSize); /* Size of ds64. */
+ runningPos += drwav__write_u64ne_to_le(pWav, initialRiffChunkSize); /* Size of RIFF. Set to true value at the end. */
+ runningPos += drwav__write_u64ne_to_le(pWav, initialDataChunkSize); /* Size of DATA. Set to true value at the end. */
+ runningPos += drwav__write_u64ne_to_le(pWav, totalSampleCount); /* Sample count. */
+ runningPos += drwav__write_u32ne_to_le(pWav, 0); /* Table length. Always set to zero in our case since we're not doing any other chunks than "DATA". */
+ }
+
+
+ /* "fmt " chunk. */
+ if (pFormat->container == drwav_container_riff || pFormat->container == drwav_container_rf64) {
+ chunkSizeFMT = 16;
+ runningPos += drwav__write(pWav, "fmt ", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, (drwav_uint32)chunkSizeFMT);
+ } else if (pFormat->container == drwav_container_w64) {
+ chunkSizeFMT = 40;
+ runningPos += drwav__write(pWav, drwavGUID_W64_FMT, 16);
+ runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeFMT);
+ }
+
+ runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.formatTag);
+ runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.channels);
+ runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.sampleRate);
+ runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.avgBytesPerSec);
+ runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.blockAlign);
+ runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.bitsPerSample);
+
+ pWav->dataChunkDataPos = runningPos;
+
+ /* "data" chunk. */
+ if (pFormat->container == drwav_container_riff) {
+ drwav_uint32 chunkSizeDATA = (drwav_uint32)initialDataChunkSize;
+ runningPos += drwav__write(pWav, "data", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeDATA);
+ } else if (pFormat->container == drwav_container_w64) {
+ drwav_uint64 chunkSizeDATA = 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */
+ runningPos += drwav__write(pWav, drwavGUID_W64_DATA, 16);
+ runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeDATA);
+ } else if (pFormat->container == drwav_container_rf64) {
+ runningPos += drwav__write(pWav, "data", 4);
+ runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always set to 0xFFFFFFFF for RF64. The true size of the data chunk is specified in the ds64 chunk. */
+ }
+
+ /*
+ The runningPos variable is incremented in the section above but is left unused which is causing some static analysis tools to detect it
+ as a dead store. I'm leaving this as-is for safety just in case I want to expand this function later to include other tags and want to
+ keep track of the running position for whatever reason. The line below should silence the static analysis tools.
+ */
+ (void)runningPos;
+
+ /* Set some properties for the client's convenience. */
+ pWav->container = pFormat->container;
+ pWav->channels = (drwav_uint16)pFormat->channels;
+ pWav->sampleRate = pFormat->sampleRate;
+ pWav->bitsPerSample = (drwav_uint16)pFormat->bitsPerSample;
+ pWav->translatedFormatTag = (drwav_uint16)pFormat->format;
+
+ return DRWAV_TRUE;
+}
+
+
+DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (!drwav_preinit_write(pWav, pFormat, DRWAV_FALSE, onWrite, onSeek, pUserData, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_write__internal(pWav, pFormat, 0); /* DRWAV_FALSE = Not Sequential */
+}
+
+DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (!drwav_preinit_write(pWav, pFormat, DRWAV_TRUE, onWrite, NULL, pUserData, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_write__internal(pWav, pFormat, totalSampleCount); /* DRWAV_TRUE = Sequential */
+}
+
+DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pFormat == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_write_sequential(pWav, pFormat, totalPCMFrameCount*pFormat->channels, onWrite, pUserData, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount)
+{
+ /* Casting totalSampleCount to drwav_int64 for VC6 compatibility. No issues in practice because nobody is going to exhaust the whole 63 bits. */
+ drwav_uint64 targetDataSizeBytes = (drwav_uint64)((drwav_int64)totalSampleCount * pFormat->channels * pFormat->bitsPerSample/8.0);
+ drwav_uint64 riffChunkSizeBytes;
+ drwav_uint64 fileSizeBytes = 0;
+
+ if (pFormat->container == drwav_container_riff) {
+ riffChunkSizeBytes = drwav__riff_chunk_size_riff(targetDataSizeBytes);
+ fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */
+ } else if (pFormat->container == drwav_container_w64) {
+ riffChunkSizeBytes = drwav__riff_chunk_size_w64(targetDataSizeBytes);
+ fileSizeBytes = riffChunkSizeBytes;
+ } else if (pFormat->container == drwav_container_rf64) {
+ riffChunkSizeBytes = drwav__riff_chunk_size_rf64(targetDataSizeBytes);
+ fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */
+ }
+
+ return fileSizeBytes;
+}
+
+
+#ifndef DR_WAV_NO_STDIO
+
+/* drwav_result_from_errno() is only used for fopen() and wfopen() so putting it inside DR_WAV_NO_STDIO for now. If something else needs this later we can move it out. */
+#include <errno.h>
+static drwav_result drwav_result_from_errno(int e)
+{
+ switch (e)
+ {
+ case 0: return DRWAV_SUCCESS;
+ #ifdef EPERM
+ case EPERM: return DRWAV_INVALID_OPERATION;
+ #endif
+ #ifdef ENOENT
+ case ENOENT: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef ESRCH
+ case ESRCH: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef EINTR
+ case EINTR: return DRWAV_INTERRUPT;
+ #endif
+ #ifdef EIO
+ case EIO: return DRWAV_IO_ERROR;
+ #endif
+ #ifdef ENXIO
+ case ENXIO: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef E2BIG
+ case E2BIG: return DRWAV_INVALID_ARGS;
+ #endif
+ #ifdef ENOEXEC
+ case ENOEXEC: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef EBADF
+ case EBADF: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef ECHILD
+ case ECHILD: return DRWAV_ERROR;
+ #endif
+ #ifdef EAGAIN
+ case EAGAIN: return DRWAV_UNAVAILABLE;
+ #endif
+ #ifdef ENOMEM
+ case ENOMEM: return DRWAV_OUT_OF_MEMORY;
+ #endif
+ #ifdef EACCES
+ case EACCES: return DRWAV_ACCESS_DENIED;
+ #endif
+ #ifdef EFAULT
+ case EFAULT: return DRWAV_BAD_ADDRESS;
+ #endif
+ #ifdef ENOTBLK
+ case ENOTBLK: return DRWAV_ERROR;
+ #endif
+ #ifdef EBUSY
+ case EBUSY: return DRWAV_BUSY;
+ #endif
+ #ifdef EEXIST
+ case EEXIST: return DRWAV_ALREADY_EXISTS;
+ #endif
+ #ifdef EXDEV
+ case EXDEV: return DRWAV_ERROR;
+ #endif
+ #ifdef ENODEV
+ case ENODEV: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef ENOTDIR
+ case ENOTDIR: return DRWAV_NOT_DIRECTORY;
+ #endif
+ #ifdef EISDIR
+ case EISDIR: return DRWAV_IS_DIRECTORY;
+ #endif
+ #ifdef EINVAL
+ case EINVAL: return DRWAV_INVALID_ARGS;
+ #endif
+ #ifdef ENFILE
+ case ENFILE: return DRWAV_TOO_MANY_OPEN_FILES;
+ #endif
+ #ifdef EMFILE
+ case EMFILE: return DRWAV_TOO_MANY_OPEN_FILES;
+ #endif
+ #ifdef ENOTTY
+ case ENOTTY: return DRWAV_INVALID_OPERATION;
+ #endif
+ #ifdef ETXTBSY
+ case ETXTBSY: return DRWAV_BUSY;
+ #endif
+ #ifdef EFBIG
+ case EFBIG: return DRWAV_TOO_BIG;
+ #endif
+ #ifdef ENOSPC
+ case ENOSPC: return DRWAV_NO_SPACE;
+ #endif
+ #ifdef ESPIPE
+ case ESPIPE: return DRWAV_BAD_SEEK;
+ #endif
+ #ifdef EROFS
+ case EROFS: return DRWAV_ACCESS_DENIED;
+ #endif
+ #ifdef EMLINK
+ case EMLINK: return DRWAV_TOO_MANY_LINKS;
+ #endif
+ #ifdef EPIPE
+ case EPIPE: return DRWAV_BAD_PIPE;
+ #endif
+ #ifdef EDOM
+ case EDOM: return DRWAV_OUT_OF_RANGE;
+ #endif
+ #ifdef ERANGE
+ case ERANGE: return DRWAV_OUT_OF_RANGE;
+ #endif
+ #ifdef EDEADLK
+ case EDEADLK: return DRWAV_DEADLOCK;
+ #endif
+ #ifdef ENAMETOOLONG
+ case ENAMETOOLONG: return DRWAV_PATH_TOO_LONG;
+ #endif
+ #ifdef ENOLCK
+ case ENOLCK: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOSYS
+ case ENOSYS: return DRWAV_NOT_IMPLEMENTED;
+ #endif
+ #ifdef ENOTEMPTY
+ case ENOTEMPTY: return DRWAV_DIRECTORY_NOT_EMPTY;
+ #endif
+ #ifdef ELOOP
+ case ELOOP: return DRWAV_TOO_MANY_LINKS;
+ #endif
+ #ifdef ENOMSG
+ case ENOMSG: return DRWAV_NO_MESSAGE;
+ #endif
+ #ifdef EIDRM
+ case EIDRM: return DRWAV_ERROR;
+ #endif
+ #ifdef ECHRNG
+ case ECHRNG: return DRWAV_ERROR;
+ #endif
+ #ifdef EL2NSYNC
+ case EL2NSYNC: return DRWAV_ERROR;
+ #endif
+ #ifdef EL3HLT
+ case EL3HLT: return DRWAV_ERROR;
+ #endif
+ #ifdef EL3RST
+ case EL3RST: return DRWAV_ERROR;
+ #endif
+ #ifdef ELNRNG
+ case ELNRNG: return DRWAV_OUT_OF_RANGE;
+ #endif
+ #ifdef EUNATCH
+ case EUNATCH: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOCSI
+ case ENOCSI: return DRWAV_ERROR;
+ #endif
+ #ifdef EL2HLT
+ case EL2HLT: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADE
+ case EBADE: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADR
+ case EBADR: return DRWAV_ERROR;
+ #endif
+ #ifdef EXFULL
+ case EXFULL: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOANO
+ case ENOANO: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADRQC
+ case EBADRQC: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADSLT
+ case EBADSLT: return DRWAV_ERROR;
+ #endif
+ #ifdef EBFONT
+ case EBFONT: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef ENOSTR
+ case ENOSTR: return DRWAV_ERROR;
+ #endif
+ #ifdef ENODATA
+ case ENODATA: return DRWAV_NO_DATA_AVAILABLE;
+ #endif
+ #ifdef ETIME
+ case ETIME: return DRWAV_TIMEOUT;
+ #endif
+ #ifdef ENOSR
+ case ENOSR: return DRWAV_NO_DATA_AVAILABLE;
+ #endif
+ #ifdef ENONET
+ case ENONET: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ENOPKG
+ case ENOPKG: return DRWAV_ERROR;
+ #endif
+ #ifdef EREMOTE
+ case EREMOTE: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOLINK
+ case ENOLINK: return DRWAV_ERROR;
+ #endif
+ #ifdef EADV
+ case EADV: return DRWAV_ERROR;
+ #endif
+ #ifdef ESRMNT
+ case ESRMNT: return DRWAV_ERROR;
+ #endif
+ #ifdef ECOMM
+ case ECOMM: return DRWAV_ERROR;
+ #endif
+ #ifdef EPROTO
+ case EPROTO: return DRWAV_ERROR;
+ #endif
+ #ifdef EMULTIHOP
+ case EMULTIHOP: return DRWAV_ERROR;
+ #endif
+ #ifdef EDOTDOT
+ case EDOTDOT: return DRWAV_ERROR;
+ #endif
+ #ifdef EBADMSG
+ case EBADMSG: return DRWAV_BAD_MESSAGE;
+ #endif
+ #ifdef EOVERFLOW
+ case EOVERFLOW: return DRWAV_TOO_BIG;
+ #endif
+ #ifdef ENOTUNIQ
+ case ENOTUNIQ: return DRWAV_NOT_UNIQUE;
+ #endif
+ #ifdef EBADFD
+ case EBADFD: return DRWAV_ERROR;
+ #endif
+ #ifdef EREMCHG
+ case EREMCHG: return DRWAV_ERROR;
+ #endif
+ #ifdef ELIBACC
+ case ELIBACC: return DRWAV_ACCESS_DENIED;
+ #endif
+ #ifdef ELIBBAD
+ case ELIBBAD: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef ELIBSCN
+ case ELIBSCN: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef ELIBMAX
+ case ELIBMAX: return DRWAV_ERROR;
+ #endif
+ #ifdef ELIBEXEC
+ case ELIBEXEC: return DRWAV_ERROR;
+ #endif
+ #ifdef EILSEQ
+ case EILSEQ: return DRWAV_INVALID_DATA;
+ #endif
+ #ifdef ERESTART
+ case ERESTART: return DRWAV_ERROR;
+ #endif
+ #ifdef ESTRPIPE
+ case ESTRPIPE: return DRWAV_ERROR;
+ #endif
+ #ifdef EUSERS
+ case EUSERS: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOTSOCK
+ case ENOTSOCK: return DRWAV_NOT_SOCKET;
+ #endif
+ #ifdef EDESTADDRREQ
+ case EDESTADDRREQ: return DRWAV_NO_ADDRESS;
+ #endif
+ #ifdef EMSGSIZE
+ case EMSGSIZE: return DRWAV_TOO_BIG;
+ #endif
+ #ifdef EPROTOTYPE
+ case EPROTOTYPE: return DRWAV_BAD_PROTOCOL;
+ #endif
+ #ifdef ENOPROTOOPT
+ case ENOPROTOOPT: return DRWAV_PROTOCOL_UNAVAILABLE;
+ #endif
+ #ifdef EPROTONOSUPPORT
+ case EPROTONOSUPPORT: return DRWAV_PROTOCOL_NOT_SUPPORTED;
+ #endif
+ #ifdef ESOCKTNOSUPPORT
+ case ESOCKTNOSUPPORT: return DRWAV_SOCKET_NOT_SUPPORTED;
+ #endif
+ #ifdef EOPNOTSUPP
+ case EOPNOTSUPP: return DRWAV_INVALID_OPERATION;
+ #endif
+ #ifdef EPFNOSUPPORT
+ case EPFNOSUPPORT: return DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED;
+ #endif
+ #ifdef EAFNOSUPPORT
+ case EAFNOSUPPORT: return DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED;
+ #endif
+ #ifdef EADDRINUSE
+ case EADDRINUSE: return DRWAV_ALREADY_IN_USE;
+ #endif
+ #ifdef EADDRNOTAVAIL
+ case EADDRNOTAVAIL: return DRWAV_ERROR;
+ #endif
+ #ifdef ENETDOWN
+ case ENETDOWN: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ENETUNREACH
+ case ENETUNREACH: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ENETRESET
+ case ENETRESET: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ECONNABORTED
+ case ECONNABORTED: return DRWAV_NO_NETWORK;
+ #endif
+ #ifdef ECONNRESET
+ case ECONNRESET: return DRWAV_CONNECTION_RESET;
+ #endif
+ #ifdef ENOBUFS
+ case ENOBUFS: return DRWAV_NO_SPACE;
+ #endif
+ #ifdef EISCONN
+ case EISCONN: return DRWAV_ALREADY_CONNECTED;
+ #endif
+ #ifdef ENOTCONN
+ case ENOTCONN: return DRWAV_NOT_CONNECTED;
+ #endif
+ #ifdef ESHUTDOWN
+ case ESHUTDOWN: return DRWAV_ERROR;
+ #endif
+ #ifdef ETOOMANYREFS
+ case ETOOMANYREFS: return DRWAV_ERROR;
+ #endif
+ #ifdef ETIMEDOUT
+ case ETIMEDOUT: return DRWAV_TIMEOUT;
+ #endif
+ #ifdef ECONNREFUSED
+ case ECONNREFUSED: return DRWAV_CONNECTION_REFUSED;
+ #endif
+ #ifdef EHOSTDOWN
+ case EHOSTDOWN: return DRWAV_NO_HOST;
+ #endif
+ #ifdef EHOSTUNREACH
+ case EHOSTUNREACH: return DRWAV_NO_HOST;
+ #endif
+ #ifdef EALREADY
+ case EALREADY: return DRWAV_IN_PROGRESS;
+ #endif
+ #ifdef EINPROGRESS
+ case EINPROGRESS: return DRWAV_IN_PROGRESS;
+ #endif
+ #ifdef ESTALE
+ case ESTALE: return DRWAV_INVALID_FILE;
+ #endif
+ #ifdef EUCLEAN
+ case EUCLEAN: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOTNAM
+ case ENOTNAM: return DRWAV_ERROR;
+ #endif
+ #ifdef ENAVAIL
+ case ENAVAIL: return DRWAV_ERROR;
+ #endif
+ #ifdef EISNAM
+ case EISNAM: return DRWAV_ERROR;
+ #endif
+ #ifdef EREMOTEIO
+ case EREMOTEIO: return DRWAV_IO_ERROR;
+ #endif
+ #ifdef EDQUOT
+ case EDQUOT: return DRWAV_NO_SPACE;
+ #endif
+ #ifdef ENOMEDIUM
+ case ENOMEDIUM: return DRWAV_DOES_NOT_EXIST;
+ #endif
+ #ifdef EMEDIUMTYPE
+ case EMEDIUMTYPE: return DRWAV_ERROR;
+ #endif
+ #ifdef ECANCELED
+ case ECANCELED: return DRWAV_CANCELLED;
+ #endif
+ #ifdef ENOKEY
+ case ENOKEY: return DRWAV_ERROR;
+ #endif
+ #ifdef EKEYEXPIRED
+ case EKEYEXPIRED: return DRWAV_ERROR;
+ #endif
+ #ifdef EKEYREVOKED
+ case EKEYREVOKED: return DRWAV_ERROR;
+ #endif
+ #ifdef EKEYREJECTED
+ case EKEYREJECTED: return DRWAV_ERROR;
+ #endif
+ #ifdef EOWNERDEAD
+ case EOWNERDEAD: return DRWAV_ERROR;
+ #endif
+ #ifdef ENOTRECOVERABLE
+ case ENOTRECOVERABLE: return DRWAV_ERROR;
+ #endif
+ #ifdef ERFKILL
+ case ERFKILL: return DRWAV_ERROR;
+ #endif
+ #ifdef EHWPOISON
+ case EHWPOISON: return DRWAV_ERROR;
+ #endif
+ default: return DRWAV_ERROR;
+ }
+}
+
+static drwav_result drwav_fopen(FILE** ppFile, const char* pFilePath, const char* pOpenMode)
+{
+#if _MSC_VER && _MSC_VER >= 1400
+ errno_t err;
+#endif
+
+ if (ppFile != NULL) {
+ *ppFile = NULL; /* Safety. */
+ }
+
+ if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) {
+ return DRWAV_INVALID_ARGS;
+ }
+
+#if _MSC_VER && _MSC_VER >= 1400
+ err = fopen_s(ppFile, pFilePath, pOpenMode);
+ if (err != 0) {
+ return drwav_result_from_errno(err);
+ }
+#else
+#if defined(_WIN32) || defined(__APPLE__)
+ *ppFile = fopen(pFilePath, pOpenMode);
+#else
+ #if defined(_FILE_OFFSET_BITS) && _FILE_OFFSET_BITS == 64 && defined(_LARGEFILE64_SOURCE)
+ *ppFile = fopen64(pFilePath, pOpenMode);
+ #else
+ *ppFile = fopen(pFilePath, pOpenMode);
+ #endif
+#endif
+ if (*ppFile == NULL) {
+ drwav_result result = drwav_result_from_errno(errno);
+ if (result == DRWAV_SUCCESS) {
+ result = DRWAV_ERROR; /* Just a safety check to make sure we never ever return success when pFile == NULL. */
+ }
+
+ return result;
+ }
+#endif
+
+ return DRWAV_SUCCESS;
+}
+
+/*
+_wfopen() isn't always available in all compilation environments.
+
+ * Windows only.
+ * MSVC seems to support it universally as far back as VC6 from what I can tell (haven't checked further back).
+ * MinGW-64 (both 32- and 64-bit) seems to support it.
+ * MinGW wraps it in !defined(__STRICT_ANSI__).
+ * OpenWatcom wraps it in !defined(_NO_EXT_KEYS).
+
+This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs()
+fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support.
+*/
+#if defined(_WIN32)
+ #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS))
+ #define DRWAV_HAS_WFOPEN
+ #endif
+#endif
+
+static drwav_result drwav_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_t* pOpenMode, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (ppFile != NULL) {
+ *ppFile = NULL; /* Safety. */
+ }
+
+ if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) {
+ return DRWAV_INVALID_ARGS;
+ }
+
+#if defined(DRWAV_HAS_WFOPEN)
+ {
+ /* Use _wfopen() on Windows. */
+ #if defined(_MSC_VER) && _MSC_VER >= 1400
+ errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode);
+ if (err != 0) {
+ return drwav_result_from_errno(err);
+ }
+ #else
+ *ppFile = _wfopen(pFilePath, pOpenMode);
+ if (*ppFile == NULL) {
+ return drwav_result_from_errno(errno);
+ }
+ #endif
+ (void)pAllocationCallbacks;
+ }
+#else
+ /*
+ Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can
+ think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for
+ maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility.
+ */
+ {
+ mbstate_t mbs;
+ size_t lenMB;
+ const wchar_t* pFilePathTemp = pFilePath;
+ char* pFilePathMB = NULL;
+ char pOpenModeMB[32] = {0};
+
+ /* Get the length first. */
+ DRWAV_ZERO_OBJECT(&mbs);
+ lenMB = wcsrtombs(NULL, &pFilePathTemp, 0, &mbs);
+ if (lenMB == (size_t)-1) {
+ return drwav_result_from_errno(errno);
+ }
+
+ pFilePathMB = (char*)drwav__malloc_from_callbacks(lenMB + 1, pAllocationCallbacks);
+ if (pFilePathMB == NULL) {
+ return DRWAV_OUT_OF_MEMORY;
+ }
+
+ pFilePathTemp = pFilePath;
+ DRWAV_ZERO_OBJECT(&mbs);
+ wcsrtombs(pFilePathMB, &pFilePathTemp, lenMB + 1, &mbs);
+
+ /* The open mode should always consist of ASCII characters so we should be able to do a trivial conversion. */
+ {
+ size_t i = 0;
+ for (;;) {
+ if (pOpenMode[i] == 0) {
+ pOpenModeMB[i] = '\0';
+ break;
+ }
+
+ pOpenModeMB[i] = (char)pOpenMode[i];
+ i += 1;
+ }
+ }
+
+ *ppFile = fopen(pFilePathMB, pOpenModeMB);
+
+ drwav__free_from_callbacks(pFilePathMB, pAllocationCallbacks);
+ }
+
+ if (*ppFile == NULL) {
+ return DRWAV_ERROR;
+ }
+#endif
+
+ return DRWAV_SUCCESS;
+}
+
+
+static size_t drwav__on_read_stdio(void* pUserData, void* pBufferOut, size_t bytesToRead)
+{
+ return fread(pBufferOut, 1, bytesToRead, (FILE*)pUserData);
+}
+
+static size_t drwav__on_write_stdio(void* pUserData, const void* pData, size_t bytesToWrite)
+{
+ return fwrite(pData, 1, bytesToWrite, (FILE*)pUserData);
+}
+
+static drwav_bool32 drwav__on_seek_stdio(void* pUserData, int offset, drwav_seek_origin origin)
+{
+ return fseek((FILE*)pUserData, offset, (origin == drwav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0;
+}
+
+DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_ex(pWav, filename, NULL, NULL, 0, pAllocationCallbacks);
+}
+
+
+static drwav_bool32 drwav_init_file__internal_FILE(drwav* pWav, FILE* pFile, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav_bool32 result;
+
+ result = drwav_preinit(pWav, drwav__on_read_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks);
+ if (result != DRWAV_TRUE) {
+ fclose(pFile);
+ return result;
+ }
+
+ result = drwav_init__internal(pWav, onChunk, pChunkUserData, flags);
+ if (result != DRWAV_TRUE) {
+ fclose(pFile);
+ return result;
+ }
+
+ return DRWAV_TRUE;
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ FILE* pFile;
+ if (drwav_fopen(&pFile, filename, "rb") != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ /* This takes ownership of the FILE* object. */
+ return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_ex_w(pWav, filename, NULL, NULL, 0, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ FILE* pFile;
+ if (drwav_wfopen(&pFile, filename, L"rb", pAllocationCallbacks) != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ /* This takes ownership of the FILE* object. */
+ return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks);
+}
+
+
+static drwav_bool32 drwav_init_file_write__internal_FILE(drwav* pWav, FILE* pFile, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav_bool32 result;
+
+ result = drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks);
+ if (result != DRWAV_TRUE) {
+ fclose(pFile);
+ return result;
+ }
+
+ result = drwav_init_write__internal(pWav, pFormat, totalSampleCount);
+ if (result != DRWAV_TRUE) {
+ fclose(pFile);
+ return result;
+ }
+
+ return DRWAV_TRUE;
+}
+
+static drwav_bool32 drwav_init_file_write__internal(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ FILE* pFile;
+ if (drwav_fopen(&pFile, filename, "wb") != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ /* This takes ownership of the FILE* object. */
+ return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks);
+}
+
+static drwav_bool32 drwav_init_file_write_w__internal(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ FILE* pFile;
+ if (drwav_wfopen(&pFile, filename, L"wb", pAllocationCallbacks) != DRWAV_SUCCESS) {
+ return DRWAV_FALSE;
+ }
+
+ /* This takes ownership of the FILE* object. */
+ return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_write__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_write__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pFormat == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_file_write_sequential(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_write_w__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_file_write_w__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pFormat == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_file_write_sequential_w(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks);
+}
+#endif /* DR_WAV_NO_STDIO */
+
+
+static size_t drwav__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead)
+{
+ drwav* pWav = (drwav*)pUserData;
+ size_t bytesRemaining;
+
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->memoryStream.dataSize >= pWav->memoryStream.currentReadPos);
+
+ bytesRemaining = pWav->memoryStream.dataSize - pWav->memoryStream.currentReadPos;
+ if (bytesToRead > bytesRemaining) {
+ bytesToRead = bytesRemaining;
+ }
+
+ if (bytesToRead > 0) {
+ DRWAV_COPY_MEMORY(pBufferOut, pWav->memoryStream.data + pWav->memoryStream.currentReadPos, bytesToRead);
+ pWav->memoryStream.currentReadPos += bytesToRead;
+ }
+
+ return bytesToRead;
+}
+
+static drwav_bool32 drwav__on_seek_memory(void* pUserData, int offset, drwav_seek_origin origin)
+{
+ drwav* pWav = (drwav*)pUserData;
+ DRWAV_ASSERT(pWav != NULL);
+
+ if (origin == drwav_seek_origin_current) {
+ if (offset > 0) {
+ if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) {
+ return DRWAV_FALSE; /* Trying to seek too far forward. */
+ }
+ } else {
+ if (pWav->memoryStream.currentReadPos < (size_t)-offset) {
+ return DRWAV_FALSE; /* Trying to seek too far backwards. */
+ }
+ }
+
+ /* This will never underflow thanks to the clamps above. */
+ pWav->memoryStream.currentReadPos += offset;
+ } else {
+ if ((drwav_uint32)offset <= pWav->memoryStream.dataSize) {
+ pWav->memoryStream.currentReadPos = offset;
+ } else {
+ return DRWAV_FALSE; /* Trying to seek too far forward. */
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+static size_t drwav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite)
+{
+ drwav* pWav = (drwav*)pUserData;
+ size_t bytesRemaining;
+
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(pWav->memoryStreamWrite.dataCapacity >= pWav->memoryStreamWrite.currentWritePos);
+
+ bytesRemaining = pWav->memoryStreamWrite.dataCapacity - pWav->memoryStreamWrite.currentWritePos;
+ if (bytesRemaining < bytesToWrite) {
+ /* Need to reallocate. */
+ void* pNewData;
+ size_t newDataCapacity = (pWav->memoryStreamWrite.dataCapacity == 0) ? 256 : pWav->memoryStreamWrite.dataCapacity * 2;
+
+ /* If doubling wasn't enough, just make it the minimum required size to write the data. */
+ if ((newDataCapacity - pWav->memoryStreamWrite.currentWritePos) < bytesToWrite) {
+ newDataCapacity = pWav->memoryStreamWrite.currentWritePos + bytesToWrite;
+ }
+
+ pNewData = drwav__realloc_from_callbacks(*pWav->memoryStreamWrite.ppData, newDataCapacity, pWav->memoryStreamWrite.dataCapacity, &pWav->allocationCallbacks);
+ if (pNewData == NULL) {
+ return 0;
+ }
+
+ *pWav->memoryStreamWrite.ppData = pNewData;
+ pWav->memoryStreamWrite.dataCapacity = newDataCapacity;
+ }
+
+ DRWAV_COPY_MEMORY(((drwav_uint8*)(*pWav->memoryStreamWrite.ppData)) + pWav->memoryStreamWrite.currentWritePos, pDataIn, bytesToWrite);
+
+ pWav->memoryStreamWrite.currentWritePos += bytesToWrite;
+ if (pWav->memoryStreamWrite.dataSize < pWav->memoryStreamWrite.currentWritePos) {
+ pWav->memoryStreamWrite.dataSize = pWav->memoryStreamWrite.currentWritePos;
+ }
+
+ *pWav->memoryStreamWrite.pDataSize = pWav->memoryStreamWrite.dataSize;
+
+ return bytesToWrite;
+}
+
+static drwav_bool32 drwav__on_seek_memory_write(void* pUserData, int offset, drwav_seek_origin origin)
+{
+ drwav* pWav = (drwav*)pUserData;
+ DRWAV_ASSERT(pWav != NULL);
+
+ if (origin == drwav_seek_origin_current) {
+ if (offset > 0) {
+ if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) {
+ offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); /* Trying to seek too far forward. */
+ }
+ } else {
+ if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) {
+ offset = -(int)pWav->memoryStreamWrite.currentWritePos; /* Trying to seek too far backwards. */
+ }
+ }
+
+ /* This will never underflow thanks to the clamps above. */
+ pWav->memoryStreamWrite.currentWritePos += offset;
+ } else {
+ if ((drwav_uint32)offset <= pWav->memoryStreamWrite.dataSize) {
+ pWav->memoryStreamWrite.currentWritePos = offset;
+ } else {
+ pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; /* Trying to seek too far forward. */
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_memory_ex(pWav, data, dataSize, NULL, NULL, 0, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (data == NULL || dataSize == 0) {
+ return DRWAV_FALSE;
+ }
+
+ if (!drwav_preinit(pWav, drwav__on_read_memory, drwav__on_seek_memory, pWav, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ pWav->memoryStream.data = (const drwav_uint8*)data;
+ pWav->memoryStream.dataSize = dataSize;
+ pWav->memoryStream.currentReadPos = 0;
+
+ return drwav_init__internal(pWav, onChunk, pChunkUserData, flags);
+}
+
+
+static drwav_bool32 drwav_init_memory_write__internal(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (ppData == NULL || pDataSize == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ *ppData = NULL; /* Important because we're using realloc()! */
+ *pDataSize = 0;
+
+ if (!drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_memory, drwav__on_seek_memory_write, pWav, pAllocationCallbacks)) {
+ return DRWAV_FALSE;
+ }
+
+ pWav->memoryStreamWrite.ppData = ppData;
+ pWav->memoryStreamWrite.pDataSize = pDataSize;
+ pWav->memoryStreamWrite.dataSize = 0;
+ pWav->memoryStreamWrite.dataCapacity = 0;
+ pWav->memoryStreamWrite.currentWritePos = 0;
+
+ return drwav_init_write__internal(pWav, pFormat, totalSampleCount);
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks);
+}
+
+DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pFormat == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ return drwav_init_memory_write_sequential(pWav, ppData, pDataSize, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks);
+}
+
+
+
+DRWAV_API drwav_result drwav_uninit(drwav* pWav)
+{
+ drwav_result result = DRWAV_SUCCESS;
+
+ if (pWav == NULL) {
+ return DRWAV_INVALID_ARGS;
+ }
+
+ /*
+ If the drwav object was opened in write mode we'll need to finalize a few things:
+ - Make sure the "data" chunk is aligned to 16-bits for RIFF containers, or 64 bits for W64 containers.
+ - Set the size of the "data" chunk.
+ */
+ if (pWav->onWrite != NULL) {
+ drwav_uint32 paddingSize = 0;
+
+ /* Padding. Do not adjust pWav->dataChunkDataSize - this should not include the padding. */
+ if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) {
+ paddingSize = drwav__chunk_padding_size_riff(pWav->dataChunkDataSize);
+ } else {
+ paddingSize = drwav__chunk_padding_size_w64(pWav->dataChunkDataSize);
+ }
+
+ if (paddingSize > 0) {
+ drwav_uint64 paddingData = 0;
+ drwav__write(pWav, &paddingData, paddingSize); /* Byte order does not matter for this. */
+ }
+
+ /*
+ Chunk sizes. When using sequential mode, these will have been filled in at initialization time. We only need
+ to do this when using non-sequential mode.
+ */
+ if (pWav->onSeek && !pWav->isSequentialWrite) {
+ if (pWav->container == drwav_container_riff) {
+ /* The "RIFF" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, 4, drwav_seek_origin_start)) {
+ drwav_uint32 riffChunkSize = drwav__riff_chunk_size_riff(pWav->dataChunkDataSize);
+ drwav__write_u32ne_to_le(pWav, riffChunkSize);
+ }
+
+ /* the "data" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 4, drwav_seek_origin_start)) {
+ drwav_uint32 dataChunkSize = drwav__data_chunk_size_riff(pWav->dataChunkDataSize);
+ drwav__write_u32ne_to_le(pWav, dataChunkSize);
+ }
+ } else if (pWav->container == drwav_container_w64) {
+ /* The "RIFF" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, 16, drwav_seek_origin_start)) {
+ drwav_uint64 riffChunkSize = drwav__riff_chunk_size_w64(pWav->dataChunkDataSize);
+ drwav__write_u64ne_to_le(pWav, riffChunkSize);
+ }
+
+ /* The "data" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 16, drwav_seek_origin_start)) {
+ drwav_uint64 dataChunkSize = drwav__data_chunk_size_w64(pWav->dataChunkDataSize);
+ drwav__write_u64ne_to_le(pWav, dataChunkSize);
+ }
+ } else if (pWav->container == drwav_container_rf64) {
+ /* We only need to update the ds64 chunk. The "RIFF" and "data" chunks always have their sizes set to 0xFFFFFFFF for RF64. */
+ int ds64BodyPos = 12 + 8;
+
+ /* The "RIFF" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, drwav_seek_origin_start)) {
+ drwav_uint64 riffChunkSize = drwav__riff_chunk_size_rf64(pWav->dataChunkDataSize);
+ drwav__write_u64ne_to_le(pWav, riffChunkSize);
+ }
+
+ /* The "data" chunk size. */
+ if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, drwav_seek_origin_start)) {
+ drwav_uint64 dataChunkSize = drwav__data_chunk_size_rf64(pWav->dataChunkDataSize);
+ drwav__write_u64ne_to_le(pWav, dataChunkSize);
+ }
+ }
+ }
+
+ /* Validation for sequential mode. */
+ if (pWav->isSequentialWrite) {
+ if (pWav->dataChunkDataSize != pWav->dataChunkDataSizeTargetWrite) {
+ result = DRWAV_INVALID_FILE;
+ }
+ }
+ }
+
+#ifndef DR_WAV_NO_STDIO
+ /*
+ If we opened the file with drwav_open_file() we will want to close the file handle. We can know whether or not drwav_open_file()
+ was used by looking at the onRead and onSeek callbacks.
+ */
+ if (pWav->onRead == drwav__on_read_stdio || pWav->onWrite == drwav__on_write_stdio) {
+ fclose((FILE*)pWav->pUserData);
+ }
+#endif
+
+ return result;
+}
+
+
+
+DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut)
+{
+ size_t bytesRead;
+
+ if (pWav == NULL || bytesToRead == 0) {
+ return 0;
+ }
+
+ if (bytesToRead > pWav->bytesRemaining) {
+ bytesToRead = (size_t)pWav->bytesRemaining;
+ }
+
+ if (pBufferOut != NULL) {
+ bytesRead = pWav->onRead(pWav->pUserData, pBufferOut, bytesToRead);
+ } else {
+ /* We need to seek. If we fail, we need to read-and-discard to make sure we get a good byte count. */
+ bytesRead = 0;
+ while (bytesRead < bytesToRead) {
+ size_t bytesToSeek = (bytesToRead - bytesRead);
+ if (bytesToSeek > 0x7FFFFFFF) {
+ bytesToSeek = 0x7FFFFFFF;
+ }
+
+ if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, drwav_seek_origin_current) == DRWAV_FALSE) {
+ break;
+ }
+
+ bytesRead += bytesToSeek;
+ }
+
+ /* When we get here we may need to read-and-discard some data. */
+ while (bytesRead < bytesToRead) {
+ drwav_uint8 buffer[4096];
+ size_t bytesSeeked;
+ size_t bytesToSeek = (bytesToRead - bytesRead);
+ if (bytesToSeek > sizeof(buffer)) {
+ bytesToSeek = sizeof(buffer);
+ }
+
+ bytesSeeked = pWav->onRead(pWav->pUserData, buffer, bytesToSeek);
+ bytesRead += bytesSeeked;
+
+ if (bytesSeeked < bytesToSeek) {
+ break; /* Reached the end. */
+ }
+ }
+ }
+
+ pWav->bytesRemaining -= bytesRead;
+ return bytesRead;
+}
+
+
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut)
+{
+ drwav_uint32 bytesPerFrame;
+ drwav_uint64 bytesToRead; /* Intentionally uint64 instead of size_t so we can do a check that we're not reading too much on 32-bit builds. */
+
+ if (pWav == NULL || framesToRead == 0) {
+ return 0;
+ }
+
+ /* Cannot use this function for compressed formats. */
+ if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) {
+ return 0;
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ /* Don't try to read more samples than can potentially fit in the output buffer. */
+ bytesToRead = framesToRead * bytesPerFrame;
+ if (bytesToRead > DRWAV_SIZE_MAX) {
+ bytesToRead = (DRWAV_SIZE_MAX / bytesPerFrame) * bytesPerFrame; /* Round the number of bytes to read to a clean frame boundary. */
+ }
+
+ /*
+ Doing an explicit check here just to make it clear that we don't want to be attempt to read anything if there's no bytes to read. There
+ *could* be a time where it evaluates to 0 due to overflowing.
+ */
+ if (bytesToRead == 0) {
+ return 0;
+ }
+
+ return drwav_read_raw(pWav, (size_t)bytesToRead, pBufferOut) / bytesPerFrame;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut);
+
+ if (pBufferOut != NULL) {
+ drwav__bswap_samples(pBufferOut, framesRead*pWav->channels, drwav_get_bytes_per_pcm_frame(pWav)/pWav->channels, pWav->translatedFormatTag);
+ }
+
+ return framesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut)
+{
+ if (drwav__is_little_endian()) {
+ return drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut);
+ } else {
+ return drwav_read_pcm_frames_be(pWav, framesToRead, pBufferOut);
+ }
+}
+
+
+
+DRWAV_API drwav_bool32 drwav_seek_to_first_pcm_frame(drwav* pWav)
+{
+ if (pWav->onWrite != NULL) {
+ return DRWAV_FALSE; /* No seeking in write mode. */
+ }
+
+ if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, drwav_seek_origin_start)) {
+ return DRWAV_FALSE;
+ }
+
+ if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) {
+ pWav->compressed.iCurrentPCMFrame = 0;
+
+ /* Cached data needs to be cleared for compressed formats. */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ DRWAV_ZERO_OBJECT(&pWav->msadpcm);
+ } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ DRWAV_ZERO_OBJECT(&pWav->ima);
+ } else {
+ DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */
+ }
+ }
+
+ pWav->bytesRemaining = pWav->dataChunkDataSize;
+ return DRWAV_TRUE;
+}
+
+DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex)
+{
+ /* Seeking should be compatible with wave files > 2GB. */
+
+ if (pWav == NULL || pWav->onSeek == NULL) {
+ return DRWAV_FALSE;
+ }
+
+ /* No seeking in write mode. */
+ if (pWav->onWrite != NULL) {
+ return DRWAV_FALSE;
+ }
+
+ /* If there are no samples, just return DRWAV_TRUE without doing anything. */
+ if (pWav->totalPCMFrameCount == 0) {
+ return DRWAV_TRUE;
+ }
+
+ /* Make sure the sample is clamped. */
+ if (targetFrameIndex >= pWav->totalPCMFrameCount) {
+ targetFrameIndex = pWav->totalPCMFrameCount - 1;
+ }
+
+ /*
+ For compressed formats we just use a slow generic seek. If we are seeking forward we just seek forward. If we are going backwards we need
+ to seek back to the start.
+ */
+ if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) {
+ /* TODO: This can be optimized. */
+
+ /*
+ If we're seeking forward it's simple - just keep reading samples until we hit the sample we're requesting. If we're seeking backwards,
+ we first need to seek back to the start and then just do the same thing as a forward seek.
+ */
+ if (targetFrameIndex < pWav->compressed.iCurrentPCMFrame) {
+ if (!drwav_seek_to_first_pcm_frame(pWav)) {
+ return DRWAV_FALSE;
+ }
+ }
+
+ if (targetFrameIndex > pWav->compressed.iCurrentPCMFrame) {
+ drwav_uint64 offsetInFrames = targetFrameIndex - pWav->compressed.iCurrentPCMFrame;
+
+ drwav_int16 devnull[2048];
+ while (offsetInFrames > 0) {
+ drwav_uint64 framesRead = 0;
+ drwav_uint64 framesToRead = offsetInFrames;
+ if (framesToRead > drwav_countof(devnull)/pWav->channels) {
+ framesToRead = drwav_countof(devnull)/pWav->channels;
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ framesRead = drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, devnull);
+ } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ framesRead = drwav_read_pcm_frames_s16__ima(pWav, framesToRead, devnull);
+ } else {
+ DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */
+ }
+
+ if (framesRead != framesToRead) {
+ return DRWAV_FALSE;
+ }
+
+ offsetInFrames -= framesRead;
+ }
+ }
+ } else {
+ drwav_uint64 totalSizeInBytes;
+ drwav_uint64 currentBytePos;
+ drwav_uint64 targetBytePos;
+ drwav_uint64 offset;
+
+ totalSizeInBytes = pWav->totalPCMFrameCount * drwav_get_bytes_per_pcm_frame(pWav);
+ DRWAV_ASSERT(totalSizeInBytes >= pWav->bytesRemaining);
+
+ currentBytePos = totalSizeInBytes - pWav->bytesRemaining;
+ targetBytePos = targetFrameIndex * drwav_get_bytes_per_pcm_frame(pWav);
+
+ if (currentBytePos < targetBytePos) {
+ /* Offset forwards. */
+ offset = (targetBytePos - currentBytePos);
+ } else {
+ /* Offset backwards. */
+ if (!drwav_seek_to_first_pcm_frame(pWav)) {
+ return DRWAV_FALSE;
+ }
+ offset = targetBytePos;
+ }
+
+ while (offset > 0) {
+ int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset);
+ if (!pWav->onSeek(pWav->pUserData, offset32, drwav_seek_origin_current)) {
+ return DRWAV_FALSE;
+ }
+
+ pWav->bytesRemaining -= offset32;
+ offset -= offset32;
+ }
+ }
+
+ return DRWAV_TRUE;
+}
+
+
+DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData)
+{
+ size_t bytesWritten;
+
+ if (pWav == NULL || bytesToWrite == 0 || pData == NULL) {
+ return 0;
+ }
+
+ bytesWritten = pWav->onWrite(pWav->pUserData, pData, bytesToWrite);
+ pWav->dataChunkDataSize += bytesWritten;
+
+ return bytesWritten;
+}
+
+
+DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData)
+{
+ drwav_uint64 bytesToWrite;
+ drwav_uint64 bytesWritten;
+ const drwav_uint8* pRunningData;
+
+ if (pWav == NULL || framesToWrite == 0 || pData == NULL) {
+ return 0;
+ }
+
+ bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8);
+ if (bytesToWrite > DRWAV_SIZE_MAX) {
+ return 0;
+ }
+
+ bytesWritten = 0;
+ pRunningData = (const drwav_uint8*)pData;
+
+ while (bytesToWrite > 0) {
+ size_t bytesJustWritten;
+ drwav_uint64 bytesToWriteThisIteration;
+
+ bytesToWriteThisIteration = bytesToWrite;
+ DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */
+
+ bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, pRunningData);
+ if (bytesJustWritten == 0) {
+ break;
+ }
+
+ bytesToWrite -= bytesJustWritten;
+ bytesWritten += bytesJustWritten;
+ pRunningData += bytesJustWritten;
+ }
+
+ return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels;
+}
+
+DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData)
+{
+ drwav_uint64 bytesToWrite;
+ drwav_uint64 bytesWritten;
+ drwav_uint32 bytesPerSample;
+ const drwav_uint8* pRunningData;
+
+ if (pWav == NULL || framesToWrite == 0 || pData == NULL) {
+ return 0;
+ }
+
+ bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8);
+ if (bytesToWrite > DRWAV_SIZE_MAX) {
+ return 0;
+ }
+
+ bytesWritten = 0;
+ pRunningData = (const drwav_uint8*)pData;
+
+ bytesPerSample = drwav_get_bytes_per_pcm_frame(pWav) / pWav->channels;
+
+ while (bytesToWrite > 0) {
+ drwav_uint8 temp[4096];
+ drwav_uint32 sampleCount;
+ size_t bytesJustWritten;
+ drwav_uint64 bytesToWriteThisIteration;
+
+ bytesToWriteThisIteration = bytesToWrite;
+ DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */
+
+ /*
+ WAV files are always little-endian. We need to byte swap on big-endian architectures. Since our input buffer is read-only we need
+ to use an intermediary buffer for the conversion.
+ */
+ sampleCount = sizeof(temp)/bytesPerSample;
+
+ if (bytesToWriteThisIteration > ((drwav_uint64)sampleCount)*bytesPerSample) {
+ bytesToWriteThisIteration = ((drwav_uint64)sampleCount)*bytesPerSample;
+ }
+
+ DRWAV_COPY_MEMORY(temp, pRunningData, (size_t)bytesToWriteThisIteration);
+ drwav__bswap_samples(temp, sampleCount, bytesPerSample, pWav->translatedFormatTag);
+
+ bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, temp);
+ if (bytesJustWritten == 0) {
+ break;
+ }
+
+ bytesToWrite -= bytesJustWritten;
+ bytesWritten += bytesJustWritten;
+ pRunningData += bytesJustWritten;
+ }
+
+ return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels;
+}
+
+DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData)
+{
+ if (drwav__is_little_endian()) {
+ return drwav_write_pcm_frames_le(pWav, framesToWrite, pData);
+ } else {
+ return drwav_write_pcm_frames_be(pWav, framesToWrite, pData);
+ }
+}
+
+
+static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead = 0;
+
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(framesToRead > 0);
+
+ /* TODO: Lots of room for optimization here. */
+
+ while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) {
+ /* If there are no cached frames we need to load a new block. */
+ if (pWav->msadpcm.cachedFrameCount == 0 && pWav->msadpcm.bytesRemainingInBlock == 0) {
+ if (pWav->channels == 1) {
+ /* Mono. */
+ drwav_uint8 header[7];
+ if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) {
+ return totalFramesRead;
+ }
+ pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header);
+
+ pWav->msadpcm.predictor[0] = header[0];
+ pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 1);
+ pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 3);
+ pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 5);
+ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0];
+ pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.cachedFrameCount = 2;
+ } else {
+ /* Stereo. */
+ drwav_uint8 header[14];
+ if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) {
+ return totalFramesRead;
+ }
+ pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header);
+
+ pWav->msadpcm.predictor[0] = header[0];
+ pWav->msadpcm.predictor[1] = header[1];
+ pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 2);
+ pWav->msadpcm.delta[1] = drwav__bytes_to_s16(header + 4);
+ pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 6);
+ pWav->msadpcm.prevFrames[1][1] = (drwav_int32)drwav__bytes_to_s16(header + 8);
+ pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 10);
+ pWav->msadpcm.prevFrames[1][0] = (drwav_int32)drwav__bytes_to_s16(header + 12);
+
+ pWav->msadpcm.cachedFrames[0] = pWav->msadpcm.prevFrames[0][0];
+ pWav->msadpcm.cachedFrames[1] = pWav->msadpcm.prevFrames[1][0];
+ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1];
+ pWav->msadpcm.cachedFrameCount = 2;
+ }
+ }
+
+ /* Output anything that's cached. */
+ while (framesToRead > 0 && pWav->msadpcm.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) {
+ if (pBufferOut != NULL) {
+ drwav_uint32 iSample = 0;
+ for (iSample = 0; iSample < pWav->channels; iSample += 1) {
+ pBufferOut[iSample] = (drwav_int16)pWav->msadpcm.cachedFrames[(drwav_countof(pWav->msadpcm.cachedFrames) - (pWav->msadpcm.cachedFrameCount*pWav->channels)) + iSample];
+ }
+
+ pBufferOut += pWav->channels;
+ }
+
+ framesToRead -= 1;
+ totalFramesRead += 1;
+ pWav->compressed.iCurrentPCMFrame += 1;
+ pWav->msadpcm.cachedFrameCount -= 1;
+ }
+
+ if (framesToRead == 0) {
+ return totalFramesRead;
+ }
+
+
+ /*
+ If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next
+ loop iteration which will trigger the loading of a new block.
+ */
+ if (pWav->msadpcm.cachedFrameCount == 0) {
+ if (pWav->msadpcm.bytesRemainingInBlock == 0) {
+ continue;
+ } else {
+ static drwav_int32 adaptationTable[] = {
+ 230, 230, 230, 230, 307, 409, 512, 614,
+ 768, 614, 512, 409, 307, 230, 230, 230
+ };
+ static drwav_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 };
+ static drwav_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 };
+
+ drwav_uint8 nibbles;
+ drwav_int32 nibble0;
+ drwav_int32 nibble1;
+
+ if (pWav->onRead(pWav->pUserData, &nibbles, 1) != 1) {
+ return totalFramesRead;
+ }
+ pWav->msadpcm.bytesRemainingInBlock -= 1;
+
+ /* TODO: Optimize away these if statements. */
+ nibble0 = ((nibbles & 0xF0) >> 4); if ((nibbles & 0x80)) { nibble0 |= 0xFFFFFFF0UL; }
+ nibble1 = ((nibbles & 0x0F) >> 0); if ((nibbles & 0x08)) { nibble1 |= 0xFFFFFFF0UL; }
+
+ if (pWav->channels == 1) {
+ /* Mono. */
+ drwav_int32 newSample0;
+ drwav_int32 newSample1;
+
+ newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
+ newSample0 += nibble0 * pWav->msadpcm.delta[0];
+ newSample0 = drwav_clamp(newSample0, -32768, 32767);
+
+ pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8;
+ if (pWav->msadpcm.delta[0] < 16) {
+ pWav->msadpcm.delta[0] = 16;
+ }
+
+ pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.prevFrames[0][1] = newSample0;
+
+
+ newSample1 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
+ newSample1 += nibble1 * pWav->msadpcm.delta[0];
+ newSample1 = drwav_clamp(newSample1, -32768, 32767);
+
+ pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8;
+ if (pWav->msadpcm.delta[0] < 16) {
+ pWav->msadpcm.delta[0] = 16;
+ }
+
+ pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.prevFrames[0][1] = newSample1;
+
+
+ pWav->msadpcm.cachedFrames[2] = newSample0;
+ pWav->msadpcm.cachedFrames[3] = newSample1;
+ pWav->msadpcm.cachedFrameCount = 2;
+ } else {
+ /* Stereo. */
+ drwav_int32 newSample0;
+ drwav_int32 newSample1;
+
+ /* Left. */
+ newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
+ newSample0 += nibble0 * pWav->msadpcm.delta[0];
+ newSample0 = drwav_clamp(newSample0, -32768, 32767);
+
+ pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8;
+ if (pWav->msadpcm.delta[0] < 16) {
+ pWav->msadpcm.delta[0] = 16;
+ }
+
+ pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
+ pWav->msadpcm.prevFrames[0][1] = newSample0;
+
+
+ /* Right. */
+ newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8;
+ newSample1 += nibble1 * pWav->msadpcm.delta[1];
+ newSample1 = drwav_clamp(newSample1, -32768, 32767);
+
+ pWav->msadpcm.delta[1] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8;
+ if (pWav->msadpcm.delta[1] < 16) {
+ pWav->msadpcm.delta[1] = 16;
+ }
+
+ pWav->msadpcm.prevFrames[1][0] = pWav->msadpcm.prevFrames[1][1];
+ pWav->msadpcm.prevFrames[1][1] = newSample1;
+
+ pWav->msadpcm.cachedFrames[2] = newSample0;
+ pWav->msadpcm.cachedFrames[3] = newSample1;
+ pWav->msadpcm.cachedFrameCount = 1;
+ }
+ }
+ }
+ }
+
+ return totalFramesRead;
+}
+
+
+static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead = 0;
+ drwav_uint32 iChannel;
+
+ static drwav_int32 indexTable[16] = {
+ -1, -1, -1, -1, 2, 4, 6, 8,
+ -1, -1, -1, -1, 2, 4, 6, 8
+ };
+
+ static drwav_int32 stepTable[89] = {
+ 7, 8, 9, 10, 11, 12, 13, 14, 16, 17,
+ 19, 21, 23, 25, 28, 31, 34, 37, 41, 45,
+ 50, 55, 60, 66, 73, 80, 88, 97, 107, 118,
+ 130, 143, 157, 173, 190, 209, 230, 253, 279, 307,
+ 337, 371, 408, 449, 494, 544, 598, 658, 724, 796,
+ 876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066,
+ 2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358,
+ 5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899,
+ 15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767
+ };
+
+ DRWAV_ASSERT(pWav != NULL);
+ DRWAV_ASSERT(framesToRead > 0);
+
+ /* TODO: Lots of room for optimization here. */
+
+ while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) {
+ /* If there are no cached samples we need to load a new block. */
+ if (pWav->ima.cachedFrameCount == 0 && pWav->ima.bytesRemainingInBlock == 0) {
+ if (pWav->channels == 1) {
+ /* Mono. */
+ drwav_uint8 header[4];
+ if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) {
+ return totalFramesRead;
+ }
+ pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header);
+
+ if (header[2] >= drwav_countof(stepTable)) {
+ pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current);
+ pWav->ima.bytesRemainingInBlock = 0;
+ return totalFramesRead; /* Invalid data. */
+ }
+
+ pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0);
+ pWav->ima.stepIndex[0] = header[2];
+ pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[0];
+ pWav->ima.cachedFrameCount = 1;
+ } else {
+ /* Stereo. */
+ drwav_uint8 header[8];
+ if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) {
+ return totalFramesRead;
+ }
+ pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header);
+
+ if (header[2] >= drwav_countof(stepTable) || header[6] >= drwav_countof(stepTable)) {
+ pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current);
+ pWav->ima.bytesRemainingInBlock = 0;
+ return totalFramesRead; /* Invalid data. */
+ }
+
+ pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0);
+ pWav->ima.stepIndex[0] = header[2];
+ pWav->ima.predictor[1] = drwav__bytes_to_s16(header + 4);
+ pWav->ima.stepIndex[1] = header[6];
+
+ pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 2] = pWav->ima.predictor[0];
+ pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[1];
+ pWav->ima.cachedFrameCount = 1;
+ }
+ }
+
+ /* Output anything that's cached. */
+ while (framesToRead > 0 && pWav->ima.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) {
+ if (pBufferOut != NULL) {
+ drwav_uint32 iSample;
+ for (iSample = 0; iSample < pWav->channels; iSample += 1) {
+ pBufferOut[iSample] = (drwav_int16)pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + iSample];
+ }
+ pBufferOut += pWav->channels;
+ }
+
+ framesToRead -= 1;
+ totalFramesRead += 1;
+ pWav->compressed.iCurrentPCMFrame += 1;
+ pWav->ima.cachedFrameCount -= 1;
+ }
+
+ if (framesToRead == 0) {
+ return totalFramesRead;
+ }
+
+ /*
+ If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next
+ loop iteration which will trigger the loading of a new block.
+ */
+ if (pWav->ima.cachedFrameCount == 0) {
+ if (pWav->ima.bytesRemainingInBlock == 0) {
+ continue;
+ } else {
+ /*
+ From what I can tell with stereo streams, it looks like every 4 bytes (8 samples) is for one channel. So it goes 4 bytes for the
+ left channel, 4 bytes for the right channel.
+ */
+ pWav->ima.cachedFrameCount = 8;
+ for (iChannel = 0; iChannel < pWav->channels; ++iChannel) {
+ drwav_uint32 iByte;
+ drwav_uint8 nibbles[4];
+ if (pWav->onRead(pWav->pUserData, &nibbles, 4) != 4) {
+ pWav->ima.cachedFrameCount = 0;
+ return totalFramesRead;
+ }
+ pWav->ima.bytesRemainingInBlock -= 4;
+
+ for (iByte = 0; iByte < 4; ++iByte) {
+ drwav_uint8 nibble0 = ((nibbles[iByte] & 0x0F) >> 0);
+ drwav_uint8 nibble1 = ((nibbles[iByte] & 0xF0) >> 4);
+
+ drwav_int32 step = stepTable[pWav->ima.stepIndex[iChannel]];
+ drwav_int32 predictor = pWav->ima.predictor[iChannel];
+
+ drwav_int32 diff = step >> 3;
+ if (nibble0 & 1) diff += step >> 2;
+ if (nibble0 & 2) diff += step >> 1;
+ if (nibble0 & 4) diff += step;
+ if (nibble0 & 8) diff = -diff;
+
+ predictor = drwav_clamp(predictor + diff, -32768, 32767);
+ pWav->ima.predictor[iChannel] = predictor;
+ pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble0], 0, (drwav_int32)drwav_countof(stepTable)-1);
+ pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+0)*pWav->channels + iChannel] = predictor;
+
+
+ step = stepTable[pWav->ima.stepIndex[iChannel]];
+ predictor = pWav->ima.predictor[iChannel];
+
+ diff = step >> 3;
+ if (nibble1 & 1) diff += step >> 2;
+ if (nibble1 & 2) diff += step >> 1;
+ if (nibble1 & 4) diff += step;
+ if (nibble1 & 8) diff = -diff;
+
+ predictor = drwav_clamp(predictor + diff, -32768, 32767);
+ pWav->ima.predictor[iChannel] = predictor;
+ pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble1], 0, (drwav_int32)drwav_countof(stepTable)-1);
+ pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+1)*pWav->channels + iChannel] = predictor;
+ }
+ }
+ }
+ }
+ }
+
+ return totalFramesRead;
+}
+
+
+#ifndef DR_WAV_NO_CONVERSION_API
+static unsigned short g_drwavAlawTable[256] = {
+ 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580,
+ 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0,
+ 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600,
+ 0xD500, 0xD700, 0xD100, 0xD300, 0xDD00, 0xDF00, 0xD900, 0xDB00, 0xC500, 0xC700, 0xC100, 0xC300, 0xCD00, 0xCF00, 0xC900, 0xCB00,
+ 0xFEA8, 0xFEB8, 0xFE88, 0xFE98, 0xFEE8, 0xFEF8, 0xFEC8, 0xFED8, 0xFE28, 0xFE38, 0xFE08, 0xFE18, 0xFE68, 0xFE78, 0xFE48, 0xFE58,
+ 0xFFA8, 0xFFB8, 0xFF88, 0xFF98, 0xFFE8, 0xFFF8, 0xFFC8, 0xFFD8, 0xFF28, 0xFF38, 0xFF08, 0xFF18, 0xFF68, 0xFF78, 0xFF48, 0xFF58,
+ 0xFAA0, 0xFAE0, 0xFA20, 0xFA60, 0xFBA0, 0xFBE0, 0xFB20, 0xFB60, 0xF8A0, 0xF8E0, 0xF820, 0xF860, 0xF9A0, 0xF9E0, 0xF920, 0xF960,
+ 0xFD50, 0xFD70, 0xFD10, 0xFD30, 0xFDD0, 0xFDF0, 0xFD90, 0xFDB0, 0xFC50, 0xFC70, 0xFC10, 0xFC30, 0xFCD0, 0xFCF0, 0xFC90, 0xFCB0,
+ 0x1580, 0x1480, 0x1780, 0x1680, 0x1180, 0x1080, 0x1380, 0x1280, 0x1D80, 0x1C80, 0x1F80, 0x1E80, 0x1980, 0x1880, 0x1B80, 0x1A80,
+ 0x0AC0, 0x0A40, 0x0BC0, 0x0B40, 0x08C0, 0x0840, 0x09C0, 0x0940, 0x0EC0, 0x0E40, 0x0FC0, 0x0F40, 0x0CC0, 0x0C40, 0x0DC0, 0x0D40,
+ 0x5600, 0x5200, 0x5E00, 0x5A00, 0x4600, 0x4200, 0x4E00, 0x4A00, 0x7600, 0x7200, 0x7E00, 0x7A00, 0x6600, 0x6200, 0x6E00, 0x6A00,
+ 0x2B00, 0x2900, 0x2F00, 0x2D00, 0x2300, 0x2100, 0x2700, 0x2500, 0x3B00, 0x3900, 0x3F00, 0x3D00, 0x3300, 0x3100, 0x3700, 0x3500,
+ 0x0158, 0x0148, 0x0178, 0x0168, 0x0118, 0x0108, 0x0138, 0x0128, 0x01D8, 0x01C8, 0x01F8, 0x01E8, 0x0198, 0x0188, 0x01B8, 0x01A8,
+ 0x0058, 0x0048, 0x0078, 0x0068, 0x0018, 0x0008, 0x0038, 0x0028, 0x00D8, 0x00C8, 0x00F8, 0x00E8, 0x0098, 0x0088, 0x00B8, 0x00A8,
+ 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0,
+ 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350
+};
+
+static unsigned short g_drwavMulawTable[256] = {
+ 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84,
+ 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84,
+ 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004,
+ 0xF0C4, 0xF144, 0xF1C4, 0xF244, 0xF2C4, 0xF344, 0xF3C4, 0xF444, 0xF4C4, 0xF544, 0xF5C4, 0xF644, 0xF6C4, 0xF744, 0xF7C4, 0xF844,
+ 0xF8A4, 0xF8E4, 0xF924, 0xF964, 0xF9A4, 0xF9E4, 0xFA24, 0xFA64, 0xFAA4, 0xFAE4, 0xFB24, 0xFB64, 0xFBA4, 0xFBE4, 0xFC24, 0xFC64,
+ 0xFC94, 0xFCB4, 0xFCD4, 0xFCF4, 0xFD14, 0xFD34, 0xFD54, 0xFD74, 0xFD94, 0xFDB4, 0xFDD4, 0xFDF4, 0xFE14, 0xFE34, 0xFE54, 0xFE74,
+ 0xFE8C, 0xFE9C, 0xFEAC, 0xFEBC, 0xFECC, 0xFEDC, 0xFEEC, 0xFEFC, 0xFF0C, 0xFF1C, 0xFF2C, 0xFF3C, 0xFF4C, 0xFF5C, 0xFF6C, 0xFF7C,
+ 0xFF88, 0xFF90, 0xFF98, 0xFFA0, 0xFFA8, 0xFFB0, 0xFFB8, 0xFFC0, 0xFFC8, 0xFFD0, 0xFFD8, 0xFFE0, 0xFFE8, 0xFFF0, 0xFFF8, 0x0000,
+ 0x7D7C, 0x797C, 0x757C, 0x717C, 0x6D7C, 0x697C, 0x657C, 0x617C, 0x5D7C, 0x597C, 0x557C, 0x517C, 0x4D7C, 0x497C, 0x457C, 0x417C,
+ 0x3E7C, 0x3C7C, 0x3A7C, 0x387C, 0x367C, 0x347C, 0x327C, 0x307C, 0x2E7C, 0x2C7C, 0x2A7C, 0x287C, 0x267C, 0x247C, 0x227C, 0x207C,
+ 0x1EFC, 0x1DFC, 0x1CFC, 0x1BFC, 0x1AFC, 0x19FC, 0x18FC, 0x17FC, 0x16FC, 0x15FC, 0x14FC, 0x13FC, 0x12FC, 0x11FC, 0x10FC, 0x0FFC,
+ 0x0F3C, 0x0EBC, 0x0E3C, 0x0DBC, 0x0D3C, 0x0CBC, 0x0C3C, 0x0BBC, 0x0B3C, 0x0ABC, 0x0A3C, 0x09BC, 0x093C, 0x08BC, 0x083C, 0x07BC,
+ 0x075C, 0x071C, 0x06DC, 0x069C, 0x065C, 0x061C, 0x05DC, 0x059C, 0x055C, 0x051C, 0x04DC, 0x049C, 0x045C, 0x041C, 0x03DC, 0x039C,
+ 0x036C, 0x034C, 0x032C, 0x030C, 0x02EC, 0x02CC, 0x02AC, 0x028C, 0x026C, 0x024C, 0x022C, 0x020C, 0x01EC, 0x01CC, 0x01AC, 0x018C,
+ 0x0174, 0x0164, 0x0154, 0x0144, 0x0134, 0x0124, 0x0114, 0x0104, 0x00F4, 0x00E4, 0x00D4, 0x00C4, 0x00B4, 0x00A4, 0x0094, 0x0084,
+ 0x0078, 0x0070, 0x0068, 0x0060, 0x0058, 0x0050, 0x0048, 0x0040, 0x0038, 0x0030, 0x0028, 0x0020, 0x0018, 0x0010, 0x0008, 0x0000
+};
+
+static DRWAV_INLINE drwav_int16 drwav__alaw_to_s16(drwav_uint8 sampleIn)
+{
+ return (short)g_drwavAlawTable[sampleIn];
+}
+
+static DRWAV_INLINE drwav_int16 drwav__mulaw_to_s16(drwav_uint8 sampleIn)
+{
+ return (short)g_drwavMulawTable[sampleIn];
+}
+
+
+
+static void drwav__pcm_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample)
+{
+ unsigned int i;
+
+ /* Special case for 8-bit sample data because it's treated as unsigned. */
+ if (bytesPerSample == 1) {
+ drwav_u8_to_s16(pOut, pIn, totalSampleCount);
+ return;
+ }
+
+
+ /* Slightly more optimal implementation for common formats. */
+ if (bytesPerSample == 2) {
+ for (i = 0; i < totalSampleCount; ++i) {
+ *pOut++ = ((const drwav_int16*)pIn)[i];
+ }
+ return;
+ }
+ if (bytesPerSample == 3) {
+ drwav_s24_to_s16(pOut, pIn, totalSampleCount);
+ return;
+ }
+ if (bytesPerSample == 4) {
+ drwav_s32_to_s16(pOut, (const drwav_int32*)pIn, totalSampleCount);
+ return;
+ }
+
+
+ /* Anything more than 64 bits per sample is not supported. */
+ if (bytesPerSample > 8) {
+ DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut));
+ return;
+ }
+
+
+ /* Generic, slow converter. */
+ for (i = 0; i < totalSampleCount; ++i) {
+ drwav_uint64 sample = 0;
+ unsigned int shift = (8 - bytesPerSample) * 8;
+
+ unsigned int j;
+ for (j = 0; j < bytesPerSample; j += 1) {
+ DRWAV_ASSERT(j < 8);
+ sample |= (drwav_uint64)(pIn[j]) << shift;
+ shift += 8;
+ }
+
+ pIn += j;
+ *pOut++ = (drwav_int16)((drwav_int64)sample >> 48);
+ }
+}
+
+static void drwav__ieee_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample)
+{
+ if (bytesPerSample == 4) {
+ drwav_f32_to_s16(pOut, (const float*)pIn, totalSampleCount);
+ return;
+ } else if (bytesPerSample == 8) {
+ drwav_f64_to_s16(pOut, (const double*)pIn, totalSampleCount);
+ return;
+ } else {
+ /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */
+ DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut));
+ return;
+ }
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint32 bytesPerFrame;
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ /* Fast path. */
+ if ((pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 16) || pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__pcm_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__ieee_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_alaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s16__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_mulaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ if (pWav == NULL || framesToRead == 0) {
+ return 0;
+ }
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ /* Don't try to read more samples than can potentially fit in the output buffer. */
+ if (framesToRead * pWav->channels * sizeof(drwav_int16) > DRWAV_SIZE_MAX) {
+ framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int16) / pWav->channels;
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) {
+ return drwav_read_pcm_frames_s16__pcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) {
+ return drwav_read_pcm_frames_s16__ieee(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) {
+ return drwav_read_pcm_frames_s16__alaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) {
+ return drwav_read_pcm_frames_s16__mulaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ return drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ return drwav_read_pcm_frames_s16__ima(pWav, framesToRead, pBufferOut);
+ }
+
+ return 0;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) {
+ drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) {
+ drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+
+DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ int x = pIn[i];
+ r = x << 8;
+ r = r - 32768;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ int x = ((int)(((unsigned int)(((const drwav_uint8*)pIn)[i*3+0]) << 8) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+1]) << 16) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+2])) << 24)) >> 8;
+ r = x >> 8;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ int x = pIn[i];
+ r = x >> 16;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ float x = pIn[i];
+ float c;
+ c = ((x < -1) ? -1 : ((x > 1) ? 1 : x));
+ c = c + 1;
+ r = (int)(c * 32767.5f);
+ r = r - 32768;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount)
+{
+ int r;
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ double x = pIn[i];
+ double c;
+ c = ((x < -1) ? -1 : ((x > 1) ? 1 : x));
+ c = c + 1;
+ r = (int)(c * 32767.5);
+ r = r - 32768;
+ pOut[i] = (short)r;
+ }
+}
+
+DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ pOut[i] = drwav__alaw_to_s16(pIn[i]);
+ }
+}
+
+DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+ for (i = 0; i < sampleCount; ++i) {
+ pOut[i] = drwav__mulaw_to_s16(pIn[i]);
+ }
+}
+
+
+
+static void drwav__pcm_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample)
+{
+ unsigned int i;
+
+ /* Special case for 8-bit sample data because it's treated as unsigned. */
+ if (bytesPerSample == 1) {
+ drwav_u8_to_f32(pOut, pIn, sampleCount);
+ return;
+ }
+
+ /* Slightly more optimal implementation for common formats. */
+ if (bytesPerSample == 2) {
+ drwav_s16_to_f32(pOut, (const drwav_int16*)pIn, sampleCount);
+ return;
+ }
+ if (bytesPerSample == 3) {
+ drwav_s24_to_f32(pOut, pIn, sampleCount);
+ return;
+ }
+ if (bytesPerSample == 4) {
+ drwav_s32_to_f32(pOut, (const drwav_int32*)pIn, sampleCount);
+ return;
+ }
+
+
+ /* Anything more than 64 bits per sample is not supported. */
+ if (bytesPerSample > 8) {
+ DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut));
+ return;
+ }
+
+
+ /* Generic, slow converter. */
+ for (i = 0; i < sampleCount; ++i) {
+ drwav_uint64 sample = 0;
+ unsigned int shift = (8 - bytesPerSample) * 8;
+
+ unsigned int j;
+ for (j = 0; j < bytesPerSample; j += 1) {
+ DRWAV_ASSERT(j < 8);
+ sample |= (drwav_uint64)(pIn[j]) << shift;
+ shift += 8;
+ }
+
+ pIn += j;
+ *pOut++ = (float)((drwav_int64)sample / 9223372036854775807.0);
+ }
+}
+
+static void drwav__ieee_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample)
+{
+ if (bytesPerSample == 4) {
+ unsigned int i;
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = ((const float*)pIn)[i];
+ }
+ return;
+ } else if (bytesPerSample == 8) {
+ drwav_f64_to_f32(pOut, (const double*)pIn, sampleCount);
+ return;
+ } else {
+ /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */
+ DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut));
+ return;
+ }
+}
+
+
+static drwav_uint64 drwav_read_pcm_frames_f32__pcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__pcm_to_f32(pBufferOut, sampleData, (size_t)framesRead*pWav->channels, bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ /*
+ We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't
+ want to duplicate that code.
+ */
+ drwav_uint64 totalFramesRead = 0;
+ drwav_int16 samples16[2048];
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__ima(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ /*
+ We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't
+ want to duplicate that code.
+ */
+ drwav_uint64 totalFramesRead = 0;
+ drwav_int16 samples16[2048];
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__ieee(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ /* Fast path. */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT && pWav->bitsPerSample == 32) {
+ return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__ieee_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__alaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_alaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_f32__mulaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_mulaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ if (pWav == NULL || framesToRead == 0) {
+ return 0;
+ }
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ /* Don't try to read more samples than can potentially fit in the output buffer. */
+ if (framesToRead * pWav->channels * sizeof(float) > DRWAV_SIZE_MAX) {
+ framesToRead = DRWAV_SIZE_MAX / sizeof(float) / pWav->channels;
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) {
+ return drwav_read_pcm_frames_f32__pcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ return drwav_read_pcm_frames_f32__msadpcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) {
+ return drwav_read_pcm_frames_f32__ieee(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) {
+ return drwav_read_pcm_frames_f32__alaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) {
+ return drwav_read_pcm_frames_f32__mulaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ return drwav_read_pcm_frames_f32__ima(pWav, framesToRead, pBufferOut);
+ }
+
+ return 0;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) {
+ drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) {
+ drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+
+DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+#ifdef DR_WAV_LIBSNDFILE_COMPAT
+ /*
+ It appears libsndfile uses slightly different logic for the u8 -> f32 conversion to dr_wav, which in my opinion is incorrect. It appears
+ libsndfile performs the conversion something like "f32 = (u8 / 256) * 2 - 1", however I think it should be "f32 = (u8 / 255) * 2 - 1" (note
+ the divisor of 256 vs 255). I use libsndfile as a benchmark for testing, so I'm therefore leaving this block here just for my automated
+ correctness testing. This is disabled by default.
+ */
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (pIn[i] / 256.0f) * 2 - 1;
+ }
+#else
+ for (i = 0; i < sampleCount; ++i) {
+ float x = pIn[i];
+ x = x * 0.00784313725490196078f; /* 0..255 to 0..2 */
+ x = x - 1; /* 0..2 to -1..1 */
+
+ *pOut++ = x;
+ }
+#endif
+}
+
+DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = pIn[i] * 0.000030517578125f;
+ }
+}
+
+DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ double x;
+ drwav_uint32 a = ((drwav_uint32)(pIn[i*3+0]) << 8);
+ drwav_uint32 b = ((drwav_uint32)(pIn[i*3+1]) << 16);
+ drwav_uint32 c = ((drwav_uint32)(pIn[i*3+2]) << 24);
+
+ x = (double)((drwav_int32)(a | b | c) >> 8);
+ *pOut++ = (float)(x * 0.00000011920928955078125);
+ }
+}
+
+DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount)
+{
+ size_t i;
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (float)(pIn[i] / 2147483648.0);
+ }
+}
+
+DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (float)pIn[i];
+ }
+}
+
+DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = drwav__alaw_to_s16(pIn[i]) / 32768.0f;
+ }
+}
+
+DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = drwav__mulaw_to_s16(pIn[i]) / 32768.0f;
+ }
+}
+
+
+
+static void drwav__pcm_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample)
+{
+ unsigned int i;
+
+ /* Special case for 8-bit sample data because it's treated as unsigned. */
+ if (bytesPerSample == 1) {
+ drwav_u8_to_s32(pOut, pIn, totalSampleCount);
+ return;
+ }
+
+ /* Slightly more optimal implementation for common formats. */
+ if (bytesPerSample == 2) {
+ drwav_s16_to_s32(pOut, (const drwav_int16*)pIn, totalSampleCount);
+ return;
+ }
+ if (bytesPerSample == 3) {
+ drwav_s24_to_s32(pOut, pIn, totalSampleCount);
+ return;
+ }
+ if (bytesPerSample == 4) {
+ for (i = 0; i < totalSampleCount; ++i) {
+ *pOut++ = ((const drwav_int32*)pIn)[i];
+ }
+ return;
+ }
+
+
+ /* Anything more than 64 bits per sample is not supported. */
+ if (bytesPerSample > 8) {
+ DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut));
+ return;
+ }
+
+
+ /* Generic, slow converter. */
+ for (i = 0; i < totalSampleCount; ++i) {
+ drwav_uint64 sample = 0;
+ unsigned int shift = (8 - bytesPerSample) * 8;
+
+ unsigned int j;
+ for (j = 0; j < bytesPerSample; j += 1) {
+ DRWAV_ASSERT(j < 8);
+ sample |= (drwav_uint64)(pIn[j]) << shift;
+ shift += 8;
+ }
+
+ pIn += j;
+ *pOut++ = (drwav_int32)((drwav_int64)sample >> 32);
+ }
+}
+
+static void drwav__ieee_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample)
+{
+ if (bytesPerSample == 4) {
+ drwav_f32_to_s32(pOut, (const float*)pIn, totalSampleCount);
+ return;
+ } else if (bytesPerSample == 8) {
+ drwav_f64_to_s32(pOut, (const double*)pIn, totalSampleCount);
+ return;
+ } else {
+ /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */
+ DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut));
+ return;
+ }
+}
+
+
+static drwav_uint64 drwav_read_pcm_frames_s32__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+ drwav_uint32 bytesPerFrame;
+
+ /* Fast path. */
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 32) {
+ return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
+ }
+
+ bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__pcm_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ /*
+ We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't
+ want to duplicate that code.
+ */
+ drwav_uint64 totalFramesRead = 0;
+ drwav_int16 samples16[2048];
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ /*
+ We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't
+ want to duplicate that code.
+ */
+ drwav_uint64 totalFramesRead = 0;
+ drwav_int16 samples16[2048];
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav__ieee_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels);
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_alaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+static drwav_uint64 drwav_read_pcm_frames_s32__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 totalFramesRead;
+ drwav_uint8 sampleData[4096];
+
+ drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
+ if (bytesPerFrame == 0) {
+ return 0;
+ }
+
+ totalFramesRead = 0;
+
+ while (framesToRead > 0) {
+ drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
+ if (framesRead == 0) {
+ break;
+ }
+
+ drwav_mulaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels));
+
+ pBufferOut += framesRead*pWav->channels;
+ framesToRead -= framesRead;
+ totalFramesRead += framesRead;
+ }
+
+ return totalFramesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ if (pWav == NULL || framesToRead == 0) {
+ return 0;
+ }
+
+ if (pBufferOut == NULL) {
+ return drwav_read_pcm_frames(pWav, framesToRead, NULL);
+ }
+
+ /* Don't try to read more samples than can potentially fit in the output buffer. */
+ if (framesToRead * pWav->channels * sizeof(drwav_int32) > DRWAV_SIZE_MAX) {
+ framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int32) / pWav->channels;
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) {
+ return drwav_read_pcm_frames_s32__pcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) {
+ return drwav_read_pcm_frames_s32__msadpcm(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) {
+ return drwav_read_pcm_frames_s32__ieee(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) {
+ return drwav_read_pcm_frames_s32__alaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) {
+ return drwav_read_pcm_frames_s32__mulaw(pWav, framesToRead, pBufferOut);
+ }
+
+ if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) {
+ return drwav_read_pcm_frames_s32__ima(pWav, framesToRead, pBufferOut);
+ }
+
+ return 0;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) {
+ drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut)
+{
+ drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut);
+ if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) {
+ drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels);
+ }
+
+ return framesRead;
+}
+
+
+DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = ((int)pIn[i] - 128) << 24;
+ }
+}
+
+DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = pIn[i] << 16;
+ }
+}
+
+DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ unsigned int s0 = pIn[i*3 + 0];
+ unsigned int s1 = pIn[i*3 + 1];
+ unsigned int s2 = pIn[i*3 + 2];
+
+ drwav_int32 sample32 = (drwav_int32)((s0 << 8) | (s1 << 16) | (s2 << 24));
+ *pOut++ = sample32;
+ }
+}
+
+DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]);
+ }
+}
+
+DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]);
+ }
+}
+
+DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i = 0; i < sampleCount; ++i) {
+ *pOut++ = ((drwav_int32)drwav__alaw_to_s16(pIn[i])) << 16;
+ }
+}
+
+DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount)
+{
+ size_t i;
+
+ if (pOut == NULL || pIn == NULL) {
+ return;
+ }
+
+ for (i= 0; i < sampleCount; ++i) {
+ *pOut++ = ((drwav_int32)drwav__mulaw_to_s16(pIn[i])) << 16;
+ }
+}
+
+
+
+static drwav_int16* drwav__read_pcm_frames_and_close_s16(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount)
+{
+ drwav_uint64 sampleDataSize;
+ drwav_int16* pSampleData;
+ drwav_uint64 framesRead;
+
+ DRWAV_ASSERT(pWav != NULL);
+
+ sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int16);
+ if (sampleDataSize > DRWAV_SIZE_MAX) {
+ drwav_uninit(pWav);
+ return NULL; /* File's too big. */
+ }
+
+ pSampleData = (drwav_int16*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */
+ if (pSampleData == NULL) {
+ drwav_uninit(pWav);
+ return NULL; /* Failed to allocate memory. */
+ }
+
+ framesRead = drwav_read_pcm_frames_s16(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData);
+ if (framesRead != pWav->totalPCMFrameCount) {
+ drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks);
+ drwav_uninit(pWav);
+ return NULL; /* There was an error reading the samples. */
+ }
+
+ drwav_uninit(pWav);
+
+ if (sampleRate) {
+ *sampleRate = pWav->sampleRate;
+ }
+ if (channels) {
+ *channels = pWav->channels;
+ }
+ if (totalFrameCount) {
+ *totalFrameCount = pWav->totalPCMFrameCount;
+ }
+
+ return pSampleData;
+}
+
+static float* drwav__read_pcm_frames_and_close_f32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount)
+{
+ drwav_uint64 sampleDataSize;
+ float* pSampleData;
+ drwav_uint64 framesRead;
+
+ DRWAV_ASSERT(pWav != NULL);
+
+ sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float);
+ if (sampleDataSize > DRWAV_SIZE_MAX) {
+ drwav_uninit(pWav);
+ return NULL; /* File's too big. */
+ }
+
+ pSampleData = (float*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */
+ if (pSampleData == NULL) {
+ drwav_uninit(pWav);
+ return NULL; /* Failed to allocate memory. */
+ }
+
+ framesRead = drwav_read_pcm_frames_f32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData);
+ if (framesRead != pWav->totalPCMFrameCount) {
+ drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks);
+ drwav_uninit(pWav);
+ return NULL; /* There was an error reading the samples. */
+ }
+
+ drwav_uninit(pWav);
+
+ if (sampleRate) {
+ *sampleRate = pWav->sampleRate;
+ }
+ if (channels) {
+ *channels = pWav->channels;
+ }
+ if (totalFrameCount) {
+ *totalFrameCount = pWav->totalPCMFrameCount;
+ }
+
+ return pSampleData;
+}
+
+static drwav_int32* drwav__read_pcm_frames_and_close_s32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount)
+{
+ drwav_uint64 sampleDataSize;
+ drwav_int32* pSampleData;
+ drwav_uint64 framesRead;
+
+ DRWAV_ASSERT(pWav != NULL);
+
+ sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int32);
+ if (sampleDataSize > DRWAV_SIZE_MAX) {
+ drwav_uninit(pWav);
+ return NULL; /* File's too big. */
+ }
+
+ pSampleData = (drwav_int32*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */
+ if (pSampleData == NULL) {
+ drwav_uninit(pWav);
+ return NULL; /* Failed to allocate memory. */
+ }
+
+ framesRead = drwav_read_pcm_frames_s32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData);
+ if (framesRead != pWav->totalPCMFrameCount) {
+ drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks);
+ drwav_uninit(pWav);
+ return NULL; /* There was an error reading the samples. */
+ }
+
+ drwav_uninit(pWav);
+
+ if (sampleRate) {
+ *sampleRate = pWav->sampleRate;
+ }
+ if (channels) {
+ *channels = pWav->channels;
+ }
+ if (totalFrameCount) {
+ *totalFrameCount = pWav->totalPCMFrameCount;
+ }
+
+ return pSampleData;
+}
+
+
+
+DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+#ifndef DR_WAV_NO_STDIO
+DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+
+DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+#endif
+
+DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+
+DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ drwav wav;
+
+ if (channelsOut) {
+ *channelsOut = 0;
+ }
+ if (sampleRateOut) {
+ *sampleRateOut = 0;
+ }
+ if (totalFrameCountOut) {
+ *totalFrameCountOut = 0;
+ }
+
+ if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) {
+ return NULL;
+ }
+
+ return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut);
+}
+#endif /* DR_WAV_NO_CONVERSION_API */
+
+
+DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks)
+{
+ if (pAllocationCallbacks != NULL) {
+ drwav__free_from_callbacks(p, pAllocationCallbacks);
+ } else {
+ drwav__free_default(p, NULL);
+ }
+}
+
+DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data)
+{
+ return drwav__bytes_to_u16(data);
+}
+
+DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data)
+{
+ return drwav__bytes_to_s16(data);
+}
+
+DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data)
+{
+ return drwav__bytes_to_u32(data);
+}
+
+DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data)
+{
+ return drwav__bytes_to_s32(data);
+}
+
+DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data)
+{
+ return drwav__bytes_to_u64(data);
+}
+
+DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data)
+{
+ return drwav__bytes_to_s64(data);
+}
+
+
+DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16])
+{
+ return drwav__guid_equal(a, b);
+}
+
+DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b)
+{
+ return drwav__fourcc_equal(a, b);
+}
+
+#endif /* dr_wav_c */
+#endif /* DR_WAV_IMPLEMENTATION */
+
+/*
+RELEASE NOTES - v0.11.0
+=======================
+Version 0.11.0 has breaking API changes.
+
+Improved Client-Defined Memory Allocation
+-----------------------------------------
+The main change with this release is the addition of a more flexible way of implementing custom memory allocation routines. The
+existing system of DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE are still in place and will be used by default when no custom
+allocation callbacks are specified.
+
+To use the new system, you pass in a pointer to a drwav_allocation_callbacks object to drwav_init() and family, like this:
+
+ void* my_malloc(size_t sz, void* pUserData)
+ {
+ return malloc(sz);
+ }
+ void* my_realloc(void* p, size_t sz, void* pUserData)
+ {
+ return realloc(p, sz);
+ }
+ void my_free(void* p, void* pUserData)
+ {
+ free(p);
+ }
+
+ ...
+
+ drwav_allocation_callbacks allocationCallbacks;
+ allocationCallbacks.pUserData = &myData;
+ allocationCallbacks.onMalloc = my_malloc;
+ allocationCallbacks.onRealloc = my_realloc;
+ allocationCallbacks.onFree = my_free;
+ drwav_init_file(&wav, "my_file.wav", &allocationCallbacks);
+
+The advantage of this new system is that it allows you to specify user data which will be passed in to the allocation routines.
+
+Passing in null for the allocation callbacks object will cause dr_wav to use defaults which is the same as DRWAV_MALLOC,
+DRWAV_REALLOC and DRWAV_FREE and the equivalent of how it worked in previous versions.
+
+Every API that opens a drwav object now takes this extra parameter. These include the following:
+
+ drwav_init()
+ drwav_init_ex()
+ drwav_init_file()
+ drwav_init_file_ex()
+ drwav_init_file_w()
+ drwav_init_file_w_ex()
+ drwav_init_memory()
+ drwav_init_memory_ex()
+ drwav_init_write()
+ drwav_init_write_sequential()
+ drwav_init_write_sequential_pcm_frames()
+ drwav_init_file_write()
+ drwav_init_file_write_sequential()
+ drwav_init_file_write_sequential_pcm_frames()
+ drwav_init_file_write_w()
+ drwav_init_file_write_sequential_w()
+ drwav_init_file_write_sequential_pcm_frames_w()
+ drwav_init_memory_write()
+ drwav_init_memory_write_sequential()
+ drwav_init_memory_write_sequential_pcm_frames()
+ drwav_open_and_read_pcm_frames_s16()
+ drwav_open_and_read_pcm_frames_f32()
+ drwav_open_and_read_pcm_frames_s32()
+ drwav_open_file_and_read_pcm_frames_s16()
+ drwav_open_file_and_read_pcm_frames_f32()
+ drwav_open_file_and_read_pcm_frames_s32()
+ drwav_open_file_and_read_pcm_frames_s16_w()
+ drwav_open_file_and_read_pcm_frames_f32_w()
+ drwav_open_file_and_read_pcm_frames_s32_w()
+ drwav_open_memory_and_read_pcm_frames_s16()
+ drwav_open_memory_and_read_pcm_frames_f32()
+ drwav_open_memory_and_read_pcm_frames_s32()
+
+Endian Improvements
+-------------------
+Previously, the following APIs returned little-endian audio data. These now return native-endian data. This improves compatibility
+on big-endian architectures.
+
+ drwav_read_pcm_frames()
+ drwav_read_pcm_frames_s16()
+ drwav_read_pcm_frames_s32()
+ drwav_read_pcm_frames_f32()
+ drwav_open_and_read_pcm_frames_s16()
+ drwav_open_and_read_pcm_frames_s32()
+ drwav_open_and_read_pcm_frames_f32()
+ drwav_open_file_and_read_pcm_frames_s16()
+ drwav_open_file_and_read_pcm_frames_s32()
+ drwav_open_file_and_read_pcm_frames_f32()
+ drwav_open_file_and_read_pcm_frames_s16_w()
+ drwav_open_file_and_read_pcm_frames_s32_w()
+ drwav_open_file_and_read_pcm_frames_f32_w()
+ drwav_open_memory_and_read_pcm_frames_s16()
+ drwav_open_memory_and_read_pcm_frames_s32()
+ drwav_open_memory_and_read_pcm_frames_f32()
+
+APIs have been added to give you explicit control over whether or not audio data is read or written in big- or little-endian byte
+order:
+
+ drwav_read_pcm_frames_le()
+ drwav_read_pcm_frames_be()
+ drwav_read_pcm_frames_s16le()
+ drwav_read_pcm_frames_s16be()
+ drwav_read_pcm_frames_f32le()
+ drwav_read_pcm_frames_f32be()
+ drwav_read_pcm_frames_s32le()
+ drwav_read_pcm_frames_s32be()
+ drwav_write_pcm_frames_le()
+ drwav_write_pcm_frames_be()
+
+Removed APIs
+------------
+The following APIs were deprecated in version 0.10.0 and have now been removed:
+
+ drwav_open()
+ drwav_open_ex()
+ drwav_open_write()
+ drwav_open_write_sequential()
+ drwav_open_file()
+ drwav_open_file_ex()
+ drwav_open_file_write()
+ drwav_open_file_write_sequential()
+ drwav_open_memory()
+ drwav_open_memory_ex()
+ drwav_open_memory_write()
+ drwav_open_memory_write_sequential()
+ drwav_close()
+
+
+
+RELEASE NOTES - v0.10.0
+=======================
+Version 0.10.0 has breaking API changes. There are no significant bug fixes in this release, so if you are affected you do
+not need to upgrade.
+
+Removed APIs
+------------
+The following APIs were deprecated in version 0.9.0 and have been completely removed in version 0.10.0:
+
+ drwav_read()
+ drwav_read_s16()
+ drwav_read_f32()
+ drwav_read_s32()
+ drwav_seek_to_sample()
+ drwav_write()
+ drwav_open_and_read_s16()
+ drwav_open_and_read_f32()
+ drwav_open_and_read_s32()
+ drwav_open_file_and_read_s16()
+ drwav_open_file_and_read_f32()
+ drwav_open_file_and_read_s32()
+ drwav_open_memory_and_read_s16()
+ drwav_open_memory_and_read_f32()
+ drwav_open_memory_and_read_s32()
+ drwav::totalSampleCount
+
+See release notes for version 0.9.0 at the bottom of this file for replacement APIs.
+
+Deprecated APIs
+---------------
+The following APIs have been deprecated. There is a confusing and completely arbitrary difference between drwav_init*() and
+drwav_open*(), where drwav_init*() initializes a pre-allocated drwav object, whereas drwav_open*() will first allocated a
+drwav object on the heap and then initialize it. drwav_open*() has been deprecated which means you must now use a pre-
+allocated drwav object with drwav_init*(). If you need the previous functionality, you can just do a malloc() followed by
+a called to one of the drwav_init*() APIs.
+
+ drwav_open()
+ drwav_open_ex()
+ drwav_open_write()
+ drwav_open_write_sequential()
+ drwav_open_file()
+ drwav_open_file_ex()
+ drwav_open_file_write()
+ drwav_open_file_write_sequential()
+ drwav_open_memory()
+ drwav_open_memory_ex()
+ drwav_open_memory_write()
+ drwav_open_memory_write_sequential()
+ drwav_close()
+
+These APIs will be removed completely in a future version. The rationale for this change is to remove confusion between the
+two different ways to initialize a drwav object.
+*/
+
+/*
+REVISION HISTORY
+================
+v0.12.16 - 2020-12-02
+ - Fix a bug when trying to read more bytes than can fit in a size_t.
+
+v0.12.15 - 2020-11-21
+ - Fix compilation with OpenWatcom.
+
+v0.12.14 - 2020-11-13
+ - Minor code clean up.
+
+v0.12.13 - 2020-11-01
+ - Improve compiler support for older versions of GCC.
+
+v0.12.12 - 2020-09-28
+ - Add support for RF64.
+ - Fix a bug in writing mode where the size of the RIFF chunk incorrectly includes the header section.
+
+v0.12.11 - 2020-09-08
+ - Fix a compilation error on older compilers.
+
+v0.12.10 - 2020-08-24
+ - Fix a bug when seeking with ADPCM formats.
+
+v0.12.9 - 2020-08-02
+ - Simplify sized types.
+
+v0.12.8 - 2020-07-25
+ - Fix a compilation warning.
+
+v0.12.7 - 2020-07-15
+ - Fix some bugs on big-endian architectures.
+ - Fix an error in s24 to f32 conversion.
+
+v0.12.6 - 2020-06-23
+ - Change drwav_read_*() to allow NULL to be passed in as the output buffer which is equivalent to a forward seek.
+ - Fix a buffer overflow when trying to decode invalid IMA-ADPCM files.
+ - Add include guard for the implementation section.
+
+v0.12.5 - 2020-05-27
+ - Minor documentation fix.
+
+v0.12.4 - 2020-05-16
+ - Replace assert() with DRWAV_ASSERT().
+ - Add compile-time and run-time version querying.
+ - DRWAV_VERSION_MINOR
+ - DRWAV_VERSION_MAJOR
+ - DRWAV_VERSION_REVISION
+ - DRWAV_VERSION_STRING
+ - drwav_version()
+ - drwav_version_string()
+
+v0.12.3 - 2020-04-30
+ - Fix compilation errors with VC6.
+
+v0.12.2 - 2020-04-21
+ - Fix a bug where drwav_init_file() does not close the file handle after attempting to load an erroneous file.
+
+v0.12.1 - 2020-04-13
+ - Fix some pedantic warnings.
+
+v0.12.0 - 2020-04-04
+ - API CHANGE: Add container and format parameters to the chunk callback.
+ - Minor documentation updates.
+
+v0.11.5 - 2020-03-07
+ - Fix compilation error with Visual Studio .NET 2003.
+
+v0.11.4 - 2020-01-29
+ - Fix some static analysis warnings.
+ - Fix a bug when reading f32 samples from an A-law encoded stream.
+
+v0.11.3 - 2020-01-12
+ - Minor changes to some f32 format conversion routines.
+ - Minor bug fix for ADPCM conversion when end of file is reached.
+
+v0.11.2 - 2019-12-02
+ - Fix a possible crash when using custom memory allocators without a custom realloc() implementation.
+ - Fix an integer overflow bug.
+ - Fix a null pointer dereference bug.
+ - Add limits to sample rate, channels and bits per sample to tighten up some validation.
+
+v0.11.1 - 2019-10-07
+ - Internal code clean up.
+
+v0.11.0 - 2019-10-06
+ - API CHANGE: Add support for user defined memory allocation routines. This system allows the program to specify their own memory allocation
+ routines with a user data pointer for client-specific contextual data. This adds an extra parameter to the end of the following APIs:
+ - drwav_init()
+ - drwav_init_ex()
+ - drwav_init_file()
+ - drwav_init_file_ex()
+ - drwav_init_file_w()
+ - drwav_init_file_w_ex()
+ - drwav_init_memory()
+ - drwav_init_memory_ex()
+ - drwav_init_write()
+ - drwav_init_write_sequential()
+ - drwav_init_write_sequential_pcm_frames()
+ - drwav_init_file_write()
+ - drwav_init_file_write_sequential()
+ - drwav_init_file_write_sequential_pcm_frames()
+ - drwav_init_file_write_w()
+ - drwav_init_file_write_sequential_w()
+ - drwav_init_file_write_sequential_pcm_frames_w()
+ - drwav_init_memory_write()
+ - drwav_init_memory_write_sequential()
+ - drwav_init_memory_write_sequential_pcm_frames()
+ - drwav_open_and_read_pcm_frames_s16()
+ - drwav_open_and_read_pcm_frames_f32()
+ - drwav_open_and_read_pcm_frames_s32()
+ - drwav_open_file_and_read_pcm_frames_s16()
+ - drwav_open_file_and_read_pcm_frames_f32()
+ - drwav_open_file_and_read_pcm_frames_s32()
+ - drwav_open_file_and_read_pcm_frames_s16_w()
+ - drwav_open_file_and_read_pcm_frames_f32_w()
+ - drwav_open_file_and_read_pcm_frames_s32_w()
+ - drwav_open_memory_and_read_pcm_frames_s16()
+ - drwav_open_memory_and_read_pcm_frames_f32()
+ - drwav_open_memory_and_read_pcm_frames_s32()
+ Set this extra parameter to NULL to use defaults which is the same as the previous behaviour. Setting this NULL will use
+ DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE.
+ - Add support for reading and writing PCM frames in an explicit endianness. New APIs:
+ - drwav_read_pcm_frames_le()
+ - drwav_read_pcm_frames_be()
+ - drwav_read_pcm_frames_s16le()
+ - drwav_read_pcm_frames_s16be()
+ - drwav_read_pcm_frames_f32le()
+ - drwav_read_pcm_frames_f32be()
+ - drwav_read_pcm_frames_s32le()
+ - drwav_read_pcm_frames_s32be()
+ - drwav_write_pcm_frames_le()
+ - drwav_write_pcm_frames_be()
+ - Remove deprecated APIs.
+ - API CHANGE: The following APIs now return native-endian data. Previously they returned little-endian data.
+ - drwav_read_pcm_frames()
+ - drwav_read_pcm_frames_s16()
+ - drwav_read_pcm_frames_s32()
+ - drwav_read_pcm_frames_f32()
+ - drwav_open_and_read_pcm_frames_s16()
+ - drwav_open_and_read_pcm_frames_s32()
+ - drwav_open_and_read_pcm_frames_f32()
+ - drwav_open_file_and_read_pcm_frames_s16()
+ - drwav_open_file_and_read_pcm_frames_s32()
+ - drwav_open_file_and_read_pcm_frames_f32()
+ - drwav_open_file_and_read_pcm_frames_s16_w()
+ - drwav_open_file_and_read_pcm_frames_s32_w()
+ - drwav_open_file_and_read_pcm_frames_f32_w()
+ - drwav_open_memory_and_read_pcm_frames_s16()
+ - drwav_open_memory_and_read_pcm_frames_s32()
+ - drwav_open_memory_and_read_pcm_frames_f32()
+
+v0.10.1 - 2019-08-31
+ - Correctly handle partial trailing ADPCM blocks.
+
+v0.10.0 - 2019-08-04
+ - Remove deprecated APIs.
+ - Add wchar_t variants for file loading APIs:
+ drwav_init_file_w()
+ drwav_init_file_ex_w()
+ drwav_init_file_write_w()
+ drwav_init_file_write_sequential_w()
+ - Add drwav_target_write_size_bytes() which calculates the total size in bytes of a WAV file given a format and sample count.
+ - Add APIs for specifying the PCM frame count instead of the sample count when opening in sequential write mode:
+ drwav_init_write_sequential_pcm_frames()
+ drwav_init_file_write_sequential_pcm_frames()
+ drwav_init_file_write_sequential_pcm_frames_w()
+ drwav_init_memory_write_sequential_pcm_frames()
+ - Deprecate drwav_open*() and drwav_close():
+ drwav_open()
+ drwav_open_ex()
+ drwav_open_write()
+ drwav_open_write_sequential()
+ drwav_open_file()
+ drwav_open_file_ex()
+ drwav_open_file_write()
+ drwav_open_file_write_sequential()
+ drwav_open_memory()
+ drwav_open_memory_ex()
+ drwav_open_memory_write()
+ drwav_open_memory_write_sequential()
+ drwav_close()
+ - Minor documentation updates.
+
+v0.9.2 - 2019-05-21
+ - Fix warnings.
+
+v0.9.1 - 2019-05-05
+ - Add support for C89.
+ - Change license to choice of public domain or MIT-0.
+
+v0.9.0 - 2018-12-16
+ - API CHANGE: Add new reading APIs for reading by PCM frames instead of samples. Old APIs have been deprecated and
+ will be removed in v0.10.0. Deprecated APIs and their replacements:
+ drwav_read() -> drwav_read_pcm_frames()
+ drwav_read_s16() -> drwav_read_pcm_frames_s16()
+ drwav_read_f32() -> drwav_read_pcm_frames_f32()
+ drwav_read_s32() -> drwav_read_pcm_frames_s32()
+ drwav_seek_to_sample() -> drwav_seek_to_pcm_frame()
+ drwav_write() -> drwav_write_pcm_frames()
+ drwav_open_and_read_s16() -> drwav_open_and_read_pcm_frames_s16()
+ drwav_open_and_read_f32() -> drwav_open_and_read_pcm_frames_f32()
+ drwav_open_and_read_s32() -> drwav_open_and_read_pcm_frames_s32()
+ drwav_open_file_and_read_s16() -> drwav_open_file_and_read_pcm_frames_s16()
+ drwav_open_file_and_read_f32() -> drwav_open_file_and_read_pcm_frames_f32()
+ drwav_open_file_and_read_s32() -> drwav_open_file_and_read_pcm_frames_s32()
+ drwav_open_memory_and_read_s16() -> drwav_open_memory_and_read_pcm_frames_s16()
+ drwav_open_memory_and_read_f32() -> drwav_open_memory_and_read_pcm_frames_f32()
+ drwav_open_memory_and_read_s32() -> drwav_open_memory_and_read_pcm_frames_s32()
+ drwav::totalSampleCount -> drwav::totalPCMFrameCount
+ - API CHANGE: Rename drwav_open_and_read_file_*() to drwav_open_file_and_read_*().
+ - API CHANGE: Rename drwav_open_and_read_memory_*() to drwav_open_memory_and_read_*().
+ - Add built-in support for smpl chunks.
+ - Add support for firing a callback for each chunk in the file at initialization time.
+ - This is enabled through the drwav_init_ex(), etc. family of APIs.
+ - Handle invalid FMT chunks more robustly.
+
+v0.8.5 - 2018-09-11
+ - Const correctness.
+ - Fix a potential stack overflow.
+
+v0.8.4 - 2018-08-07
+ - Improve 64-bit detection.
+
+v0.8.3 - 2018-08-05
+ - Fix C++ build on older versions of GCC.
+
+v0.8.2 - 2018-08-02
+ - Fix some big-endian bugs.
+
+v0.8.1 - 2018-06-29
+ - Add support for sequential writing APIs.
+ - Disable seeking in write mode.
+ - Fix bugs with Wave64.
+ - Fix typos.
+
+v0.8 - 2018-04-27
+ - Bug fix.
+ - Start using major.minor.revision versioning.
+
+v0.7f - 2018-02-05
+ - Restrict ADPCM formats to a maximum of 2 channels.
+
+v0.7e - 2018-02-02
+ - Fix a crash.
+
+v0.7d - 2018-02-01
+ - Fix a crash.
+
+v0.7c - 2018-02-01
+ - Set drwav.bytesPerSample to 0 for all compressed formats.
+ - Fix a crash when reading 16-bit floating point WAV files. In this case dr_wav will output silence for
+ all format conversion reading APIs (*_s16, *_s32, *_f32 APIs).
+ - Fix some divide-by-zero errors.
+
+v0.7b - 2018-01-22
+ - Fix errors with seeking of compressed formats.
+ - Fix compilation error when DR_WAV_NO_CONVERSION_API
+
+v0.7a - 2017-11-17
+ - Fix some GCC warnings.
+
+v0.7 - 2017-11-04
+ - Add writing APIs.
+
+v0.6 - 2017-08-16
+ - API CHANGE: Rename dr_* types to drwav_*.
+ - Add support for custom implementations of malloc(), realloc(), etc.
+ - Add support for Microsoft ADPCM.
+ - Add support for IMA ADPCM (DVI, format code 0x11).
+ - Optimizations to drwav_read_s16().
+ - Bug fixes.
+
+v0.5g - 2017-07-16
+ - Change underlying type for booleans to unsigned.
+
+v0.5f - 2017-04-04
+ - Fix a minor bug with drwav_open_and_read_s16() and family.
+
+v0.5e - 2016-12-29
+ - Added support for reading samples as signed 16-bit integers. Use the _s16() family of APIs for this.
+ - Minor fixes to documentation.
+
+v0.5d - 2016-12-28
+ - Use drwav_int* and drwav_uint* sized types to improve compiler support.
+
+v0.5c - 2016-11-11
+ - Properly handle JUNK chunks that come before the FMT chunk.
+
+v0.5b - 2016-10-23
+ - A minor change to drwav_bool8 and drwav_bool32 types.
+
+v0.5a - 2016-10-11
+ - Fixed a bug with drwav_open_and_read() and family due to incorrect argument ordering.
+ - Improve A-law and mu-law efficiency.
+
+v0.5 - 2016-09-29
+ - API CHANGE. Swap the order of "channels" and "sampleRate" parameters in drwav_open_and_read*(). Rationale for this is to
+ keep it consistent with dr_audio and dr_flac.
+
+v0.4b - 2016-09-18
+ - Fixed a typo in documentation.
+
+v0.4a - 2016-09-18
+ - Fixed a typo.
+ - Change date format to ISO 8601 (YYYY-MM-DD)
+
+v0.4 - 2016-07-13
+ - API CHANGE. Make onSeek consistent with dr_flac.
+ - API CHANGE. Rename drwav_seek() to drwav_seek_to_sample() for clarity and consistency with dr_flac.
+ - Added support for Sony Wave64.
+
+v0.3a - 2016-05-28
+ - API CHANGE. Return drwav_bool32 instead of int in onSeek callback.
+ - Fixed a memory leak.
+
+v0.3 - 2016-05-22
+ - Lots of API changes for consistency.
+
+v0.2a - 2016-05-16
+ - Fixed Linux/GCC build.
+
+v0.2 - 2016-05-11
+ - Added support for reading data as signed 32-bit PCM for consistency with dr_flac.
+
+v0.1a - 2016-05-07
+ - Fixed a bug in drwav_open_file() where the file handle would not be closed if the loader failed to initialize.
+
+v0.1 - 2016-05-04
+ - Initial versioned release.
+*/
+
+/*
+This software is available as a choice of the following licenses. Choose
+whichever you prefer.
+
+===============================================================================
+ALTERNATIVE 1 - Public Domain (www.unlicense.org)
+===============================================================================
+This is free and unencumbered software released into the public domain.
+
+Anyone is free to copy, modify, publish, use, compile, sell, or distribute this
+software, either in source code form or as a compiled binary, for any purpose,
+commercial or non-commercial, and by any means.
+
+In jurisdictions that recognize copyright laws, the author or authors of this
+software dedicate any and all copyright interest in the software to the public
+domain. We make this dedication for the benefit of the public at large and to
+the detriment of our heirs and successors. We intend this dedication to be an
+overt act of relinquishment in perpetuity of all present and future rights to
+this software under copyright law.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+For more information, please refer to <http://unlicense.org/>
+
+===============================================================================
+ALTERNATIVE 2 - MIT No Attribution
+===============================================================================
+Copyright 2020 David Reid
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+*/
diff --git a/Examples/OldMain/main.cpp b/Examples/OldMain/main.cpp
new file mode 100644
index 0000000..6e991b7
--- /dev/null
+++ b/Examples/OldMain/main.cpp
@@ -0,0 +1,684 @@
+#include "whisper.h"
+
+// third-party utilities
+// use your favorite implementations
+#define DR_WAV_IMPLEMENTATION
+#include "dr_wav.h"
+
+#include <cmath>
+#include <fstream>
+#include <cstdio>
+#include <string>
+#include <thread>
+#include <vector>
+
+// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
+// Lowest is red, middle is yellow, highest is green.
+const std::vector<std::string> k_colors = {
+ "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
+ "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
+};
+
+// 500 -> 00:05.000
+// 6000 -> 01:00.000
+std::string to_timestamp(int64_t t, bool comma = false) {
+ int64_t msec = t * 10;
+ int64_t hr = msec / (1000 * 60 * 60);
+ msec = msec - hr * (1000 * 60 * 60);
+ int64_t min = msec / (1000 * 60);
+ msec = msec - min * (1000 * 60);
+ int64_t sec = msec / 1000;
+ msec = msec - sec * 1000;
+
+ char buf[32];
+ snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
+
+ return std::string(buf);
+}
+
+int timestamp_to_sample(int64_t t, int n_samples) {
+ return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
+}
+
+// helper function to replace substrings
+void replace_all(std::string & s, const std::string & search, const std::string & replace) {
+ for (size_t pos = 0; ; pos += replace.length()) {
+ pos = s.find(search, pos);
+ if (pos == std::string::npos) break;
+ s.erase(pos, search.length());
+ s.insert(pos, replace);
+ }
+}
+
+// command-line parameters
+struct whisper_params {
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
+ int32_t n_processors = 1;
+ int32_t offset_t_ms = 0;
+ int32_t offset_n = 0;
+ int32_t duration_ms = 0;
+ int32_t max_context = -1;
+ int32_t max_len = 0;
+
+ float word_thold = 0.01f;
+
+ bool speed_up = false;
+ bool translate = false;
+ bool diarize = false;
+ bool output_txt = false;
+ bool output_vtt = false;
+ bool output_srt = false;
+ bool output_wts = false;
+ bool print_special = false;
+ bool print_colors = false;
+ bool print_progress = false;
+ bool no_timestamps = false;
+
+ std::string language = "en";
+ std::string prompt;
+ std::string model = "models/ggml-base.en.bin";
+
+ std::vector<std::string> fname_inp = {};
+};
+
+void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
+
+bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
+ for (int i = 1; i < argc; i++) {
+ std::string arg = argv[i];
+
+ if (arg[0] != '-') {
+ params.fname_inp.push_back(arg);
+ continue;
+ }
+
+ if (arg == "-h" || arg == "--help") {
+ whisper_print_usage(argc, argv, params);
+ exit(0);
+ }
+ else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
+ else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
+ else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
+ else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
+ else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
+ else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
+ else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
+ else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
+ else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
+ else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
+ else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
+ else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
+ else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
+ else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
+ else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
+ else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
+ else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
+ else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
+ else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
+ else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
+ else if ( arg == "--prompt") { params.prompt = argv[++i]; }
+ else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
+ else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
+ else {
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+ whisper_print_usage(argc, argv, params);
+ exit(0);
+ }
+ }
+
+ return true;
+}
+
+void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
+ fprintf(stderr, "\n");
+ fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
+ fprintf(stderr, "\n");
+ fprintf(stderr, "options:\n");
+ fprintf(stderr, " -h, --help [default] show this help message and exit\n");
+ fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
+ fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
+ fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
+ fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
+ fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
+ fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
+ fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
+ fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
+ fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
+ fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
+ fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
+ fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
+ fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
+ fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
+ fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
+ fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
+ fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
+ fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
+ fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
+ fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
+ fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
+ fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
+ fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
+ fprintf(stderr, "\n");
+}
+
+struct whisper_print_user_data {
+ const whisper_params * params;
+
+ const std::vector<std::vector<float>> * pcmf32s;
+};
+
+void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
+ const auto & params = *((whisper_print_user_data *) user_data)->params;
+ const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
+
+ const int n_segments = whisper_full_n_segments(ctx);
+
+ // print the last n_new segments
+ const int s0 = n_segments - n_new;
+ if (s0 == 0) {
+ printf("\n");
+ }
+
+ for (int i = s0; i < n_segments; i++) {
+ if (params.no_timestamps) {
+ if (params.print_colors) {
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+ if (params.print_special == false) {
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+ if (id >= whisper_token_eot(ctx)) {
+ continue;
+ }
+ }
+
+ const char * text = whisper_full_get_token_text(ctx, i, j);
+ const float p = whisper_full_get_token_p (ctx, i, j);
+
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+
+ printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
+ }
+ } else {
+ const char * text = whisper_full_get_segment_text(ctx, i);
+ printf("%s", text);
+ }
+ fflush(stdout);
+ } else {
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+ std::string speaker;
+
+ if (params.diarize && pcmf32s.size() == 2) {
+ const int64_t n_samples = pcmf32s[0].size();
+
+ const int64_t is0 = timestamp_to_sample(t0, n_samples);
+ const int64_t is1 = timestamp_to_sample(t1, n_samples);
+
+ double energy0 = 0.0f;
+ double energy1 = 0.0f;
+
+ for (int64_t j = is0; j < is1; j++) {
+ energy0 += fabs(pcmf32s[0][j]);
+ energy1 += fabs(pcmf32s[1][j]);
+ }
+
+ if (energy0 > 1.1*energy1) {
+ speaker = "(speaker 0)";
+ } else if (energy1 > 1.1*energy0) {
+ speaker = "(speaker 1)";
+ } else {
+ speaker = "(speaker ?)";
+ }
+
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+ }
+
+ if (params.print_colors) {
+ printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
+ if (params.print_special == false) {
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
+ if (id >= whisper_token_eot(ctx)) {
+ continue;
+ }
+ }
+
+ const char * text = whisper_full_get_token_text(ctx, i, j);
+ const float p = whisper_full_get_token_p (ctx, i, j);
+
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
+
+ printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
+ }
+ printf("\n");
+ } else {
+ const char * text = whisper_full_get_segment_text(ctx, i);
+
+ printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
+ }
+ }
+ }
+}
+
+bool output_txt(struct whisper_context * ctx, const char * fname) {
+ std::ofstream fout(fname);
+ if (!fout.is_open()) {
+ fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
+ return false;
+ }
+
+ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+ const int n_segments = whisper_full_n_segments(ctx);
+ for (int i = 0; i < n_segments; ++i) {
+ const char * text = whisper_full_get_segment_text(ctx, i);
+ fout << text << "\n";
+ }
+
+ return true;
+}
+
+bool output_vtt(struct whisper_context * ctx, const char * fname) {
+ std::ofstream fout(fname);
+ if (!fout.is_open()) {
+ fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
+ return false;
+ }
+
+ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+ fout << "WEBVTT\n\n";
+
+ const int n_segments = whisper_full_n_segments(ctx);
+ for (int i = 0; i < n_segments; ++i) {
+ const char * text = whisper_full_get_segment_text(ctx, i);
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+ fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
+ fout << text << "\n\n";
+ }
+
+ return true;
+}
+
+bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
+ std::ofstream fout(fname);
+ if (!fout.is_open()) {
+ fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
+ return false;
+ }
+
+ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+ const int n_segments = whisper_full_n_segments(ctx);
+ for (int i = 0; i < n_segments; ++i) {
+ const char * text = whisper_full_get_segment_text(ctx, i);
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+ fout << i + 1 + params.offset_n << "\n";
+ fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
+ fout << text << "\n\n";
+ }
+
+ return true;
+}
+
+// karaoke video generation
+// outputs a bash script that uses ffmpeg to generate a video with the subtitles
+// TODO: font parameter adjustments
+bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & /*params*/, float t_sec) {
+ std::ofstream fout(fname);
+
+ fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
+
+ // TODO: become parameter
+ static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
+
+ fout << "#!/bin/bash" << "\n";
+ fout << "\n";
+
+ fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
+
+ for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
+
+ const int n = whisper_full_n_tokens(ctx, i);
+
+ std::vector<whisper_token_data> tokens(n);
+ for (int j = 0; j < n; ++j) {
+ tokens[j] = whisper_full_get_token_data(ctx, i, j);
+ }
+
+ if (i > 0) {
+ fout << ",";
+ }
+
+ // background text
+ fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
+
+ bool is_first = true;
+
+ for (int j = 0; j < n; ++j) {
+ const auto & token = tokens[j];
+
+ if (tokens[j].id >= whisper_token_eot(ctx)) {
+ continue;
+ }
+
+ std::string txt_bg;
+ std::string txt_fg; // highlight token
+ std::string txt_ul; // underline
+
+ txt_bg = "> ";
+ txt_fg = "> ";
+ txt_ul = "\\ \\ ";
+
+ {
+ for (int k = 0; k < n; ++k) {
+ const auto & token2 = tokens[k];
+
+ if (tokens[k].id >= whisper_token_eot(ctx)) {
+ continue;
+ }
+
+ const std::string txt = whisper_token_to_str(ctx, token2.id);
+
+ txt_bg += txt;
+
+ if (k == j) {
+ for (int l = 0; l < (int) txt.size(); ++l) {
+ txt_fg += txt[l];
+ txt_ul += "_";
+ }
+ txt_fg += "|";
+ } else {
+ for (int l = 0; l < (int) txt.size(); ++l) {
+ txt_fg += "\\ ";
+ txt_ul += "\\ ";
+ }
+ }
+ }
+
+ ::replace_all(txt_bg, "'", "\u2019");
+ ::replace_all(txt_bg, "\"", "\\\"");
+ ::replace_all(txt_fg, "'", "\u2019");
+ ::replace_all(txt_fg, "\"", "\\\"");
+ }
+
+ if (is_first) {
+ // background text
+ fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
+ is_first = false;
+ }
+
+ // foreground text
+ fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
+
+ // underline
+ fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
+ }
+ }
+
+ fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n";
+
+ fout << "\n\n";
+ fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n";
+ fout << "\n";
+ fout << "echo \" ffplay " << fname_inp << ".mp4\"\n";
+ fout << "\n";
+
+ fout.close();
+
+ fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
+
+ return true;
+}
+
+int main(int argc, char ** argv) {
+ whisper_params params;
+
+ if (whisper_params_parse(argc, argv, params) == false) {
+ return 1;
+ }
+
+ if (params.fname_inp.empty()) {
+ fprintf(stderr, "error: no input files specified\n");
+ whisper_print_usage(argc, argv, params);
+ return 2;
+ }
+
+ if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
+ fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
+ whisper_print_usage(argc, argv, params);
+ exit(0);
+ }
+
+ // whisper init
+
+ struct whisper_context * ctx = whisper_init(params.model.c_str());
+
+ if (ctx == nullptr) {
+ fprintf(stderr, "error: failed to initialize whisper context\n");
+ return 3;
+ }
+
+ // initial prompt
+ std::vector<whisper_token> prompt_tokens;
+
+ if (!params.prompt.empty()) {
+ prompt_tokens.resize(1024);
+ prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
+
+ fprintf(stderr, "\n");
+ fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
+ fprintf(stderr, "initial tokens: [ ");
+ for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
+ fprintf(stderr, "%d ", prompt_tokens[i]);
+ }
+ fprintf(stderr, "]\n");
+ }
+
+ for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
+ const auto fname_inp = params.fname_inp[f];
+
+ std::vector<float> pcmf32; // mono-channel F32 PCM
+ std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
+
+ // WAV input
+ {
+ drwav wav;
+ std::vector<uint8_t> wav_data; // used for pipe input from stdin
+
+ if (fname_inp == "-") {
+ {
+ uint8_t buf[1024];
+ while (true)
+ {
+ const size_t n = fread(buf, 1, sizeof(buf), stdin);
+ if (n == 0) {
+ break;
+ }
+ wav_data.insert(wav_data.end(), buf, buf + n);
+ }
+ }
+
+ if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
+ fprintf(stderr, "error: failed to open WAV file from stdin\n");
+ return 4;
+ }
+
+ fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
+ }
+ else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
+ fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
+ return 5;
+ }
+
+ if (wav.channels != 1 && wav.channels != 2) {
+ fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
+ return 6;
+ }
+
+ if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
+ fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str());
+ return 6;
+ }
+
+ if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
+ fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
+ return 8;
+ }
+
+ if (wav.bitsPerSample != 16) {
+ fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
+ return 9;
+ }
+
+ const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
+
+ std::vector<int16_t> pcm16;
+ pcm16.resize(n*wav.channels);
+ drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
+ drwav_uninit(&wav);
+
+ // convert to mono, float
+ pcmf32.resize(n);
+ if (wav.channels == 1) {
+ for (uint64_t i = 0; i < n; i++) {
+ pcmf32[i] = float(pcm16[i])/32768.0f;
+ }
+ } else {
+ for (uint64_t i = 0; i < n; i++) {
+ pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
+ }
+ }
+
+ if (params.diarize) {
+ // convert to stereo, float
+ pcmf32s.resize(2);
+
+ pcmf32s[0].resize(n);
+ pcmf32s[1].resize(n);
+ for (uint64_t i = 0; i < n; i++) {
+ pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
+ pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
+ }
+ }
+ }
+
+ // print system information
+ {
+ fprintf(stderr, "\n");
+ fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
+ params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
+ }
+
+ // print some info about the processing
+ {
+ fprintf(stderr, "\n");
+ if (!whisper_is_multilingual(ctx)) {
+ if (params.language != "en" || params.translate) {
+ params.language = "en";
+ params.translate = false;
+ fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
+ }
+ }
+ fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
+ __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
+ params.n_threads, params.n_processors,
+ params.language.c_str(),
+ params.translate ? "translate" : "transcribe",
+ params.no_timestamps ? 0 : 1);
+
+ fprintf(stderr, "\n");
+ }
+
+ // run the inference
+ {
+ whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
+
+ wparams.print_realtime = false;
+ wparams.print_progress = params.print_progress;
+ wparams.print_timestamps = !params.no_timestamps;
+ wparams.print_special = params.print_special;
+ wparams.translate = params.translate;
+ wparams.language = params.language.c_str();
+ wparams.n_threads = params.n_threads;
+ wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
+ wparams.offset_ms = params.offset_t_ms;
+ wparams.duration_ms = params.duration_ms;
+
+ wparams.token_timestamps = params.output_wts || params.max_len > 0;
+ wparams.thold_pt = params.word_thold;
+ wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+
+ wparams.speed_up = params.speed_up;
+
+ wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
+ wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
+
+ whisper_print_user_data user_data = { &params, &pcmf32s };
+
+ // this callback is called on each new segment
+ if (!wparams.print_realtime) {
+ wparams.new_segment_callback = whisper_print_segment_callback;
+ wparams.new_segment_callback_user_data = &user_data;
+ }
+
+ // example for abort mechanism
+ // in this example, we do not abort the processing, but we could if the flag is set to true
+ // the callback is called before every encoder run - if it returns false, the processing is aborted
+ {
+ static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
+
+ wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) {
+ bool is_aborted = *(bool*)user_data;
+ return !is_aborted;
+ };
+ wparams.encoder_begin_callback_user_data = &is_aborted;
+ }
+
+ if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
+ fprintf(stderr, "%s: failed to process audio\n", argv[0]);
+ return 10;
+ }
+ }
+
+ // output stuff
+ {
+ printf("\n");
+
+ // output to text file
+ if (params.output_txt) {
+ const auto fname_txt = fname_inp + ".txt";
+ output_txt(ctx, fname_txt.c_str());
+ }
+
+ // output to VTT file
+ if (params.output_vtt) {
+ const auto fname_vtt = fname_inp + ".vtt";
+ output_vtt(ctx, fname_vtt.c_str());
+ }
+
+ // output to SRT file
+ if (params.output_srt) {
+ const auto fname_srt = fname_inp + ".srt";
+ output_srt(ctx, fname_srt.c_str(), params);
+ }
+
+ // output to WTS file
+ if (params.output_wts) {
+ const auto fname_wts = fname_inp + ".wts";
+ output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
+ }
+ }
+ }
+
+ whisper_print_timings(ctx);
+ whisper_free(ctx);
+
+ return 0;
+}
diff --git a/Examples/TranscribeCS/AnsiCodes.cs b/Examples/TranscribeCS/AnsiCodes.cs
new file mode 100644
index 0000000..be04ce3
--- /dev/null
+++ b/Examples/TranscribeCS/AnsiCodes.cs
@@ -0,0 +1,68 @@
+using System.Runtime.InteropServices;
+
+/// <summary>Utility class to setup console coloring with ANSI codes.</summary>
+/// <remarks>The feature requires Windows 10 or newer</remarks>
+static class AnsiCodes
+{
+ const string dll = "kernel32.dll";
+
+ [DllImport( dll, SetLastError = true )]
+ static extern IntPtr GetStdHandle( int nStdHandle );
+
+ const int STD_OUTPUT_HANDLE = -11;
+
+ [Flags]
+ enum ConsoleModes: uint
+ {
+ // Input flags
+ ENABLE_PROCESSED_INPUT = 0x0001,
+ ENABLE_LINE_INPUT = 0x0002,
+ ENABLE_ECHO_INPUT = 0x0004,
+ ENABLE_WINDOW_INPUT = 0x0008,
+ ENABLE_MOUSE_INPUT = 0x0010,
+ ENABLE_INSERT_MODE = 0x0020,
+ ENABLE_QUICK_EDIT_MODE = 0x0040,
+ ENABLE_EXTENDED_FLAGS = 0x0080,
+ ENABLE_AUTO_POSITION = 0x0100,
+ ENABLE_VIRTUAL_TERMINAL_INPUT = 0x0200,
+
+ // Output flags
+ ENABLE_PROCESSED_OUTPUT = 0x0001,
+ ENABLE_WRAP_AT_EOL_OUTPUT = 0x0002,
+ ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004,
+ DISABLE_NEWLINE_AUTO_RETURN = 0x0008,
+ ENABLE_LVB_GRID_WORLDWIDE = 0x0010
+ }
+
+ [DllImport( dll, SetLastError = true )]
+ static extern bool GetConsoleMode( IntPtr hConsoleHandle, out ConsoleModes mode );
+
+ [DllImport( dll, SetLastError = true )]
+ static extern bool SetConsoleMode( IntPtr hConsoleHandle, ConsoleModes mode );
+
+ static AnsiCodes()
+ {
+ IntPtr h = GetStdHandle( STD_OUTPUT_HANDLE );
+ IntPtr INVALID_HANDLE_VALUE = (IntPtr)( -1 );
+ if( h == INVALID_HANDLE_VALUE )
+ return;
+
+ if( !GetConsoleMode( h, out ConsoleModes mode ) )
+ return;
+
+ if( mode.HasFlag( ConsoleModes.ENABLE_VIRTUAL_TERMINAL_PROCESSING ) )
+ {
+ enabled = true;
+ return;
+ }
+
+ mode |= ConsoleModes.ENABLE_VIRTUAL_TERMINAL_PROCESSING;
+ if( SetConsoleMode( h, mode ) )
+ {
+ enabled = true;
+ return;
+ }
+ }
+
+ public static readonly bool enabled = false;
+} \ No newline at end of file
diff --git a/Examples/TranscribeCS/CommandLineArgs.cs b/Examples/TranscribeCS/CommandLineArgs.cs
new file mode 100644
index 0000000..4f9fb74
--- /dev/null
+++ b/Examples/TranscribeCS/CommandLineArgs.cs
@@ -0,0 +1,155 @@
+using System.Globalization;
+using System.Reflection;
+using Whisper;
+
+namespace TranscribeCS
+{
+ sealed record class CommandLineArgs
+ {
+ public int n_threads = Environment.ProcessorCount;
+ public int offset_t_ms = 0;
+ public int offset_n = 0;
+ public int duration_ms = 0;
+ public int max_context = -1;
+ public int max_len = 0;
+
+ public float word_thold = 0.01f;
+
+ public bool speed_up = false;
+ public bool translate = false;
+ public bool diarize = false;
+ public bool output_txt = false;
+ public bool output_vtt = false;
+ public bool output_srt = false;
+ public bool print_special = false;
+ public bool print_progress = false;
+ public bool print_colors = true;
+ public bool no_timestamps = false;
+ public int[]? prompt = null;
+
+ public eLanguage language = eLanguage.English;
+ public string model = string.Empty;
+ public readonly List<string> fileNames = new List<string>();
+
+ const bool output_wts = false;
+ public void apply( ref Parameters p )
+ {
+ p.setFlag( eFullParamsFlags.PrintRealtime, false );
+ p.setFlag( eFullParamsFlags.PrintProgress, print_progress );
+ p.setFlag( eFullParamsFlags.PrintTimestamps, !no_timestamps );
+ p.setFlag( eFullParamsFlags.PrintSpecial, print_special );
+ p.setFlag( eFullParamsFlags.Translate, translate );
+ p.language = language;
+ p.cpuThreads = n_threads;
+ if( max_context >= 0 )
+ p.n_max_text_ctx = max_context;
+ p.offset_ms = offset_t_ms;
+ p.duration_ms = duration_ms;
+ p.setFlag( eFullParamsFlags.TokenTimestamps, output_wts || max_len > 0 );
+ p.thold_pt = word_thold;
+ p.max_len = output_wts && max_len == 0 ? 60 : max_len;
+ p.setFlag( eFullParamsFlags.SpeedupAudio, speed_up );
+ }
+
+ public eResultFlags resultFlags()
+ {
+ eResultFlags flags = eResultFlags.None;
+ bool wts = output_wts || max_len > 0;
+ if( !no_timestamps || wts )
+ flags |= eResultFlags.Timestamps;
+ if( wts || print_colors )
+ flags |= eResultFlags.Tokens;
+ return flags;
+ }
+
+ static eLanguage parseLanguage( string lang ) =>
+ Library.languageFromCode( lang ) ?? throw new ArgumentException( $"Unknown language code \"{lang}\"" );
+
+ public CommandLineArgs( string[] argv )
+ {
+ for( int i = 0; i < argv.Length; i++ )
+ {
+ string arg = argv[ i ];
+ if( arg[ 0 ] != '-' )
+ {
+ fileNames.Add( arg );
+ continue;
+ }
+ if( arg == "-h" || arg == "--help" )
+ {
+ printUsage();
+ throw new OperationCanceledException();
+ }
+ else if( arg == "-t" || arg == "--threads" ) n_threads = int.Parse( argv[ ++i ] );
+ else if( arg == "-ot" || arg == "--offset-t" ) offset_t_ms = int.Parse( argv[ ++i ] );
+ else if( arg == "-on" || arg == "--offset-n" ) offset_n = int.Parse( argv[ ++i ] );
+ else if( arg == "-d" || arg == "--duration" ) duration_ms = int.Parse( argv[ ++i ] );
+ else if( arg == "-mc" || arg == "--max-context" ) max_context = int.Parse( argv[ ++i ] );
+ else if( arg == "-ml" || arg == "--max-len" ) max_len = int.Parse( argv[ ++i ] );
+ else if( arg == "-wt" || arg == "--word-thold" ) word_thold = float.Parse( argv[ ++i ], CultureInfo.InvariantCulture );
+ else if( arg == "-su" || arg == "--speed-up" ) speed_up = true;
+ else if( arg == "-tr" || arg == "--translate" ) translate = true;
+ else if( arg == "-di" || arg == "--diarize" ) diarize = true;
+ else if( arg == "-otxt" || arg == "--output-txt" ) output_txt = true;
+ else if( arg == "-ovtt" || arg == "--output-vtt" ) output_vtt = true;
+ else if( arg == "-osrt" || arg == "--output-srt" ) output_srt = true;
+ else if( arg == "-ps" || arg == "--print-special" ) print_special = true;
+ else if( arg == "-nc" || arg == "--no-colors" ) print_colors = false;
+ else if( arg == "-pp" || arg == "--print-progress" ) print_progress = true;
+ else if( arg == "-nt" || arg == "--no-timestamps" ) no_timestamps = true;
+ else if( arg == "-l" || arg == "--language" ) language = parseLanguage( argv[ ++i ] );
+ else if( arg == "--prompt" ) prompt = parsePrompt( argv[ ++i ] );
+ else if( arg == "-m" || arg == "--model" ) model = argv[ ++i ];
+ else if( arg == "-f" || arg == "--file" ) fileNames.Add( argv[ ++i ] );
+ else
+ throw new ArgumentException( $"Unknown argument: \"{arg}\"" );
+ }
+ if( string.IsNullOrWhiteSpace( model ) )
+ throw new ArgumentException( "The model file is not provided in the arguments" );
+ if( !File.Exists( model ) )
+ throw new FileNotFoundException( "Model not found", model );
+ if( fileNames.Count <= 0 )
+ throw new ArgumentException( "Please supply at least 1 input audio file to process" );
+ }
+
+ static string cstr( bool b ) => b.ToString();
+
+ static int[]? parsePrompt( string str )
+ {
+ if( string.IsNullOrWhiteSpace( str ) )
+ return null;
+ // TODO: expose whisper_tokenize function, as a method of iModel COM interface
+ throw new NotImplementedException();
+ }
+
+ void printUsage()
+ {
+ Console.WriteLine();
+
+ Console.WriteLine( "usage: {0} [options] file0.mp3 file1.wma ...", Path.GetFileName( Assembly.GetExecutingAssembly().Location ) );
+ Console.WriteLine();
+ Console.WriteLine( "options:" );
+ Console.WriteLine( " -h, --help [default] show this help message and exit" );
+ Console.WriteLine( " -t N, --threads N [{0,-7:D}] number of threads to use during computation", n_threads );
+ Console.WriteLine( " -ot N, --offset-t N [{0,-7:D}] time offset in milliseconds", offset_t_ms );
+ Console.WriteLine( " -on N, --offset-n N [{0,-7:D}] segment index offset", offset_n );
+ Console.WriteLine( " -d N, --duration N [{0,-7:D}] duration of audio to process in milliseconds", duration_ms );
+ Console.WriteLine( " -mc N, --max-context N [{0,-7:D}] maximum number of text context tokens to store", max_context );
+ Console.WriteLine( " -ml N, --max-len N [{0,-7:D}] maximum segment length in characters", max_len );
+ Console.WriteLine( " -wt N, --word-thold N [{0,-7:F2}] word timestamp probability threshold", word_thold );
+ Console.WriteLine( " -su, --speed-up [{0,-7}] speed up audio by x2 (reduced accuracy)", cstr( speed_up ) );
+ Console.WriteLine( " -tr, --translate [{0,-7}] translate from source language to english", cstr( translate ) );
+ Console.WriteLine( " -di, --diarize [{0,-7}] stereo audio diarization", cstr( diarize ) );
+ Console.WriteLine( " -otxt, --output-txt [{0,-7}] output result in a text file", cstr( output_txt ) );
+ Console.WriteLine( " -ovtt, --output-vtt [{0,-7}] output result in a vtt file", cstr( output_vtt ) );
+ Console.WriteLine( " -osrt, --output-srt [{0,-7}] output result in a srt file", cstr( output_srt ) );
+ Console.WriteLine( " -ps, --print-special [{0,-7}] print special tokens", cstr( print_special ) );
+ Console.WriteLine( " -nc, --no-colors [{0,-7}] do not print colors", cstr( !print_colors ) );
+ Console.WriteLine( " -nt, --no-timestamps [{0,-7}] do not print timestamps", cstr( no_timestamps ) );
+ Console.WriteLine( " -l LANG, --language LANG [{0,-7}] spoken language", language.getCode() );
+ Console.WriteLine( " --prompt PROMPT [ ] initial prompt" );
+ Console.WriteLine( " -m FNAME, --model FNAME [{0,-7}] model path", model );
+ Console.WriteLine( " -f FNAME, --file FNAME [{0,-7}] path of the input audio file", "" );
+ }
+ }
+} \ No newline at end of file
diff --git a/Examples/TranscribeCS/Transcribe.cs b/Examples/TranscribeCS/Transcribe.cs
new file mode 100644
index 0000000..6a1e500
--- /dev/null
+++ b/Examples/TranscribeCS/Transcribe.cs
@@ -0,0 +1,114 @@
+using System.Globalization;
+using Whisper;
+
+namespace TranscribeCS
+{
+ /// <summary>Implementation of Callbacks abstract class, to print these segments as soon as they’re produced by the library.</summary>
+ sealed class Transcribe: Callbacks
+ {
+ readonly CommandLineArgs args;
+ readonly eResultFlags resultFlags;
+
+ public Transcribe( CommandLineArgs args )
+ {
+ this.args = args;
+ resultFlags = args.resultFlags();
+ Console.OutputEncoding = System.Text.Encoding.UTF8;
+ }
+
+ // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
+ // Lowest is red, middle is yellow, highest is green.
+ readonly string[] k_colors = new string[]
+ {
+ "\x1B[38;5;196m", "\x1B[38;5;202m", "\x1B[38;5;208m", "\x1B[38;5;214m", "\x1B[38;5;220m",
+ "\x1B[38;5;226m", "\x1B[38;5;190m", "\x1B[38;5;154m", "\x1B[38;5;118m", "\x1B[38;5;82m"
+ };
+
+ int colorIndex( in sToken tok )
+ {
+ float p = tok.probability;
+ float p3 = p * p * p;
+ int col = (int)( p3 * k_colors.Length );
+ col = Math.Clamp( col, 0, k_colors.Length - 1 );
+ return col;
+ }
+
+ public static string printTime( TimeSpan ts ) =>
+ ts.ToString( "hh':'mm':'ss'.'fff", CultureInfo.InvariantCulture );
+ public static string printTimeWithComma( TimeSpan ts ) =>
+ ts.ToString( "hh':'mm':'ss','fff", CultureInfo.InvariantCulture );
+
+ protected override void onNewSegment( Context sender, int countNew )
+ {
+ TranscribeResult res = sender.results( resultFlags );
+ ReadOnlySpan<sToken> tokens = res.tokens;
+
+ int s0 = res.segments.Length - countNew;
+ if( s0 == 0 )
+ Console.WriteLine();
+
+ for( int i = s0; i < res.segments.Length; i++ )
+ {
+ sSegment seg = res.segments[ i ];
+
+ if( args.no_timestamps )
+ {
+ if( args.print_colors && AnsiCodes.enabled )
+ {
+ foreach( sToken tok in res.getTokens( seg ) )
+ {
+ if( !args.print_special && tok.hasFlag( eTokenFlags.Special ) )
+ continue;
+ Console.Write( "{0}{1}{2}", k_colors[ colorIndex( tok ) ], tok.text, "\x1B[0m" );
+ }
+ }
+ else
+ Console.Write( seg.text );
+ Console.Out.Flush();
+ continue;
+ }
+
+ string speaker = "";
+#if false
+ if( args.diarize && pcmf32s.size() == 2 )
+ {
+ const size_t n_samples = pcmf32s[ 0 ].size();
+ const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples );
+ const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples );
+
+ double energy0 = 0.0f;
+ double energy1 = 0.0f;
+
+ for( int64_t j = is0; j < is1; j++ )
+ {
+ energy0 += fabs( pcmf32s[ 0 ][ j ] );
+ energy1 += fabs( pcmf32s[ 1 ][ j ] );
+ }
+
+ if( energy0 > 1.1 * energy1 )
+ speaker = "(speaker 0)";
+ else if( energy1 > 1.1 * energy0 )
+ speaker = "(speaker 1)";
+ else
+ speaker = "(speaker ?)";
+
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+ }
+#endif
+ if( args.print_colors && AnsiCodes.enabled )
+ {
+ Console.Write( "[{0} --> {1}] ", printTime( seg.time.begin ), printTime( seg.time.end ) );
+ foreach( sToken tok in res.getTokens( seg ) )
+ {
+ if( !args.print_special && tok.hasFlag( eTokenFlags.Special ) )
+ continue;
+ Console.Write( "{0}{1}{2}{3}", speaker, k_colors[ colorIndex( tok ) ], tok.text, "\x1B[0m" );
+ }
+ Console.WriteLine();
+ }
+ else
+ Console.WriteLine( "[{0} --> {1}] {2}{3}", printTime( seg.time.begin ), printTime( seg.time.end ), speaker, seg.text );
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Examples/TranscribeCS/TranscribeCS.cs b/Examples/TranscribeCS/TranscribeCS.cs
new file mode 100644
index 0000000..9b828e3
--- /dev/null
+++ b/Examples/TranscribeCS/TranscribeCS.cs
@@ -0,0 +1,102 @@
+using Whisper;
+
+namespace TranscribeCS
+{
+ static class Program
+ {
+ static readonly bool streamAudio = true;
+
+ static int Main( string[] args )
+ {
+ try
+ {
+ CommandLineArgs cla;
+ try
+ {
+ cla = new CommandLineArgs( args );
+ }
+ catch( OperationCanceledException )
+ {
+ return 1;
+ }
+ const eLoggerFlags loggerFlags = eLoggerFlags.UseStandardError | eLoggerFlags.SkipFormatMessage;
+ Library.setLogSink( eLogLevel.Debug, loggerFlags );
+
+ using iModel model = Library.loadModel( cla.model );
+ using Context context = model.createContext();
+ cla.apply( ref context.parameters );
+ using iMediaFoundation mf = Library.initMediaFoundation();
+ Transcribe transcribe = new Transcribe( cla );
+
+ foreach( string audioFile in cla.fileNames )
+ {
+ if( streamAudio )
+ {
+ using iAudioReader reader = mf.openAudioFile( audioFile );
+ context.runFull( reader, transcribe, null, cla.prompt );
+ }
+ else
+ {
+ using iAudioBuffer buffer = mf.loadAudioFile( audioFile );
+ context.runFull( buffer, transcribe, cla.prompt );
+ }
+ // When asked to, produce these text files
+ if( cla.output_txt )
+ writeTextFile( context, audioFile );
+ if( cla.output_srt )
+ writeSubRip( context, audioFile, cla );
+ if( cla.output_vtt )
+ writeWebVTT( context, audioFile );
+ }
+
+ context.timingsPrint();
+ return 0;
+ }
+ catch( Exception ex )
+ {
+ Console.WriteLine( ex.Message );
+ return ex.HResult;
+ }
+ }
+
+ static void writeTextFile( Context context, string audioPath )
+ {
+ using var stream = File.CreateText( Path.ChangeExtension( audioPath, ".txt" ) );
+ foreach( sSegment seg in context.results().segments )
+ stream.WriteLine( seg.text );
+ }
+
+ static void writeSubRip( Context context, string audioPath, CommandLineArgs cliArgs )
+ {
+ using var stream = File.CreateText( Path.ChangeExtension( audioPath, ".srt" ) );
+ var segments = context.results( eResultFlags.Timestamps ).segments;
+
+ for( int i = 0; i < segments.Length; i++ )
+ {
+ stream.WriteLine( i + 1 + cliArgs.offset_n );
+ sSegment seg = segments[ i ];
+ string begin = Transcribe.printTimeWithComma( seg.time.begin );
+ string end = Transcribe.printTimeWithComma( seg.time.end );
+ stream.WriteLine( "{0} --> {1}", begin, end );
+ stream.WriteLine( seg.text );
+ stream.WriteLine();
+ }
+ }
+
+ static void writeWebVTT( Context context, string audioPath )
+ {
+ using var stream = File.CreateText( Path.ChangeExtension( audioPath, ".vtt" ) );
+ stream.WriteLine( "WEBVTT" );
+ stream.WriteLine();
+
+ foreach( sSegment seg in context.results( eResultFlags.Timestamps ).segments )
+ {
+ string begin = Transcribe.printTime( seg.time.begin );
+ string end = Transcribe.printTime( seg.time.end );
+ stream.WriteLine( "{0} --> {1}", begin, end );
+ stream.WriteLine( seg.text );
+ stream.WriteLine();
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Examples/TranscribeCS/TranscribeCS.csproj b/Examples/TranscribeCS/TranscribeCS.csproj
new file mode 100644
index 0000000..e9b8d0f
--- /dev/null
+++ b/Examples/TranscribeCS/TranscribeCS.csproj
@@ -0,0 +1,19 @@
+<Project Sdk="Microsoft.NET.Sdk">
+ <PropertyGroup>
+ <OutputType>Exe</OutputType>
+ <TargetFramework>net6.0-windows</TargetFramework>
+ <ImplicitUsings>enable</ImplicitUsings>
+ <Nullable>enable</Nullable>
+ <CheckForOverflowUnderflow>true</CheckForOverflowUnderflow>
+ <AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath>
+ <Platforms>x64</Platforms>
+ </PropertyGroup>
+ <ItemGroup>
+ <Content Include="..\..\x64\$(Configuration)\Whisper.dll" Link="Whisper.dll">
+ <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+ </Content>
+ </ItemGroup>
+ <ItemGroup>
+ <ProjectReference Include="..\..\WhisperNet\WhisperNet.csproj" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/WhisperDesktop/AppState.cpp b/Examples/WhisperDesktop/AppState.cpp
new file mode 100644
index 0000000..6697e6a
--- /dev/null
+++ b/Examples/WhisperDesktop/AppState.cpp
@@ -0,0 +1,192 @@
+#include "stdafx.h"
+#include "AppState.h"
+#include "Utils/miscUtils.h"
+#include <commctrl.h>
+#pragma comment(lib, "Comctl32.lib")
+// #pragma comment(linker,"/manifestdependency:\"type='win32' name='Microsoft.Windows.Common-Controls' version='6.0.0.0' processorArchitecture='amd64' publicKeyToken='6595b64144ccf1df' language='*'\"")
+#include "CircleIndicator.h"
+
+namespace
+{
+ static const HKEY regKeyRoot = HKEY_CURRENT_USER;
+ const LPCTSTR regKey = LR"(SOFTWARE\const.me\WhisperDesktop)";
+ const LPCTSTR regValPath = L"modelPath";
+ const LPCTSTR regValImpl = L"modelImpl";
+ const LPCTSTR regValLang = L"language";
+ const LPCTSTR regValLastScreen = L"screen";
+
+ static HRESULT readString( CRegKey& k, LPCTSTR name, CString& rdi )
+ {
+ ULONG nChars = 0;
+ LSTATUS lss = k.QueryStringValue( name, nullptr, &nChars );
+ if( lss != ERROR_SUCCESS )
+ return HRESULT_FROM_WIN32( lss );
+ if( nChars == 0 )
+ {
+ rdi = L"";
+ return S_FALSE;
+ }
+
+ lss = k.QueryStringValue( name, rdi.GetBufferSetLength( nChars ), &nChars );
+ rdi.ReleaseBuffer();
+ if( lss != ERROR_SUCCESS )
+ return HRESULT_FROM_WIN32( lss );
+
+ return S_OK;
+ }
+
+ using Whisper::eModelImplementation;
+}
+
+HRESULT AppState::startup()
+{
+ HRESULT hr = CoInitializeEx( nullptr, COINIT_MULTITHREADED );
+ if( FAILED( hr ) )
+ {
+ reportFatalError( "CoInitializeEx failed", hr );
+ return hr;
+ }
+ coInit = true;
+
+ LSTATUS lss = registryKey.Create( regKeyRoot, regKey );
+ if( lss != ERROR_SUCCESS )
+ {
+ hr = HRESULT_FROM_WIN32( lss );
+ reportFatalError( "Unable to open the registry key", hr );
+ return hr;
+ }
+
+ INITCOMMONCONTROLSEX init;
+ init.dwSize = sizeof( init );
+ init.dwICC = ICC_LINK_CLASS | ICC_PROGRESS_CLASS | ICC_STANDARD_CLASSES | ICC_TAB_CLASSES;
+ const BOOL icc = InitCommonControlsEx( &init );
+ if( !icc )
+ {
+ reportFatalError( "InitCommonControlsEx failed", HRESULT_FROM_WIN32( GetLastError() ) );
+ return E_FAIL;
+ }
+
+ hr = initMediaFoundation( &mediaFoundation );
+ if( FAILED( hr ) )
+ {
+ reportFatalError( "Unable to initialize Media Foundation runtime", hr );
+ return hr;
+ }
+
+ hr = console.initialize();
+ if( FAILED( hr ) )
+ {
+ reportFatalError( "Unable to initialize logging", hr );
+ return hr;
+ }
+
+ hr = CircleIndicator::registerClass();
+ if( FAILED( hr ) )
+ {
+ reportFatalError( "Unable to register custom controls", hr );
+ return hr;
+ }
+ appIcon.LoadIcon( IDI_WHISPERDESKTOP );
+ return S_OK;
+}
+
+AppState::~AppState()
+{
+ if( coInit )
+ {
+ CoUninitialize();
+ coInit = false;
+ }
+}
+
+HRESULT AppState::findModelSource()
+{
+ CHECK( readString( registryKey, regValPath, source.path ) );
+
+ {
+ CAtlFile file;
+ CHECK( file.Create( source.path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING ) );
+ ULONGLONG len;
+ CHECK( file.GetSize( len ) );
+ source.sizeInBytes = len;
+ }
+
+ CString impl;
+ CHECK( readString( registryKey, regValImpl, impl ) );
+ CHECK( implParse( impl, source.impl ) );
+ source.found = true;
+ return S_OK;
+}
+
+HRESULT AppState::saveModelSource()
+{
+ LSTATUS lss = registryKey.SetStringValue( regValPath, source.path );
+ if( lss != ERROR_SUCCESS )
+ return HRESULT_FROM_WIN32( lss );
+
+ LPCTSTR impl = implString( source.impl );
+ if( nullptr == impl )
+ return E_INVALIDARG;
+ lss = registryKey.SetStringValue( regValImpl, impl );
+ if( lss != ERROR_SUCCESS )
+ return HRESULT_FROM_WIN32( lss );
+
+ return S_OK;
+}
+
+uint32_t AppState::languageRead()
+{
+ DWORD dw;
+ LSTATUS lss = registryKey.QueryDWORDValue( regValLang, dw );
+ if( lss == ERROR_SUCCESS )
+ return dw;
+ return UINT_MAX;
+}
+
+void AppState::languageWrite( uint32_t key )
+{
+ registryKey.SetDWORDValue( regValLang, key );
+}
+
+CString AppState::stringLoad( LPCTSTR name )
+{
+ CString res;
+ readString( registryKey, name, res );
+ return res;
+}
+void AppState::stringStore( LPCTSTR name, LPCTSTR value )
+{
+ registryKey.SetStringValue( name, value );
+}
+uint32_t AppState::dwordLoad( LPCTSTR name, uint32_t fallback )
+{
+ DWORD dw;
+ LSTATUS lss = registryKey.QueryDWORDValue( name, dw );
+ if( lss == ERROR_SUCCESS )
+ return dw;
+ return fallback;
+}
+void AppState::dwordStore( LPCTSTR name, uint32_t value )
+{
+ registryKey.SetDWORDValue( name, value );
+}
+
+void AppState::lastScreenSave( HRESULT code )
+{
+ dwordStore( regValLastScreen, (uint32_t)code );
+}
+
+HRESULT AppState::lastScreenLoad()
+{
+ return (HRESULT)dwordLoad( regValLastScreen, SCREEN_TRANSCRIBE );
+}
+
+void AppState::setupIcon( CWindow* wnd )
+{
+ HICON ic = appIcon;
+ if( nullptr != ic )
+ {
+ wnd->SendMessage( WM_SETICON, ICON_SMALL, (LPARAM)ic );
+ wnd->SendMessage( WM_SETICON, ICON_BIG, (LPARAM)ic );
+ }
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/AppState.h b/Examples/WhisperDesktop/AppState.h
new file mode 100644
index 0000000..0a45a21
--- /dev/null
+++ b/Examples/WhisperDesktop/AppState.h
@@ -0,0 +1,51 @@
+#pragma once
+#include "Utils/DebugConsole.h"
+
+class AppState
+{
+ bool coInit = false;
+ CRegKey registryKey;
+ CIcon appIcon;
+public:
+
+ struct ModelSource
+ {
+ CString path;
+ bool found = false;
+ Whisper::eModelImplementation impl = (Whisper::eModelImplementation)0;
+ uint64_t sizeInBytes = 0;
+ };
+ ModelSource source;
+
+ DebugConsole console;
+ CComPtr<Whisper::iMediaFoundation> mediaFoundation;
+ CComPtr<Whisper::iModel> model;
+
+ ~AppState();
+
+ // Setup the initial things
+ HRESULT startup();
+
+ HRESULT findModelSource();
+
+ HRESULT saveModelSource();
+
+ uint32_t languageRead();
+ void languageWrite( uint32_t key );
+
+ CString stringLoad( LPCTSTR name );
+ void stringStore( LPCTSTR name, LPCTSTR value );
+ uint32_t dwordLoad( LPCTSTR name, uint32_t fallback );
+ void dwordStore( LPCTSTR name, uint32_t value );
+
+ bool automaticallyLoadModel = true;
+
+ void lastScreenSave( HRESULT code );
+ HRESULT lastScreenLoad();
+
+ void setupIcon( CWindow* wnd );
+};
+
+constexpr HRESULT SCREEN_MODEL = 1;
+constexpr HRESULT SCREEN_TRANSCRIBE = 2;
+constexpr HRESULT SCREEN_CAPTURE = 3; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/CaptureDlg.cpp b/Examples/WhisperDesktop/CaptureDlg.cpp
new file mode 100644
index 0000000..f1030dd
--- /dev/null
+++ b/Examples/WhisperDesktop/CaptureDlg.cpp
@@ -0,0 +1,505 @@
+#include "stdafx.h"
+#include "CaptureDlg.h"
+
+HRESULT CaptureDlg::show()
+{
+ auto res = DoModal( nullptr );
+ if( res == -1 )
+ return HRESULT_FROM_WIN32( GetLastError() );
+ switch( res )
+ {
+ case IDC_BACK:
+ return SCREEN_MODEL;
+ case IDC_TRANSCRIBE:
+ return SCREEN_TRANSCRIBE;
+ }
+ return S_OK;
+}
+
+static const LPCTSTR regValDevice = L"captureDevice";
+static const LPCTSTR regValOutPath = L"captureTextFile";
+static const LPCTSTR regValOutFormat = L"captureTextFlags";
+
+enum struct CaptureDlg::eTextFlags : uint32_t
+{
+ Save = 1,
+ Append = 2,
+ Timestamps = 4,
+};
+
+LRESULT CaptureDlg::OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled )
+{
+ // First DDX call, hooks up variables to controls.
+ DoDataExchange( false );
+
+ languageSelector.initialize( m_hWnd, IDC_LANGUAGE, appState );
+ cbTranslate.initialize( m_hWnd, IDC_TRANSLATE, appState );
+ cbConsole.initialize( m_hWnd, IDC_CONSOLE, appState );
+
+ pendingState.initialize(
+ // Controls to disable while pending, re-enable afterwards
+ {
+ languageSelector,
+ cbCaptureDevice,
+ checkSave, checkAppend, checkTimestamps, transcribeOutputPath, transcribeOutputBrowse,
+ GetDlgItem( IDC_DEV_REFRESH ),
+ GetDlgItem( IDC_BACK ),
+ GetDlgItem( IDC_TRANSCRIBE ),
+ GetDlgItem( IDCANCEL ),
+ },
+ // Controls to show while pending, hide afterwards
+ {
+ voiceActivity, GetDlgItem( IDC_VOICE_ACTIVITY_LBL ),
+ transcribeActivity, GetDlgItem( IDC_TRANS_LBL ),
+ stalled, GetDlgItem( IDC_STALL_LBL ),
+ progressBar,
+ } );
+
+ stalled.setActiveColor( flipRgb( 0xffcc33 ) );
+
+ HRESULT hr = work.create( this );
+ if( FAILED( hr ) )
+ {
+ reportError( m_hWnd, L"CreateThreadpoolWork failed", nullptr, hr );
+ EndDialog( IDCANCEL );
+ }
+
+ listDevices();
+ selectDevice( appState.stringLoad( regValDevice ) );
+
+ constexpr uint32_t defaultFlags = (uint32_t)eTextFlags::Append;
+ uint32_t flags = appState.dwordLoad( regValOutFormat, defaultFlags );
+ if( flags & (uint32_t)eTextFlags::Save )
+ checkSave.SetCheck( BST_CHECKED );
+ if( flags & (uint32_t)eTextFlags::Append )
+ checkAppend.SetCheck( BST_CHECKED );
+ if( flags & (uint32_t)eTextFlags::Timestamps )
+ checkTimestamps.SetCheck( BST_CHECKED );
+
+ transcribeOutputPath.SetWindowText( appState.stringLoad( regValOutPath ) );
+ onSaveTextCheckbox();
+
+ appState.lastScreenSave( SCREEN_CAPTURE );
+ appState.setupIcon( this );
+ ATLVERIFY( CenterWindow() );
+ return 0;
+}
+
+HRESULT __stdcall CaptureDlg::listDevicesCallback( int len, const Whisper::sCaptureDevice* buffer, void* pv ) noexcept
+{
+ std::vector<sCaptureDevice>& devices = *( std::vector<sCaptureDevice> * )pv;
+ devices.resize( len );
+ for( int i = 0; i < len; i++ )
+ {
+ devices[ i ].displayName = buffer[ i ].displayName;
+ devices[ i ].endpoint = buffer[ i ].endpoint;
+ }
+ return S_OK;
+}
+
+bool CaptureDlg::listDevices()
+{
+ appState.mediaFoundation->listCaptureDevices( &listDevicesCallback, &devices );
+ cbCaptureDevice.ResetContent();
+ for( const auto& dev : devices )
+ cbCaptureDevice.AddString( dev.displayName );
+ return !devices.empty();
+}
+
+void CaptureDlg::onDeviceRefresh()
+{
+ // Save the current selection
+ const int curSel = cbCaptureDevice.GetCurSel();
+ CString str;
+ if( curSel >= 0 && curSel < (int)devices.size() )
+ str = std::move( devices[ curSel ].endpoint );
+
+ // Refresh
+ listDevices();
+
+ // Restore the selection
+ selectDevice( str );
+
+ const size_t len = devices.size();
+ if( len == 0 )
+ {
+ MessageBox( L"No capture devices found on this computer.\nIf you have a USB microphone, connect it to this PC,\nand press “refresh” button.",
+ L"Capture Devices", MB_OK | MB_ICONWARNING );
+ }
+ else
+ {
+ const char* suffix = ( len != 1 ) ? "s" : "";
+ str.Format( L"Detected %zu audio capture device%S.", len, suffix );
+ MessageBox( str, L"Capture Devices", MB_OK | MB_ICONINFORMATION );
+ }
+}
+
+bool CaptureDlg::selectDevice( LPCTSTR endpoint )
+{
+ if( nullptr != endpoint && 0 != *endpoint )
+ {
+ for( size_t i = 0; i < devices.size(); i++ )
+ {
+ if( devices[ i ].endpoint == endpoint )
+ {
+ cbCaptureDevice.SetCurSel( (int)i );
+ return true;
+ }
+ }
+ }
+
+ if( !devices.empty() )
+ cbCaptureDevice.SetCurSel( 0 );
+ return false;
+}
+
+void CaptureDlg::onSaveTextCheckbox()
+{
+ const BOOL enabled = ( checkSave.GetCheck() == BST_CHECKED );
+ std::array<HWND, 4> controls = { checkAppend, checkTimestamps, transcribeOutputPath, transcribeOutputBrowse };
+ for( HWND w : controls )
+ ::EnableWindow( w, enabled );
+}
+
+void CaptureDlg::onBrowseResult()
+{
+ LPCTSTR title = L"Output Text File";
+ LPCTSTR outputFilters = L"Text files (*.txt)\0*.txt\0\0";
+ CString path;
+ transcribeOutputPath.GetWindowText( path );
+ if( !getSaveFileName( m_hWnd, title, outputFilters, path ) )
+ return;
+
+ LPCTSTR ext = PathFindExtension( path );
+ if( 0 == *ext )
+ {
+ wchar_t* const buffer = path.GetBufferSetLength( path.GetLength() + 5 );
+ PathRenameExtension( buffer, L".txt" );
+ path.ReleaseBuffer();
+ }
+
+ transcribeOutputPath.SetWindowText( path );
+}
+
+CaptureDlg::eTextFlags CaptureDlg::getOutputFlags()
+{
+ uint32_t flags = 0;
+ if( checkSave.GetCheck() == BST_CHECKED )
+ flags |= (uint32_t)eTextFlags::Save;
+ if( checkAppend.GetCheck() == BST_CHECKED )
+ flags |= (uint32_t)eTextFlags::Append;
+ if( checkTimestamps.GetCheck() == BST_CHECKED )
+ flags |= (uint32_t)eTextFlags::Timestamps;
+ return (eTextFlags)flags;
+}
+
+void CaptureDlg::setPending( bool nowPending )
+{
+ pendingState.setPending( nowPending );
+ if( nowPending )
+ {
+ progressBar.SetMarquee( TRUE, 0 );
+ btnRunCapture.SetWindowText( L"Stop" );
+ }
+ else
+ {
+ progressBar.SetMarquee( FALSE, 0 );
+ btnRunCapture.SetWindowText( L"Capture" );
+ btnRunCapture.EnableWindow( TRUE );
+ captureRunning = false;
+ }
+}
+
+void CaptureDlg::onRunCapture()
+{
+ if( captureRunning )
+ {
+ threadState.stopRequested = true;
+ btnRunCapture.EnableWindow( FALSE );
+ return;
+ }
+
+ int dev = cbCaptureDevice.GetCurSel();
+ if( dev < 0 || dev >= (int)devices.size() )
+ {
+ showError( L"Please select a capture device", S_FALSE );
+ return;
+ }
+ threadState.endpoint = devices[ dev ].endpoint;
+ threadState.language = languageSelector.selectedLanguage();
+ threadState.translate = cbTranslate.checked();
+ if( isInvalidTranslate( m_hWnd, threadState.language, threadState.translate ) )
+ return;
+
+ threadState.flags = getOutputFlags();
+ if( (uint32_t)threadState.flags & (uint32_t)eTextFlags::Save )
+ {
+ transcribeOutputPath.GetWindowText( threadState.textOutputPath );
+ if( threadState.textOutputPath.GetLength() <= 0 )
+ {
+ showError( L"Please specify the output text file", S_FALSE );
+ return;
+ }
+ appState.stringStore( regValOutPath, threadState.textOutputPath );
+ }
+ else
+ cbConsole.ensureChecked();
+
+ languageSelector.saveSelection( appState );
+ cbTranslate.saveSelection( appState );
+ appState.stringStore( regValDevice, threadState.endpoint );
+ appState.dwordStore( regValOutFormat, (uint32_t)threadState.flags );
+
+ captureRunning = true;
+ threadState.errorMessage = L"";
+ threadState.stopRequested = false;
+ threadState.captureParams.minDuration = 7;
+ threadState.captureParams.maxDuration = 11;
+ setPending( true );
+ work.post();
+}
+
+void __declspec( noinline ) CaptureDlg::getThreadError()
+{
+ getLastError( threadState.errorMessage );
+}
+
+#define CHECK_EX( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) { getThreadError(); return __hr; } }
+
+static HRESULT appendDate( CString& str, const SYSTEMTIME& time )
+{
+ constexpr DWORD dateFlags = DATE_LONGDATE;
+ int cc = GetDateFormatEx( LOCALE_NAME_USER_DEFAULT, dateFlags, &time, nullptr, nullptr, 0, nullptr );
+ if( 0 == cc )
+ return getLastHr();
+
+ const int oldLength = str.GetLength();
+ wchar_t* const buffer = str.GetBufferSetLength( oldLength + cc );
+ cc = GetDateFormatEx( LOCALE_NAME_USER_DEFAULT, dateFlags, &time, nullptr, buffer + oldLength, cc, nullptr );
+ if( 0 != cc )
+ {
+ str.ReleaseBuffer();
+ return S_OK;
+ }
+ HRESULT hr = getLastHr();
+ str.ReleaseBuffer();
+ return hr;
+}
+
+static HRESULT appendTime( CString& str, const SYSTEMTIME& time )
+{
+ constexpr DWORD timeFlags = 0;
+ int cc = GetTimeFormatEx( LOCALE_NAME_USER_DEFAULT, timeFlags, &time, nullptr, nullptr, 0 );
+ if( 0 == cc )
+ return getLastHr();
+
+ const int oldLength = str.GetLength();
+ wchar_t* const buffer = str.GetBufferSetLength( oldLength + cc );
+ cc = GetTimeFormatEx( LOCALE_NAME_USER_DEFAULT, timeFlags, &time, nullptr, buffer + oldLength, cc );
+ if( 0 != cc )
+ {
+ str.ReleaseBuffer();
+ return S_OK;
+ }
+ HRESULT hr = getLastHr();
+ str.ReleaseBuffer();
+ return hr;
+}
+
+static HRESULT printDateTime( CAtlFile& file )
+{
+ SYSTEMTIME time;
+ GetLocalTime( &time );
+
+ CString str;
+ str = L"==== Captured on ";
+ CHECK( appendDate( str, time ) );
+ str += L", ";
+ CHECK( appendTime( str, time ) );
+ str += L" ====\r\n";
+
+ CStringA u8;
+ makeUtf8( u8, str );
+ return file.Write( cstr( u8 ), (DWORD)u8.GetLength() );
+}
+
+inline HRESULT CaptureDlg::runCapture()
+{
+ clearLastError();
+ using namespace Whisper;
+ CComPtr<iAudioCapture> capture;
+ CHECK_EX( appState.mediaFoundation->openCaptureDevice( threadState.endpoint, threadState.captureParams, &capture ) );
+
+ HRESULT hr;
+ CAtlFile file;
+ const uint32_t flags = (uint32_t)threadState.flags;
+ if( flags & (uint32_t)eTextFlags::Save )
+ {
+ const bool append = 0 != ( flags & (uint32_t)eTextFlags::Append );
+ const DWORD creation = append ? OPEN_ALWAYS : CREATE_ALWAYS;
+ hr = file.Create( threadState.textOutputPath, GENERIC_WRITE, FILE_SHARE_READ, creation );
+ if( FAILED( hr ) )
+ {
+ threadState.errorMessage = L"Unable to create the output text file";
+ return hr;
+ }
+ if( append )
+ {
+ ULONGLONG size;
+ CHECK( file.GetSize( size ) );
+ if( size == 0 )
+ CHECK( writeUtf8Bom( file ) )
+ else
+ CHECK( file.Seek( 0, SEEK_END ) );
+ }
+ else
+ {
+ CHECK( writeUtf8Bom( file ) );
+ }
+
+ if( flags & (uint32_t)eTextFlags::Timestamps )
+ CHECK( printDateTime( file ) );
+
+ threadState.file = &file;
+ }
+ else
+ threadState.file = nullptr;
+
+ CComPtr<iContext> context;
+ CHECK_EX( appState.model->createContext( &context ) );
+
+ sFullParams fullParams;
+ CHECK_EX( context->fullDefaultParams( eSamplingStrategy::Greedy, &fullParams ) );
+ fullParams.language = threadState.language;
+ fullParams.setFlag( eFullParamsFlags::Translate, threadState.translate );
+ fullParams.resetFlag( eFullParamsFlags::PrintRealtime );
+ fullParams.new_segment_callback = &newSegmentCallback;
+ fullParams.new_segment_callback_user_data = this;
+
+ sCaptureCallbacks callbacks;
+ callbacks.shouldCancel = &cbCancel;
+ callbacks.captureStatus = &cbStatus;
+ callbacks.pv = this;
+
+ CHECK_EX( context->runCapture( fullParams, callbacks, capture ) );
+ threadState.file = nullptr;
+
+ context->timingsPrint();
+ return S_OK;
+}
+
+void __stdcall CaptureDlg::poolCallback() noexcept
+{
+ const HRESULT hr = runCapture();
+ PostMessage( WM_CALLBACK_COMPLETION, hr );
+}
+
+void CaptureDlg::showError( LPCTSTR text, HRESULT hr )
+{
+ reportError( m_hWnd, text, L"Capture failed", hr );
+}
+
+LRESULT CaptureDlg::onThreadQuit( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled )
+{
+ setPending( false );
+
+ const HRESULT hr = (HRESULT)wParam;
+ if( FAILED( hr ) )
+ {
+ LPCTSTR failMessage = L"Capture failed";
+
+ if( threadState.errorMessage.GetLength() > 0 )
+ {
+ CString tmp = failMessage;
+ tmp += L"\n";
+ tmp += threadState.errorMessage;
+ showError( tmp, hr );
+ }
+ else
+ showError( failMessage, hr );
+
+ return 0;
+ }
+ else
+ {
+ if( (uint32_t)threadState.flags & (uint32_t)eTextFlags::Save )
+ ShellExecute( NULL, L"open", threadState.textOutputPath, NULL, NULL, SW_SHOW );
+ }
+
+ return 0;
+}
+
+LRESULT CaptureDlg::onThreadStatus( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled )
+{
+ using namespace Whisper;
+ const uint8_t newStatus = (uint8_t)wParam;
+ // Update the GUI
+ voiceActivity.setActive( 0 != ( newStatus & (uint8_t)eCaptureStatus::Voice ) );
+ transcribeActivity.setActive( 0 != ( newStatus & (uint8_t)eCaptureStatus::Transcribing ) );
+ stalled.setActive( 0 != ( newStatus & (uint8_t)eCaptureStatus::Stalled ) );
+ return 0;
+}
+
+HRESULT __stdcall CaptureDlg::cbCancel( void* pv ) noexcept
+{
+ CaptureDlg& dialog = *(CaptureDlg*)pv;
+ return dialog.threadState.stopRequested ? S_OK : S_FALSE;
+}
+
+HRESULT __stdcall CaptureDlg::cbStatus( void* pv, Whisper::eCaptureStatus status ) noexcept
+{
+ CaptureDlg& dialog = *(CaptureDlg*)pv;
+ if( dialog.PostMessage( WM_CALLBACK_STATUS, (uint8_t)status ) )
+ return S_OK;
+ return getLastHr();
+}
+
+HRESULT __cdecl CaptureDlg::newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept
+{
+ using namespace Whisper;
+ CComPtr<iTranscribeResult> result;
+ const eResultFlags flags = eResultFlags::Timestamps | eResultFlags::Tokens;
+ CHECK( ctx->getResults( flags, &result ) );
+ CHECK( logNewSegments( result, n_new ) );
+
+ CaptureDlg& dialog = *(CaptureDlg*)user_data;
+ return dialog.appendTextFile( result, n_new );
+}
+
+HRESULT CaptureDlg::appendTextFile( Whisper::iTranscribeResult* results, uint32_t newSegments )
+{
+ if( nullptr == threadState.file || 0 == newSegments )
+ return S_OK;
+
+ using namespace Whisper;
+ sTranscribeLength length;
+ CHECK( results->getSize( length ) );
+
+ const size_t len = length.countSegments;
+ size_t i = len - newSegments;
+
+ const sSegment* const segments = results->getSegments();
+ CStringA str;
+ for( ; i < len; i++ )
+ {
+ const sSegment& seg = segments[ i ];
+ if( 0 != ( (uint32_t)threadState.flags & (uint32_t)eTextFlags::Timestamps ) )
+ {
+ str = "[";
+ printTimeStamp( str, seg.time.begin );
+ str += " --> ";
+ printTimeStamp( str, seg.time.end );
+ str += "] ";
+ }
+ else
+ str = "";
+
+ str += seg.text;
+ str += "\r\n";
+
+ CHECK( threadState.file->Write( cstr( str ), (DWORD)str.GetLength() ) );
+ }
+
+ CHECK( threadState.file->Flush() );
+ return S_OK;
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/CaptureDlg.h b/Examples/WhisperDesktop/CaptureDlg.h
new file mode 100644
index 0000000..c66b10d
--- /dev/null
+++ b/Examples/WhisperDesktop/CaptureDlg.h
@@ -0,0 +1,143 @@
+#pragma once
+#include "AppState.h"
+#include "Utils/WTL/atlddx.h"
+#include "Utils/miscUtils.h"
+#include "Utils/LanguageDropdown.h"
+#include "Utils/TranslateCheckbox.h"
+#include "Utils/PendingState.h"
+#include "CircleIndicator.h"
+
+class CaptureDlg :
+ public CDialogImpl<CaptureDlg>,
+ public CWinDataExchange<CaptureDlg>,
+ public iThreadPoolCallback
+{
+ AppState& appState;
+
+public:
+ static constexpr UINT IDD = IDD_CAPTURE_DIALOG;
+ static constexpr UINT WM_CALLBACK_COMPLETION = WM_APP + 1;
+ static constexpr UINT WM_CALLBACK_STATUS = WM_APP + 2;
+
+ CaptureDlg( AppState& app ) : appState( app ) { }
+
+ HRESULT show();
+
+ BEGIN_MSG_MAP( CaptureDlg )
+ MESSAGE_HANDLER( WM_INITDIALOG, OnInitDialog )
+ ON_BUTTON_CLICK( IDC_CONSOLE, cbConsole.click )
+ ON_BUTTON_CLICK( IDC_DEV_REFRESH, onDeviceRefresh );
+ ON_BUTTON_CLICK( IDC_BROWSE_RESULT, onBrowseResult );
+ ON_BUTTON_CLICK( IDC_SAVE_TEXT, onSaveTextCheckbox );
+ ON_BUTTON_CLICK( IDC_RUN_CAPTURE, onRunCapture );
+
+ ON_BUTTON_CLICK( IDCANCEL, onClose )
+ ON_BUTTON_CLICK( IDC_BACK, onBack )
+ ON_BUTTON_CLICK( IDC_TRANSCRIBE, onTranscribe );
+
+ MESSAGE_HANDLER( WM_CALLBACK_COMPLETION, onThreadQuit );
+ MESSAGE_HANDLER( WM_CALLBACK_STATUS, onThreadStatus );
+ END_MSG_MAP()
+
+ BEGIN_DDX_MAP( CaptureDlg )
+ DDX_CONTROL_HANDLE( IDC_DEVICE, cbCaptureDevice )
+ DDX_CONTROL_HANDLE( IDC_RUN_CAPTURE, btnRunCapture );
+ DDX_CONTROL_HANDLE( IDC_TRANSCRIBE_PROGRESS, progressBar );
+ DDX_CONTROL_HANDLE( IDC_SAVE_TEXT, checkSave )
+ DDX_CONTROL_HANDLE( IDC_SAVE_APPEND, checkAppend )
+ DDX_CONTROL_HANDLE( IDC_SAVE_TIMESTAMPS, checkTimestamps )
+ DDX_CONTROL_HANDLE( IDC_PATH_RESULT, transcribeOutputPath )
+ DDX_CONTROL_HANDLE( IDC_BROWSE_RESULT, transcribeOutputBrowse );
+
+ DDX_CONTROL( IDC_VOICE_ACTIVITY, voiceActivity );
+ DDX_CONTROL( IDC_TRANS_STATUS, transcribeActivity );
+ DDX_CONTROL( IDC_STALL_STATUS, stalled );
+
+ END_DDX_MAP()
+
+private:
+ PendingState pendingState;
+ void setPending( bool nowPending );
+
+ LRESULT OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled );
+
+ void onClose()
+ {
+ ATLVERIFY( EndDialog( IDCANCEL ) );
+ }
+ void onBack()
+ {
+ ATLVERIFY( EndDialog( IDC_BACK ) );
+ }
+ void onTranscribe()
+ {
+ ATLVERIFY( EndDialog( IDC_TRANSCRIBE ) );
+ }
+
+ // List capture devices, and populate the combobox
+ bool listDevices();
+ void onDeviceRefresh();
+ bool selectDevice( LPCTSTR endpoint );
+
+ static HRESULT __stdcall listDevicesCallback( int len, const Whisper::sCaptureDevice* buffer, void* pv ) noexcept;
+ ConsoleCheckbox cbConsole;
+ LanguageDropdown languageSelector;
+ TranslateCheckbox cbTranslate;
+ CComboBox cbCaptureDevice;
+
+ void onBrowseResult();
+
+ enum struct eTextFlags : uint32_t;
+ CButton checkSave, checkAppend, checkTimestamps;
+ CEdit transcribeOutputPath;
+ CButton transcribeOutputBrowse;
+ void onSaveTextCheckbox();
+ eTextFlags getOutputFlags();
+
+ CButton btnRunCapture;
+ CProgressBarCtrl progressBar;
+ ThreadPoolWork work;
+
+ struct sCaptureDevice
+ {
+ CString displayName;
+ CString endpoint;
+ };
+ std::vector<sCaptureDevice> devices;
+
+ void showError( LPCTSTR text, HRESULT hr );
+
+ CircleIndicator voiceActivity;
+ CircleIndicator transcribeActivity;
+ CircleIndicator stalled;
+
+ struct sThreadState
+ {
+ volatile bool stopRequested;
+ bool translate;
+ eTextFlags flags;
+ CAtlFile* file;
+ uint32_t language;
+ Whisper::sCaptureParams captureParams;
+ CString endpoint;
+ CString textOutputPath;
+ CString errorMessage;
+ };
+ sThreadState threadState;
+ bool captureRunning = false;
+
+ void getThreadError();
+ void onRunCapture();
+ HRESULT runCapture();
+ void __stdcall poolCallback() noexcept override final;
+
+ LRESULT onThreadQuit( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled );
+ LRESULT onThreadStatus( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled );
+
+ static HRESULT __stdcall cbCancel( void* pv ) noexcept;
+ static HRESULT __stdcall cbStatus( void* pv, Whisper::eCaptureStatus status ) noexcept;
+
+ static HRESULT __cdecl newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept;
+
+ HRESULT appendTextFile( Whisper::iTranscribeResult* results, uint32_t newSegments );
+}; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/CircleIndicator.cpp b/Examples/WhisperDesktop/CircleIndicator.cpp
new file mode 100644
index 0000000..593bd4a
--- /dev/null
+++ b/Examples/WhisperDesktop/CircleIndicator.cpp
@@ -0,0 +1,118 @@
+#include "stdafx.h"
+#include "CircleIndicator.h"
+#include <atltypes.h>
+#include "AppState.h"
+
+namespace
+{
+ static const LPCTSTR windowClassName = L"CircleIndicator";
+
+ // Font with these symbols, shipped with Windows since forever:
+ // https://learn.microsoft.com/en-us/typography/font-list/segoe-ui-symbol
+ static const LPCTSTR fontName = L"Segoe UI Symbol";
+
+ // Outlined circle
+ static const LPCTSTR circleOutline = L"⚪";
+ // Filled circle
+ static const LPCTSTR circleFilled = L"âš«";
+
+ // Font size for that symbol font, in points
+ constexpr int fontSizePoints = 14;
+
+ // Default active color for the indicator
+ constexpr uint32_t defaultActiveColor = 0x7FFF7F;
+}
+
+CircleIndicator::CircleIndicator() :
+ activeColor( defaultActiveColor )
+{ }
+
+ATL::CWndClassInfo& CircleIndicator::GetWndClassInfo()
+{
+ // Use custom class style with CS_PARENTDC, and COLOR_3DFACE for the background
+ static ATL::CWndClassInfo wc =
+ {
+ { sizeof( WNDCLASSEX ),
+ CS_HREDRAW | CS_VREDRAW | CS_PARENTDC,
+ StartWindowProc,
+ 0, 0, NULL, NULL, NULL, (HBRUSH)( COLOR_3DFACE + 1 ), NULL, windowClassName, NULL },
+ NULL, NULL, IDC_ARROW, TRUE, 0, _T( "" )
+ };
+ return wc;
+}
+
+// Class registration
+HRESULT CircleIndicator::registerClass()
+{
+ WNDPROC pUnusedWndSuperProc = nullptr;
+ ATOM a = GetWndClassInfo().Register( &pUnusedWndSuperProc );
+ if( 0 != a )
+ return S_OK;
+ return getLastHr();
+}
+
+HRESULT CircleIndicator::createFont( int height )
+{
+ LOGFONT logFont;
+ memset( &logFont, 0, sizeof( logFont ) );
+ logFont.lfHeight = height;
+ logFont.lfCharSet = ANSI_CHARSET;
+ logFont.lfOutPrecision = OUT_TT_PRECIS;
+ logFont.lfClipPrecision = CLIP_DEFAULT_PRECIS;
+ wcsncpy_s( logFont.lfFaceName, fontName, _TRUNCATE );
+ font.CreateFontIndirect( &logFont );
+ if( font )
+ return S_OK;
+ return E_FAIL;
+}
+
+void CircleIndicator::onDestroy()
+{
+ if( font )
+ font.DeleteObject();
+}
+
+void CircleIndicator::onPaint( CDCHandle dc )
+{
+ CRect rectInt32;
+ GetClientRect( &rectInt32 );
+
+ CPaintDC pdc( m_hWnd );
+
+ const int logPixels = pdc.GetDeviceCaps( LOGPIXELSY );
+ int fontSize = -MulDiv( fontSizePoints, logPixels, 72 );
+ if( !font || fontHeight != fontSize )
+ {
+ if( font )
+ font.DeleteObject();
+ HRESULT hr = createFont( fontSize );
+ if( FAILED( hr ) )
+ return;
+ fontHeight = fontSize;
+ }
+
+ pdc.SetBkColor( GetSysColor( COLOR_3DFACE ) );
+ pdc.SelectFont( font );
+ pdc.SetBkMode( OPAQUE );
+ constexpr UINT textFormat = DT_CENTER | DT_VCENTER;
+
+ if( isActive )
+ {
+ pdc.SetTextColor( activeColor );
+ pdc.DrawText( circleFilled, 1, rectInt32, textFormat );
+ pdc.SetBkMode( TRANSPARENT );
+ }
+
+ pdc.SetTextColor( 0 );
+ pdc.DrawText( circleOutline, 1, rectInt32, textFormat );
+}
+
+void CircleIndicator::setActive( bool nowActive )
+{
+ if( nowActive == isActive )
+ return;
+
+ // Repaint the control
+ isActive = nowActive;
+ InvalidateRect( nullptr );
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/CircleIndicator.h b/Examples/WhisperDesktop/CircleIndicator.h
new file mode 100644
index 0000000..492524d
--- /dev/null
+++ b/Examples/WhisperDesktop/CircleIndicator.h
@@ -0,0 +1,36 @@
+#pragma once
+#include "Utils/miscUtils.h"
+#include "Utils/WTL/atlcrack.h"
+
+// This control renders a black circle, and in the active state, the circle is filled with a bright color.
+class CircleIndicator: public CWindowImpl<CircleIndicator>
+{
+public:
+ static ATL::CWndClassInfo& GetWndClassInfo();
+
+ BEGIN_MSG_MAP( CircleIndicator )
+ MSG_WM_PAINT( onPaint )
+ MSG_WM_DESTROY( onDestroy )
+ END_MSG_MAP()
+
+ // Class registration
+ static HRESULT registerClass();
+
+ void setActive( bool nowActive );
+
+ void setActiveColor( uint32_t col )
+ {
+ activeColor = col;
+ }
+ CircleIndicator();
+
+private:
+ bool isActive = false;
+ uint32_t activeColor;
+ int fontHeight = 0;
+ CFont font;
+ HRESULT createFont( int height );
+
+ void onDestroy();
+ void onPaint( CDCHandle dc );
+}; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/LoadModelDlg.cpp b/Examples/WhisperDesktop/LoadModelDlg.cpp
new file mode 100644
index 0000000..1b2bf03
--- /dev/null
+++ b/Examples/WhisperDesktop/LoadModelDlg.cpp
@@ -0,0 +1,206 @@
+#include "stdafx.h"
+#include "LoadModelDlg.h"
+#include "Utils/miscUtils.h"
+#include "Utils/logger.h"
+
+constexpr int progressMaxInteger = 1024 * 8;
+
+HRESULT LoadModelDlg::show()
+{
+ auto res = DoModal( nullptr );
+ if( res == -1 )
+ return HRESULT_FROM_WIN32( GetLastError() );
+ if( res == IDOK )
+ {
+ HRESULT hr = appState.lastScreenLoad();
+ switch( hr )
+ {
+ case SCREEN_TRANSCRIBE:
+ case SCREEN_CAPTURE:
+ return hr;
+ default:
+ return SCREEN_TRANSCRIBE;
+ }
+ }
+ return S_OK;
+}
+
+LRESULT LoadModelDlg::OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled )
+{
+ // First DDX call, hooks up variables to controls.
+ DoDataExchange( false );
+
+ cbConsole.initialize( m_hWnd, IDC_CONSOLE, appState );
+ implPopulateCombobox( cbModelType, appState.source.impl );
+ modelPath.SetWindowTextW( appState.source.path );
+
+ HRESULT hr = work.create( this );
+ if( FAILED( hr ) )
+ {
+ CString text = L"CreateThreadpoolWork failed\n";
+ text += formatErrorMessage( hr );
+ ::MessageBox( m_hWnd, text, L"Unable to load the model", MB_OK | MB_ICONWARNING );
+ return TRUE;
+ }
+
+ editorsWindows.reserve( 5 );
+ editorsWindows = { modelPath, cbModelType, GetDlgItem( IDC_BROWSE ), GetDlgItem( IDOK ), GetDlgItem( IDCANCEL ) };
+ pendingWindows.reserve( 2 );
+ pendingWindows = { GetDlgItem( IDC_PENDING_TEXT ), progressBar };
+
+ progressBar.SetRange32( 0, progressMaxInteger );
+ progressBar.SetStep( 1 );
+
+ appState.setupIcon( this );
+ ATLVERIFY( CenterWindow() );
+ if( !appState.source.found || !appState.automaticallyLoadModel )
+ return 0;
+
+ // AppState.findModelSource() method has located model parameters in registry;
+ // Post a notification identical to the "OK" button click event.
+ PostMessage( WM_COMMAND, IDOK, (LPARAM)( GetDlgItem( IDOK ).m_hWnd ) );
+
+ return 0;
+}
+
+LRESULT LoadModelDlg::OnBrowse( UINT, INT, HWND, BOOL& bHandled )
+{
+ bHandled = TRUE;
+
+ CString path;
+ modelPath.GetWindowText( path );
+ if( !getOpenFileName( m_hWnd, L"Select a GGML Model File", L"Binary files (*.bin)\0*.bin\0\0", path ) )
+ return 0;
+
+ modelPath.SetWindowText( path );
+ appState.source.path = path;
+ return 0;
+}
+
+LRESULT LoadModelDlg::validationError( LPCTSTR message )
+{
+ reportError( m_hWnd, message, L"Unable to load the model" );
+ return 0;
+}
+
+LRESULT LoadModelDlg::validationError( LPCTSTR message, HRESULT hr )
+{
+ reportError( m_hWnd, message, L"Unable to load the model", hr );
+ return 0;
+}
+
+void LoadModelDlg::setPending( bool nowPending )
+{
+ const BOOL enable = nowPending ? FALSE : TRUE;
+ for( HWND w : editorsWindows )
+ ::EnableWindow( w, enable );
+
+ const int show = nowPending ? SW_NORMAL : SW_HIDE;
+ for( HWND w : pendingWindows )
+ ::ShowWindow( w, show );
+
+ if( nowPending )
+ progressBar.SetMarquee( TRUE, 0 );
+ else
+ progressBar.SetMarquee( FALSE, 0 );
+}
+
+LRESULT LoadModelDlg::OnOk( UINT, INT, HWND, BOOL& bHandled )
+{
+ modelPath.GetWindowText( path );
+ if( path.GetLength() <= 0 )
+ return validationError( L"Please select a model GGML file" );
+
+ {
+ CAtlFile file;
+ HRESULT hr = file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING );
+ if( FAILED( hr ) )
+ return validationError( L"Unable to open the model file", hr );
+
+ ULONGLONG cb = 0;
+ file.GetSize( cb );
+ appState.source.sizeInBytes = cb;
+ }
+
+ impl = implGetValue( cbModelType );
+ if( impl == (Whisper::eModelImplementation)0 )
+ return validationError( L"Please select a model type" );
+
+ setPending( true );
+ work.post();
+ return 0;
+}
+
+void __stdcall LoadModelDlg::poolCallback() noexcept
+{
+ CComPtr<Whisper::iModel> model;
+ clearLastError();
+ loadError = L"";
+ Whisper::sLoadModelCallbacks lmcb;
+ lmcb.cancel = nullptr;
+ lmcb.progress = &LoadModelDlg::progressCallback;
+ lmcb.pv = this;
+ HRESULT hr = Whisper::loadModel( path, impl, &lmcb, &model );
+ if( SUCCEEDED( hr ) )
+ appState.model = model;
+ else
+ getLastError( loadError );
+
+ this->PostMessage( WM_CALLBACK_STATUS, (WPARAM)hr );
+}
+
+HRESULT __stdcall LoadModelDlg::progressCallback( double val, void* pv ) noexcept
+{
+ LoadModelDlg& dialog = *(LoadModelDlg*)pv;
+ constexpr double mul = progressMaxInteger;
+ int pos = lround( mul * val );
+ dialog.progressBar.PostMessage( PBM_SETPOS, pos, 0 );
+ return S_OK;
+}
+
+LRESULT LoadModelDlg::OnCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled )
+{
+ setPending( false );
+
+ bHandled = TRUE;
+ const HRESULT hr = (HRESULT)wParam;
+ if( FAILED( hr ) )
+ {
+ LPCTSTR failMessage = L"Error loading the model";
+ if( loadError.GetLength() > 0 )
+ {
+ CString tmp = failMessage;
+ tmp += L"\n";
+ tmp += loadError;
+ return validationError( tmp, hr );
+ }
+ else
+ return validationError( failMessage, hr );
+ }
+
+ appState.source.path = path;
+ appState.source.impl = impl;
+ appState.saveModelSource();
+
+ EndDialog( IDOK );
+ return 0;
+}
+
+LRESULT LoadModelDlg::OnHyperlink( int idCtrl, LPNMHDR pnmh, BOOL& bHandled )
+{
+ const UINT code = pnmh->code;
+ switch( code )
+ {
+ case NM_CLICK:
+ case NM_RETURN:
+ break;
+ default:
+ return 0;
+ }
+
+ PNMLINK pNMLink = (PNMLINK)pnmh;
+ LPCTSTR url = pNMLink->item.szUrl;
+ ShellExecute( NULL, L"open", url, NULL, NULL, SW_SHOW );
+ bHandled = TRUE;
+ return 0;
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/LoadModelDlg.h b/Examples/WhisperDesktop/LoadModelDlg.h
new file mode 100644
index 0000000..a8d7aea
--- /dev/null
+++ b/Examples/WhisperDesktop/LoadModelDlg.h
@@ -0,0 +1,69 @@
+#pragma once
+#include "AppState.h"
+#include "Utils/WTL/atlddx.h"
+#include "Utils/miscUtils.h"
+
+class LoadModelDlg:
+ public CDialogImpl<LoadModelDlg>,
+ public CWinDataExchange<LoadModelDlg>,
+ public iThreadPoolCallback
+{
+ AppState& appState;
+public:
+ static constexpr UINT IDD = IDD_OPEN_MODEL;
+ static constexpr UINT WM_CALLBACK_STATUS = WM_APP + 1;
+
+ LoadModelDlg( AppState& app ) : appState( app ) { }
+
+ HRESULT show();
+
+ BEGIN_MSG_MAP( LoadModelDlg )
+ MESSAGE_HANDLER( WM_INITDIALOG, OnInitDialog )
+ ON_BUTTON_CLICK( IDC_CONSOLE, cbConsole.click )
+ COMMAND_ID_HANDLER( IDCANCEL, OnCommand )
+ COMMAND_ID_HANDLER( IDOK, OnOk )
+ COMMAND_ID_HANDLER( IDC_BROWSE, OnBrowse )
+ MESSAGE_HANDLER( WM_CALLBACK_STATUS, OnCallbackStatus )
+ NOTIFY_ID_HANDLER( IDC_HYPERLINKS, OnHyperlink )
+ END_MSG_MAP()
+
+ BEGIN_DDX_MAP( LoadModelDlg )
+ DDX_CONTROL_HANDLE( IDC_PATH, modelPath )
+ DDX_CONTROL_HANDLE( IDC_MODEL_TYPE, cbModelType )
+ DDX_CONTROL_HANDLE( IDC_PROGRESS, progressBar );
+ END_DDX_MAP()
+
+private:
+ std::vector<HWND> editorsWindows;
+ std::vector<HWND> pendingWindows;
+ void setPending( bool nowPending );
+
+ LRESULT OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled );
+ LRESULT OnCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled );
+
+ LRESULT OnCommand( UINT, INT nIdentifier, HWND, BOOL& bHandled )
+ {
+ ATLVERIFY( EndDialog( nIdentifier ) );
+ return 0;
+ }
+ LRESULT OnBrowse( UINT, INT, HWND, BOOL& bHandled );
+ LRESULT OnOk( UINT, INT, HWND, BOOL& bHandled );
+
+ ConsoleCheckbox cbConsole;
+ CComboBox cbModelType;
+ CEdit modelPath;
+ CProgressBarCtrl progressBar;
+
+ LRESULT validationError( LPCTSTR message );
+ LRESULT validationError( LPCTSTR message, HRESULT hr );
+
+ ThreadPoolWork work;
+ CString path;
+ Whisper::eModelImplementation impl;
+ CString loadError;
+ void __stdcall poolCallback() noexcept override final;
+
+ LRESULT OnHyperlink( int idCtrl, LPNMHDR pnmh, BOOL& bHandled );
+
+ static HRESULT __stdcall progressCallback( double val, void* pv ) noexcept;
+}; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Resource.h b/Examples/WhisperDesktop/Resource.h
new file mode 100644
index 0000000..f4a9af5
--- /dev/null
+++ b/Examples/WhisperDesktop/Resource.h
@@ -0,0 +1,61 @@
+//{{NO_DEPENDENCIES}}
+// Microsoft Visual C++ generated include file.
+// Used by WhisperDesktop.rc
+//
+#define IDC_MYICON 2
+#define IDD_WHISPERDESKTOP_DIALOG 102
+#define IDD_ABOUTBOX 103
+#define IDM_ABOUT 104
+#define IDI_WHISPERDESKTOP 107
+#define IDI_SMALL 108
+#define IDR_MAINFRAME 128
+#define IDD_OPEN_MODEL 129
+#define IDD_MAIN_DIALOG 130
+#define IDD_TRANSCRIBE_DIALOG 130
+#define IDD_CAPTURE_DIALOG 131
+#define IDC_PATH 1000
+#define IDC_BROWSE 1001
+#define IDC_MODEL_TYPE 1002
+#define IDC_PATH_RESULT 1002
+#define IDC_PROGRESS 1003
+#define IDC_BROWSE_RESULT 1003
+#define IDC_SYSLINK1 1004
+#define IDC_HYPERLINKS 1004
+#define IDC_TRANSCRIBE_PROGRESS 1004
+#define IDC_PENDING_TEXT 1005
+#define IDC_MODEL_DESC 1006
+#define IDC_LANGUAGE 1007
+#define IDC_OUTPUT_FORMAT 1008
+#define IDC_PATH_MEDIA 1009
+#define IDC_DEVICE 1009
+#define IDC_BROWSE_MEDIA 1010
+#define IDC_TRANSCRIBE 1011
+#define IDC_BACK 1012
+#define IDC_CHECK1 1013
+#define IDC_CONSOLE 1013
+#define IDC_CAPTURE 1014
+#define IDC_DEV_REFRESH 1015
+#define IDC_SAVE_TEXT 1016
+#define IDC_SAVE_APPEND 1017
+#define IDC_SAVE_TIMESTAMPS 1018
+#define IDC_RUN_CAPTURE 1019
+#define IDC_VOICE_ACTIVITY 1020
+#define IDC_VOICE_ACTIVITY_LBL 1021
+#define IDC_TRANS_STATUS 1022
+#define IDC_TRANS_LBL 1023
+#define IDC_STALL_STATUS 1024
+#define IDC_STALL_LBL 1025
+#define IDC_TRANSLATE 1026
+#define IDC_STATIC -1
+
+// Next default values for new objects
+//
+#ifdef APSTUDIO_INVOKED
+#ifndef APSTUDIO_READONLY_SYMBOLS
+#define _APS_NO_MFC 1
+#define _APS_NEXT_RESOURCE_VALUE 131
+#define _APS_NEXT_COMMAND_VALUE 32771
+#define _APS_NEXT_CONTROL_VALUE 1027
+#define _APS_NEXT_SYMED_VALUE 110
+#endif
+#endif
diff --git a/Examples/WhisperDesktop/TranscribeDlg.cpp b/Examples/WhisperDesktop/TranscribeDlg.cpp
new file mode 100644
index 0000000..14bec05
--- /dev/null
+++ b/Examples/WhisperDesktop/TranscribeDlg.cpp
@@ -0,0 +1,493 @@
+#include "stdafx.h"
+#include "TranscribeDlg.h"
+#include "Utils/logger.h"
+
+HRESULT TranscribeDlg::show()
+{
+ auto res = DoModal( nullptr );
+ if( res == -1 )
+ return HRESULT_FROM_WIN32( GetLastError() );
+ switch( res )
+ {
+ case IDC_BACK:
+ return SCREEN_MODEL;
+ case IDC_CAPTURE:
+ return SCREEN_CAPTURE;
+ }
+ return S_OK;
+}
+
+constexpr int progressMaxInteger = 1024 * 8;
+
+static const LPCTSTR regValInput = L"sourceMedia";
+static const LPCTSTR regValOutFormat = L"resultFormat";
+static const LPCTSTR regValOutPath = L"resultPath";
+
+LRESULT TranscribeDlg::OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled )
+{
+ // First DDX call, hooks up variables to controls.
+ DoDataExchange( false );
+ printModelDescription();
+ languageSelector.initialize( m_hWnd, IDC_LANGUAGE, appState );
+ cbConsole.initialize( m_hWnd, IDC_CONSOLE, appState );
+ cbTranslate.initialize( m_hWnd, IDC_TRANSLATE, appState );
+ populateOutputFormats();
+
+ pendingState.initialize(
+ {
+ languageSelector,
+ sourceMediaPath, GetDlgItem( IDC_BROWSE_MEDIA ),
+ transcribeOutFormat,
+ transcribeOutputPath, GetDlgItem( IDC_BROWSE_RESULT ),
+ GetDlgItem( IDC_TRANSCRIBE ),
+ GetDlgItem( IDCANCEL ),
+ GetDlgItem( IDC_BACK ),
+ GetDlgItem( IDC_CAPTURE )
+ },
+ {
+ progressBar, GetDlgItem( IDC_PENDING_TEXT )
+ } );
+
+ HRESULT hr = work.create( this );
+ if( FAILED( hr ) )
+ {
+ reportError( m_hWnd, L"CreateThreadpoolWork failed", nullptr, hr );
+ EndDialog( IDCANCEL );
+ }
+
+ progressBar.SetRange32( 0, progressMaxInteger );
+ progressBar.SetStep( 1 );
+
+ sourceMediaPath.SetWindowText( appState.stringLoad( regValInput ) );
+ transcribeOutFormat.SetCurSel( (int)appState.dwordLoad( regValOutFormat, 0 ) );
+ transcribeOutputPath.SetWindowText( appState.stringLoad( regValOutPath ) );
+ BOOL unused;
+ OnOutFormatChange( 0, 0, nullptr, unused );
+
+ appState.lastScreenSave( SCREEN_TRANSCRIBE );
+ appState.setupIcon( this );
+ ATLVERIFY( CenterWindow() );
+ return 0;
+}
+
+void TranscribeDlg::printModelDescription()
+{
+ CString text;
+ if( S_OK == appState.model->isMultilingual() )
+ text = L"Multilingual";
+ else
+ text = L"Single-language";
+ text += L" model \"";
+ LPCTSTR path = appState.source.path;
+ path = ::PathFindFileName( path );
+ text += path;
+ text += L"\", ";
+ const int64_t cb = appState.source.sizeInBytes;
+ if( cb < 1 << 30 )
+ {
+ constexpr double mul = 1.0 / ( 1 << 20 );
+ double mb = (double)cb * mul;
+ text.AppendFormat( L"%.1f MB", mb );
+ }
+ else
+ {
+ constexpr double mul = 1.0 / ( 1 << 30 );
+ double gb = (double)cb * mul;
+ text.AppendFormat( L"%.2f GB", gb );
+ }
+ text += L" on disk, ";
+ text += implString( appState.source.impl );
+ text += L" implementation";
+
+ modelDesc.SetWindowText( text );
+}
+
+void TranscribeDlg::populateOutputFormats()
+{
+ transcribeOutFormat.AddString( L"None" );
+ transcribeOutFormat.AddString( L"Text File" );
+ transcribeOutFormat.AddString( L"SubRip subtitles" );
+ transcribeOutFormat.AddString( L"WebVTT subtitles" );
+}
+
+enum struct TranscribeDlg::eOutputFormat : uint8_t
+{
+ None = 0,
+ Text = 1,
+ SubRip = 2,
+ WebVTT = 3
+};
+
+LRESULT TranscribeDlg::OnOutFormatChange( UINT, INT, HWND, BOOL& bHandled )
+{
+ const BOOL enabled = transcribeOutFormat.GetCurSel() != 0;
+ transcribeOutputPath.EnableWindow( enabled );
+ transcribeOutputBrowse.EnableWindow( enabled );
+ return 0;
+}
+
+void TranscribeDlg::onBrowseMedia()
+{
+ LPCTSTR title = L"Input audio file to transcribe";
+ LPCTSTR filters = L"Multimedia Files\0*.wav;*.wave;*.mp3;*.wma;*.mp4;*.mpeg4;*.mkv\0\0";
+
+ CString path;
+ sourceMediaPath.GetWindowText( path );
+ if( getOpenFileName( m_hWnd, title, filters, path ) )
+ sourceMediaPath.SetWindowText( path );
+}
+
+static const LPCTSTR outputFilters = L"Text files (*.txt)\0*.txt\0SubRip subtitles (*.srt)\0*.srt\0WebVTT subtitles (*.vtt)\0*.vtt\0\0";
+static const std::array<LPCTSTR, 3> outputExtensions =
+{
+ L".txt", L".srt", L".vtt"
+};
+
+void TranscribeDlg::onBrowseOutput()
+{
+ const DWORD origFilterIndex = (DWORD)transcribeOutFormat.GetCurSel() - 1;
+
+ LPCTSTR title = L"Output Text File";
+ CString path;
+ transcribeOutputPath.GetWindowText( path );
+ DWORD filterIndex = origFilterIndex;
+ if( !getSaveFileName( m_hWnd, title, outputFilters, path, &filterIndex ) )
+ return;
+
+ LPCTSTR ext = PathFindExtension( path );
+ if( 0 == *ext && filterIndex < outputExtensions.size() )
+ {
+ wchar_t* const buffer = path.GetBufferSetLength( path.GetLength() + 5 );
+ PathRenameExtension( buffer, outputExtensions[ filterIndex ] );
+ path.ReleaseBuffer();
+ }
+
+ transcribeOutputPath.SetWindowText( path );
+ if( filterIndex != origFilterIndex )
+ transcribeOutFormat.SetCurSel( filterIndex + 1 );
+}
+
+void TranscribeDlg::setPending( bool nowPending )
+{
+ pendingState.setPending( nowPending );
+}
+
+void TranscribeDlg::transcribeError( LPCTSTR text, HRESULT hr )
+{
+ reportError( m_hWnd, text, L"Unable to transcribe audio", hr );
+}
+
+void TranscribeDlg::onTranscribe()
+{
+ // Validate input
+ sourceMediaPath.GetWindowText( transcribeArgs.pathMedia );
+ if( transcribeArgs.pathMedia.GetLength() <= 0 )
+ {
+ transcribeError( L"Please select an input audio file" );
+ return;
+ }
+
+ if( !PathFileExists( transcribeArgs.pathMedia ) )
+ {
+ transcribeError( L"Input audio file does not exist", HRESULT_FROM_WIN32( ERROR_FILE_NOT_FOUND ) );
+ return;
+ }
+
+ transcribeArgs.language = languageSelector.selectedLanguage();
+ transcribeArgs.translate = cbTranslate.checked();
+ if( isInvalidTranslate( m_hWnd, transcribeArgs.language, transcribeArgs.translate ) )
+ return;
+
+ transcribeArgs.format = (eOutputFormat)(uint8_t)transcribeOutFormat.GetCurSel();
+ if( transcribeArgs.format != eOutputFormat::None )
+ {
+ transcribeOutputPath.GetWindowText( transcribeArgs.pathOutput );
+ if( transcribeArgs.pathOutput.GetLength() <= 0 )
+ {
+ transcribeError( L"Please select an output text file" );
+ return;
+ }
+ appState.stringStore( regValOutPath, transcribeArgs.pathOutput );
+ }
+ else
+ cbConsole.ensureChecked();
+
+ appState.dwordStore( regValOutFormat, (uint32_t)(int)transcribeArgs.format );
+ languageSelector.saveSelection( appState );
+ cbTranslate.saveSelection( appState );
+ appState.stringStore( regValInput, transcribeArgs.pathMedia );
+
+ setPending( true );
+
+ work.post();
+}
+
+void __stdcall TranscribeDlg::poolCallback() noexcept
+{
+ HRESULT hr = transcribe();
+ PostMessage( WM_CALLBACK_STATUS, (WPARAM)hr );
+}
+
+static void printTime( CString& rdi, int64_t ticks )
+{
+ const Whisper::sTimeSpan ts{ (uint64_t)ticks };
+ const Whisper::sTimeSpanFields fields = ts;
+
+ if( fields.days != 0 )
+ {
+ rdi.AppendFormat( L"%i days, %i hours", fields.days, (int)fields.hours );
+ return;
+ }
+ if( ( fields.hours | fields.minutes ) != 0 )
+ {
+ rdi.AppendFormat( L"%02d:%02d:%02d", (int)fields.hours, (int)fields.minutes, (int)fields.seconds );
+ return;
+ }
+ rdi.AppendFormat( L"%.3f seconds", (double)ticks / 1E7 );
+}
+
+LRESULT TranscribeDlg::onCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled )
+{
+ setPending( false );
+ const HRESULT hr = (HRESULT)wParam;
+ if( FAILED( hr ) )
+ {
+ LPCTSTR failMessage = L"Transcribe failed";
+
+ if( transcribeArgs.errorMessage.GetLength() > 0 )
+ {
+ CString tmp = failMessage;
+ tmp += L"\n";
+ tmp += transcribeArgs.errorMessage;
+ transcribeError( tmp, hr );
+ }
+ else
+ transcribeError( failMessage, hr );
+
+ return 0;
+ }
+
+ const int64_t elapsed = ( GetTickCount64() - transcribeArgs.startTime ) * 10'000;
+ const int64_t media = transcribeArgs.mediaDuration;
+ CString message = L"Transcribed the audio\nMedia duration: ";
+ printTime( message, media );
+ message += L"\nProcessing time: ";
+ printTime( message, elapsed );
+ message += L"\nRelative processing speed: ";
+ double mul = (double)media / (double)elapsed;
+ message.AppendFormat( L"%g", mul );
+
+ MessageBox( message, L"Transcribe Completed", MB_OK | MB_ICONINFORMATION );
+ return 0;
+}
+
+void TranscribeDlg::getThreadError()
+{
+ getLastError( transcribeArgs.errorMessage );
+}
+
+#define CHECK_EX( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) { getThreadError(); return __hr; } }
+
+HRESULT TranscribeDlg::transcribe()
+{
+ transcribeArgs.startTime = GetTickCount64();
+ clearLastError();
+ transcribeArgs.errorMessage = L"";
+
+ using namespace Whisper;
+ CComPtr<iAudioReader> reader;
+
+ CHECK_EX( appState.mediaFoundation->openAudioFile( transcribeArgs.pathMedia, false, &reader ) );
+ CHECK_EX( reader->getDuration( transcribeArgs.mediaDuration ) );
+
+ const eOutputFormat format = transcribeArgs.format;
+ CAtlFile outputFile;
+ if( format != eOutputFormat::None )
+ CHECK( outputFile.Create( transcribeArgs.pathOutput, GENERIC_WRITE, 0, CREATE_ALWAYS ) );
+
+ transcribeArgs.resultFlags = eResultFlags::Timestamps | eResultFlags::Tokens;
+
+ CComPtr<iContext> context;
+ CHECK_EX( appState.model->createContext( &context ) );
+
+ sFullParams fullParams;
+ CHECK_EX( context->fullDefaultParams( eSamplingStrategy::Greedy, &fullParams ) );
+ fullParams.language = transcribeArgs.language;
+ fullParams.setFlag( eFullParamsFlags::Translate, transcribeArgs.translate );
+ fullParams.resetFlag( eFullParamsFlags::PrintRealtime );
+
+ fullParams.new_segment_callback_user_data = this;
+ fullParams.new_segment_callback = &newSegmentCallbackStatic;
+
+ // Setup the progress indication sink
+ sProgressSink progressSink{ &progressCallbackStatic, this };
+ // Run the transcribe
+ CHECK_EX( context->runStreamed( fullParams, progressSink, reader ) );
+
+ context->timingsPrint();
+
+ if( format == eOutputFormat::None )
+ return S_OK;
+
+ CComPtr<iTranscribeResult> result;
+ CHECK_EX( context->getResults( transcribeArgs.resultFlags, &result ) );
+
+ sTranscribeLength len;
+ CHECK_EX( result->getSize( len ) );
+ const sSegment* const segments = result->getSegments();
+
+ switch( format )
+ {
+ case eOutputFormat::Text:
+ return writeTextFile( segments, len.countSegments, outputFile );
+ case eOutputFormat::SubRip:
+ return writeSubRip( segments, len.countSegments, outputFile );
+ case eOutputFormat::WebVTT:
+ return writeWebVTT( segments, len.countSegments, outputFile );
+ default:
+ return E_FAIL;
+ }
+}
+
+#undef CHECK_EX
+
+inline HRESULT TranscribeDlg::progressCallback( double p ) noexcept
+{
+ constexpr double mul = progressMaxInteger;
+ int pos = lround( mul * p );
+ progressBar.PostMessage( PBM_SETPOS, pos, 0 );
+ return S_OK;
+}
+
+HRESULT __cdecl TranscribeDlg::progressCallbackStatic( double p, Whisper::iContext* ctx, void* pv ) noexcept
+{
+ TranscribeDlg* dlg = (TranscribeDlg*)pv;
+ return dlg->progressCallback( p );
+}
+
+namespace
+{
+ HRESULT write( CAtlFile& file, const CStringA& line )
+ {
+ if( line.GetLength() > 0 )
+ CHECK( file.Write( cstr( line ), (DWORD)line.GetLength() ) );
+ return S_OK;
+ }
+
+ void printTime( CStringA& rdi, Whisper::sTimeSpan time, bool comma )
+ {
+ Whisper::sTimeSpanFields fields = time;
+ const char separator = comma ? ',' : '.';
+ rdi.AppendFormat( "%02d:%02d:%02d%c%03d",
+ (int)fields.hours,
+ (int)fields.minutes,
+ (int)fields.seconds,
+ separator,
+ fields.ticks / 10'000 );
+ }
+
+ const char* skipBlank( const char* rsi )
+ {
+ while( true )
+ {
+ const char c = *rsi;
+ if( c == ' ' || c == '\t' )
+ {
+ rsi++;
+ continue;
+ }
+ return rsi;
+ }
+ }
+}
+
+using Whisper::sSegment;
+
+
+HRESULT TranscribeDlg::writeTextFile( const sSegment* const segments, const size_t length, CAtlFile& file )
+{
+ using namespace Whisper;
+ CHECK( writeUtf8Bom( file ) );
+ CStringA line;
+ for( size_t i = 0; i < length; i++ )
+ {
+ line = skipBlank( segments[ i ].text );
+ line += "\r\n";
+ CHECK( write( file, line ) );
+ }
+ return S_OK;
+}
+
+HRESULT TranscribeDlg::writeSubRip( const sSegment* const segments, const size_t length, CAtlFile& file )
+{
+ CHECK( writeUtf8Bom( file ) );
+ CStringA line;
+ for( size_t i = 0; i < length; i++ )
+ {
+ const sSegment& seg = segments[ i ];
+
+ line.Format( "%zu\r\n", i + 1 );
+ printTime( line, seg.time.begin, true );
+ line += " --> ";
+ printTime( line, seg.time.end, true );
+ line += "\r\n";
+ line += skipBlank( seg.text );
+ line += "\r\n\r\n";
+ CHECK( write( file, line ) );
+ }
+ return S_OK;
+}
+
+HRESULT TranscribeDlg::writeWebVTT( const sSegment* const segments, const size_t length, CAtlFile& file )
+{
+ CHECK( writeUtf8Bom( file ) );
+ CStringA line;
+ line = "WEBVTT\r\n\r\n";
+ CHECK( write( file, line ) );
+
+ for( size_t i = 0; i < length; i++ )
+ {
+ const sSegment& seg = segments[ i ];
+ line = "";
+
+ printTime( line, seg.time.begin, false );
+ line += " --> ";
+ printTime( line, seg.time.end, false );
+ line += "\r\n";
+ line += skipBlank( seg.text );
+ line += "\r\n\r\n";
+ CHECK( write( file, line ) );
+ }
+ return S_OK;
+}
+
+inline HRESULT TranscribeDlg::newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new )
+{
+ using namespace Whisper;
+ CComPtr<iTranscribeResult> result;
+ CHECK( ctx->getResults( transcribeArgs.resultFlags, &result ) );
+ return logNewSegments( result, n_new );
+}
+
+HRESULT __cdecl TranscribeDlg::newSegmentCallbackStatic( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept
+{
+ TranscribeDlg* dlg = (TranscribeDlg*)user_data;
+ return dlg->newSegmentCallback( ctx, n_new );
+}
+
+void TranscribeDlg::onWmClose()
+{
+ if( GetDlgItem( IDCANCEL ).IsWindowEnabled() )
+ {
+ EndDialog( IDCANCEL );
+ return;
+ }
+
+ constexpr UINT flags = MB_YESNO | MB_ICONQUESTION | MB_DEFBUTTON2;
+ const int res = this->MessageBox( L"Transcribe is in progress.\nDo you want to quit anyway?", L"Confirm exit", flags );
+ if( res != IDYES )
+ return;
+
+ // TODO: instead of ExitProcess(), implement another callback in the DLL API, for proper cancellation of the background task
+ ExitProcess( 1 );
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/TranscribeDlg.h b/Examples/WhisperDesktop/TranscribeDlg.h
new file mode 100644
index 0000000..354e002
--- /dev/null
+++ b/Examples/WhisperDesktop/TranscribeDlg.h
@@ -0,0 +1,124 @@
+#pragma once
+#include "AppState.h"
+#include "Utils/WTL/atlddx.h"
+#include "Utils/WTL/atlcrack.h"
+#include "Utils/miscUtils.h"
+#include "Utils/LanguageDropdown.h"
+#include "Utils/TranslateCheckbox.h"
+#include "Utils/PendingState.h"
+
+class TranscribeDlg :
+ public CDialogImpl<TranscribeDlg>,
+ public CWinDataExchange<TranscribeDlg>,
+ public iThreadPoolCallback
+{
+ AppState& appState;
+
+public:
+ static constexpr UINT IDD = IDD_TRANSCRIBE_DIALOG;
+ static constexpr UINT WM_CALLBACK_STATUS = WM_APP + 1;
+
+ TranscribeDlg( AppState& app ) : appState( app ) { }
+
+ // Show this dialog modally, without parent.
+ HRESULT show();
+
+ BEGIN_MSG_MAP( LoadModelDlg )
+ MESSAGE_HANDLER( WM_INITDIALOG, OnInitDialog )
+ ON_BUTTON_CLICK( IDC_CONSOLE, cbConsole.click )
+ ON_BUTTON_CLICK( IDCANCEL, onClose )
+ ON_BUTTON_CLICK( IDC_BACK, onBack )
+ ON_BUTTON_CLICK( IDC_BROWSE_MEDIA, onBrowseMedia )
+ ON_BUTTON_CLICK( IDC_BROWSE_RESULT, onBrowseOutput )
+ ON_BUTTON_CLICK( IDC_TRANSCRIBE, onTranscribe )
+ ON_BUTTON_CLICK( IDC_CAPTURE, onCapture );
+ COMMAND_HANDLER( IDC_OUTPUT_FORMAT, CBN_SELCHANGE, OnOutFormatChange )
+ MESSAGE_HANDLER( WM_CALLBACK_STATUS, onCallbackStatus )
+ MSG_WM_CLOSE( onWmClose )
+ END_MSG_MAP()
+
+ BEGIN_DDX_MAP( LoadModelDlg )
+ DDX_CONTROL_HANDLE( IDC_MODEL_DESC, modelDesc )
+ DDX_CONTROL_HANDLE( IDC_PATH_MEDIA, sourceMediaPath )
+ DDX_CONTROL_HANDLE( IDC_OUTPUT_FORMAT, transcribeOutFormat )
+ DDX_CONTROL_HANDLE( IDC_PATH_RESULT, transcribeOutputPath )
+ DDX_CONTROL_HANDLE( IDC_BROWSE_RESULT, transcribeOutputBrowse );
+ DDX_CONTROL_HANDLE( IDC_TRANSCRIBE_PROGRESS, progressBar );
+ END_DDX_MAP()
+
+private:
+ PendingState pendingState;
+ void setPending( bool nowPending );
+ void transcribeError( LPCTSTR text, HRESULT hr = S_FALSE );
+
+ LRESULT OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled );
+
+ void onClose()
+ {
+ ATLVERIFY( EndDialog( IDCANCEL ) );
+ }
+ void onBack()
+ {
+ ATLVERIFY( EndDialog( IDC_BACK ) );
+ }
+
+ void printModelDescription();
+ CStatic modelDesc;
+ ConsoleCheckbox cbConsole;
+
+ LanguageDropdown languageSelector;
+ TranslateCheckbox cbTranslate;
+
+ CEdit sourceMediaPath;
+ CEdit transcribeOutputPath;
+ CButton transcribeOutputBrowse;
+ CComboBox transcribeOutFormat;
+ CProgressBarCtrl progressBar;
+ void populateOutputFormats();
+
+ LRESULT OnOutFormatChange( UINT, INT, HWND, BOOL& bHandled );
+ void onBrowseMedia();
+ void onBrowseOutput();
+ void onTranscribe();
+ void onCapture()
+ {
+ EndDialog( IDC_CAPTURE );
+ }
+
+ ThreadPoolWork work;
+
+ enum struct eOutputFormat : uint8_t;
+
+ struct TranscribeArgs
+ {
+ CString pathMedia;
+ CString pathOutput;
+ uint32_t language;
+ bool translate;
+ eOutputFormat format;
+ Whisper::eResultFlags resultFlags;
+ uint64_t startTime;
+ int64_t mediaDuration;
+ CString errorMessage;
+ };
+ TranscribeArgs transcribeArgs;
+
+ void __stdcall poolCallback() noexcept override final;
+
+ LRESULT onCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled );
+
+ HRESULT transcribe();
+ void getThreadError();
+
+ static HRESULT writeTextFile( const Whisper::sSegment* const segments, const size_t length, CAtlFile& file );
+ static HRESULT writeSubRip( const Whisper::sSegment* const segments, const size_t length, CAtlFile& file );
+ static HRESULT writeWebVTT( const Whisper::sSegment* const segments, const size_t length, CAtlFile& file );
+
+ static HRESULT __cdecl newSegmentCallbackStatic( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept;
+ HRESULT newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new );
+
+ static HRESULT __cdecl progressCallbackStatic( double p, Whisper::iContext* ctx, void* pv ) noexcept;
+ HRESULT progressCallback( double p ) noexcept;
+
+ void onWmClose();
+}; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/DebugConsole.cpp b/Examples/WhisperDesktop/Utils/DebugConsole.cpp
new file mode 100644
index 0000000..640efb0
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/DebugConsole.cpp
@@ -0,0 +1,289 @@
+// https://github.com/Const-me/vis_avs_dx/blob/master/avs_dx/DxVisuals/Interop/ConsoleLogger.cpp
+#include "stdafx.h"
+#include "DebugConsole.h"
+#include "miscUtils.h"
+#include "../AppState.h"
+#include "logger.h"
+
+namespace
+{
+ using Whisper::eLogLevel;
+
+ constexpr uint16_t defaultAttributes = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE;
+
+ inline uint16_t textAttributes( eLogLevel lvl )
+ {
+ switch( lvl )
+ {
+ case eLogLevel::Error:
+ return FOREGROUND_RED | FOREGROUND_INTENSITY;
+ case eLogLevel::Warning:
+ return FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_INTENSITY;
+ case eLogLevel::Info:
+ return FOREGROUND_GREEN | FOREGROUND_INTENSITY;
+ case eLogLevel::Debug:
+ return FOREGROUND_BLUE | FOREGROUND_INTENSITY;
+ }
+ return defaultAttributes;
+ }
+
+ // Background stuff: accumulate messages in a small buffer, in case user will want to see them in the console.
+ // Ideally, we should accumulate them in a more efficient data structure, maybe a circular buffer.
+ // However, we don't have that many messages per second, this simple solution that uses std::deque is probably good enough for the job.
+ static constexpr uint16_t bufferSize = 64;
+
+ using Lock = CComCritSecLock<CComAutoCriticalSection>;
+#define LOCK() Lock __lock{ critSec }
+
+ thread_local CStringA threadError;
+}
+
+HRESULT DebugConsole::Entry::print( HANDLE hConsole, CString& tempString ) const
+{
+ if( !SetConsoleTextAttribute( hConsole, textAttributes( level ) ) )
+ return getLastHr();
+
+ makeUtf16( tempString, message );
+ tempString += L"\r\n";
+ if( !WriteConsoleW( hConsole, tempString, (DWORD)tempString.GetLength(), nullptr, nullptr ) )
+ return getLastHr();
+ return S_OK;
+}
+
+void clearLastError()
+{
+ threadError = "";
+}
+
+bool getLastError( CString& rdi )
+{
+ if( threadError.GetLength() <= 0 )
+ {
+ rdi = L"";
+ return false;
+ }
+ else
+ {
+ makeUtf16( rdi, threadError );
+ threadError = "";
+ return true;
+ }
+}
+
+inline void DebugConsole::logSink( eLogLevel lvl, const char* message )
+{
+ LOCK();
+
+ // Add to the buffer
+ while( buffer.size() >= bufferSize )
+ buffer.pop_front();
+ buffer.emplace_back( Entry{ lvl, message } );
+
+ // If the console window is shown, print there, too.
+ if( output )
+ buffer.rbegin()->print( output, tempString );
+}
+
+void __stdcall DebugConsole::logSinkStatic( void* context, eLogLevel lvl, const char* message )
+{
+ if( lvl == eLogLevel::Error )
+ threadError = message;
+
+ DebugConsole* con = (DebugConsole*)context;
+ con->logSink( lvl, message );
+}
+
+HRESULT DebugConsole::initialize( Whisper::eLogLevel level )
+{
+ if( nullptr != pGlobalInstance )
+ return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
+ pGlobalInstance = this;
+
+ Whisper::sLoggerSetup setup;
+ setup.sink = &logSinkStatic;
+ setup.context = this;
+ setup.level = level;
+ setup.flags = Whisper::eLoggerFlags::SkipFormatMessage;
+ return Whisper::setupLogger( setup );
+}
+
+DebugConsole::~DebugConsole()
+{
+ hide();
+
+ Whisper::sLoggerSetup setup;
+ memset( &setup, 0, sizeof( setup ) );
+ Whisper::setupLogger( setup );
+
+ pGlobalInstance = nullptr;
+}
+
+DebugConsole* DebugConsole::pGlobalInstance = nullptr;
+
+void DebugConsole::windowClosed()
+{
+ LOCK();
+ if( FreeConsole() )
+ {
+ // Apparently, FreeConsole already closes that handle: https://stackoverflow.com/q/12676312/126995
+ output.Detach();
+ }
+ output.Close();
+
+ for( CButton* b : checkboxes )
+ {
+ if( !*b )
+ continue;
+ if( !b->IsWindow() )
+ continue;
+ PostMessage( *b, BM_SETCHECK, BST_UNCHECKED, 0 );
+ }
+}
+
+BOOL __stdcall DebugConsole::consoleHandlerRoutine( DWORD dwCtrlType )
+{
+ switch( dwCtrlType )
+ {
+ case CTRL_CLOSE_EVENT:
+ case CTRL_C_EVENT:
+ case CTRL_BREAK_EVENT:
+ pGlobalInstance->windowClosed();
+ return TRUE;
+ }
+ return TRUE;
+}
+
+HRESULT DebugConsole::show()
+{
+ HWND hWnd = GetConsoleWindow();
+ if( nullptr != hWnd )
+ {
+ ShowWindow( hWnd, SW_RESTORE );
+ SetForegroundWindow( hWnd );
+ return S_FALSE;
+ }
+
+ if( !AllocConsole() )
+ return getLastHr();
+
+ output.Close();
+ output.Attach( GetStdHandle( STD_OUTPUT_HANDLE ) );
+ if( !output )
+ return getLastHr();
+
+ constexpr UINT cp = CP_UTF8;
+ if( IsValidCodePage( cp ) )
+ SetConsoleOutputCP( cp );
+
+ // Enable ANSI color coding
+ DWORD mode = 0;
+ if( !GetConsoleMode( output, &mode ) )
+ return getLastHr();
+ if( 0 == ( mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING ) )
+ {
+ mode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING;
+ if( !SetConsoleMode( output, mode ) )
+ return getLastHr();
+ }
+
+ SetConsoleTitle( L"Whisper Desktop Debug Console" );
+
+ SetConsoleCtrlHandler( &consoleHandlerRoutine, TRUE );
+
+ // Disable close command in the sys.menu of the new console, otherwise the whole process will quit: https://stackoverflow.com/a/12015131/126995
+ HWND hwnd = ::GetConsoleWindow();
+ if( hwnd != nullptr )
+ {
+ HMENU hMenu = ::GetSystemMenu( hwnd, FALSE );
+ if( hMenu != NULL )
+ DeleteMenu( hMenu, SC_CLOSE, MF_BYCOMMAND );
+ }
+
+ // Print old log entries
+ for( const auto& e : buffer )
+ CHECK( e.print( output, tempString ) );
+
+ const CStringA msg = "Press Control+C or Control+Break to close this window\r\n";
+ if( !SetConsoleTextAttribute( output, FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE | FOREGROUND_INTENSITY ) )
+ return getLastHr();
+ if( !WriteConsoleA( output, cstr( msg ), msg.GetLength(), nullptr, nullptr ) )
+ return getLastHr();
+
+ return S_OK;
+}
+
+HRESULT DebugConsole::hide()
+{
+ LOCK();
+ if( !output )
+ return S_FALSE;
+ windowClosed();
+ return S_OK;
+}
+
+void DebugConsole::addCheckbox( CButton& cb )
+{
+ checkboxes.emplace( &cb );
+}
+void DebugConsole::removeCheckbox( CButton& cb )
+{
+ checkboxes.erase( &cb );
+}
+
+HRESULT ConsoleCheckbox::initialize( HWND dialog, int idc, AppState& state )
+{
+ control = GetDlgItem( dialog, idc );
+ assert( control );
+
+ console = &state.console;
+ if( state.console.isVisible() )
+ control.SetCheck( BST_CHECKED );
+
+ state.console.addCheckbox( control );
+ return S_OK;
+}
+
+void ConsoleCheckbox::click()
+{
+ const int state = control.GetCheck();
+ if( state == BST_CHECKED )
+ console->show();
+ else
+ console->hide();
+}
+
+void ConsoleCheckbox::ensureChecked()
+{
+ const int state = control.GetCheck();
+ if( state == BST_CHECKED )
+ return;
+ control.SetCheck( BST_CHECKED );
+ console->show();
+}
+
+void DebugConsole::log( eLogLevel lvl, const char* pszFormat, va_list args )
+{
+ LOCK();
+ // Add to the buffer
+ while( buffer.size() >= bufferSize )
+ buffer.pop_front();
+
+ tempStringA.FormatV( pszFormat, args );
+ buffer.emplace_back( Entry{ lvl, tempStringA } );
+
+ // If the console window is shown, print there, too.
+ if( output )
+ buffer.rbegin()->print( output, tempString );
+}
+
+void DebugConsole::logMessage( eLogLevel lvl, const char* pszFormat, va_list args )
+{
+ if( nullptr == pGlobalInstance )
+ return;
+ pGlobalInstance->log( lvl, pszFormat, args );
+}
+
+void logMessage( Whisper::eLogLevel lvl, const char8_t* pczFormat, va_list args )
+{
+ DebugConsole::logMessage( lvl, (const char*)pczFormat, args );
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/DebugConsole.h b/Examples/WhisperDesktop/Utils/DebugConsole.h
new file mode 100644
index 0000000..a9ee8f2
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/DebugConsole.h
@@ -0,0 +1,64 @@
+#pragma once
+#include <whisperWindows.h>
+#include <deque>
+#include <unordered_set>
+
+class AppState;
+class DebugConsole
+{
+ using eLogLevel = Whisper::eLogLevel;
+
+ struct Entry
+ {
+ eLogLevel level;
+ CStringA message;
+ HRESULT print( HANDLE hConsole, CString& tempString ) const;
+ };
+
+ CComAutoCriticalSection critSec;
+ std::deque<Entry> buffer;
+ CString tempString;
+ CHandle output;
+
+ inline void logSink( eLogLevel lvl, const char* message );
+ static void __stdcall logSinkStatic( void* context, eLogLevel lvl, const char* message );
+
+ static BOOL __stdcall consoleHandlerRoutine( DWORD dwCtrlType );
+
+ static DebugConsole* pGlobalInstance;
+ void windowClosed();
+
+ std::unordered_set<CButton*> checkboxes;
+
+ CStringA tempStringA;
+ void log( eLogLevel lvl, const char* pszFormat, va_list args );
+
+public:
+ HRESULT initialize( eLogLevel level = eLogLevel::Debug );
+ ~DebugConsole();
+
+ HRESULT show();
+ HRESULT hide();
+ bool isVisible() const { return output; }
+
+ void addCheckbox( CButton& cb );
+ void removeCheckbox( CButton& cb );
+
+ static void logMessage( eLogLevel lvl, const char* pszFormat, va_list args );
+};
+
+class ConsoleCheckbox
+{
+ CButton control;
+ DebugConsole* console = nullptr;
+
+public:
+ HRESULT initialize( HWND dialog, int idc, AppState& state );
+ void click();
+ ~ConsoleCheckbox()
+ {
+ if( nullptr != console )
+ console->removeCheckbox( control );
+ }
+ void ensureChecked();
+}; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/LanguageDropdown.cpp b/Examples/WhisperDesktop/Utils/LanguageDropdown.cpp
new file mode 100644
index 0000000..36cd1a8
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/LanguageDropdown.cpp
@@ -0,0 +1,87 @@
+#include "stdafx.h"
+#include "LanguageDropdown.h"
+#include "miscUtils.h"
+
+namespace
+{
+ inline wchar_t toUpper( wchar_t c )
+ {
+ size_t st = (uint16_t)c;
+ st = reinterpret_cast<size_t>( CharUpperW( reinterpret_cast<LPWSTR>( st ) ) );
+ return (wchar_t)(uint16_t)st;
+ }
+
+ void makeTitleCase( CString& s )
+ {
+ bool cap = true;
+ for( int i = 0; i < s.GetLength(); i++ )
+ {
+ wchar_t c = s[ i ];
+ if( cap )
+ {
+ c = toUpper( c );
+ s.SetAt( i, c );
+ }
+ cap = false;
+ if( c == ' ' )
+ cap = true;
+ }
+ }
+}
+
+int LanguageDropdown::getInitialSelection( AppState& state ) const
+{
+ constexpr uint32_t english = 0x6E65;
+
+ // Load preference from the registry
+ uint32_t id = state.languageRead();
+ if( id == UINT_MAX )
+ id = english;
+
+ auto it = std::find( keys.begin(), keys.end(), id );
+ if( it == keys.end() )
+ {
+ id = english;
+ it = std::find( keys.begin(), keys.end(), id );
+ assert( it != keys.end() );
+ }
+
+ ptrdiff_t idx = it - keys.begin();
+ return (int)idx;
+}
+
+void LanguageDropdown::initialize( HWND owner, int idc, AppState& state )
+{
+ m_hWnd = GetDlgItem( owner, idc );
+ assert( nullptr != m_hWnd );
+
+ Whisper::sLanguageList list;
+ Whisper::getSupportedLanguages( list );
+
+ const size_t length = list.length;
+ keys.resize( length );
+ CString utf16;
+ for( size_t i = 0; i < length; i++ )
+ {
+ keys[ i ] = list.pointer[ i ].key;
+ makeUtf16( utf16, list.pointer[ i ].name );
+ makeTitleCase( utf16 );
+ SendMessage( m_hWnd, CB_ADDSTRING, 0, (LPARAM)cstr( utf16 ) );
+ }
+
+ const int curSel = getInitialSelection( state );
+ SendMessage( m_hWnd, CB_SETCURSEL, curSel, 0 );
+}
+
+uint32_t LanguageDropdown::selectedLanguage()
+{
+ const int cs = (int)SendMessage( m_hWnd, CB_GETCURSEL, 0, 0 );
+ if( cs < 0 || cs >= keys.size() )
+ return UINT_MAX;
+ return keys[ cs ];
+}
+
+void LanguageDropdown::saveSelection( AppState& state )
+{
+ state.languageWrite( selectedLanguage() );
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/LanguageDropdown.h b/Examples/WhisperDesktop/Utils/LanguageDropdown.h
new file mode 100644
index 0000000..640b81e
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/LanguageDropdown.h
@@ -0,0 +1,26 @@
+#pragma once
+#include "../AppState.h"
+
+// Dropdown list which implements language selector control
+class LanguageDropdown
+{
+ HWND m_hWnd = nullptr;
+ std::vector<uint32_t> keys;
+ int getInitialSelection( AppState& state ) const;
+
+public:
+ operator HWND() const
+ {
+ return m_hWnd;
+ }
+
+ // Query language list form the native library, populate the combo box
+ // Then load the last saved language selection from registry, and preselect an item.
+ void initialize( HWND owner, int idc, AppState& state );
+
+ // Get the ID of the currently selected language, or UINT_MAX if nothing's selected
+ uint32_t selectedLanguage();
+
+ // Get the ID of the currently selected language, and store in registry
+ void saveSelection( AppState& state );
+}; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/PendingState.cpp b/Examples/WhisperDesktop/Utils/PendingState.cpp
new file mode 100644
index 0000000..404ae4e
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/PendingState.cpp
@@ -0,0 +1,40 @@
+#include "stdafx.h"
+#include "PendingState.h"
+
+void PendingState::initialize( std::initializer_list<HWND> editors, std::initializer_list<HWND> pending )
+{
+ editorsWindows = editors;
+ wasEnabled.resize( editorsWindows.size() );
+ pendingWindows = pending;
+}
+
+void PendingState::setPending( bool nowPending )
+{
+ if( nowPending )
+ {
+ for( size_t i = 0; i < editorsWindows.size(); i++ )
+ {
+ BOOL e = IsWindowEnabled( editorsWindows[ i ] );
+ if( e )
+ {
+ wasEnabled[ i ] = true;
+ EnableWindow( editorsWindows[ i ], FALSE );
+ }
+ else
+ wasEnabled[ i ] = false;
+ }
+ }
+ else
+ {
+ for( size_t i = 0; i < editorsWindows.size(); i++ )
+ {
+ if( !wasEnabled[ i ] )
+ continue;
+ EnableWindow( editorsWindows[ i ], TRUE );
+ }
+ }
+
+ const int show = nowPending ? SW_NORMAL : SW_HIDE;
+ for( HWND w : pendingWindows )
+ ::ShowWindow( w, show );
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/PendingState.h b/Examples/WhisperDesktop/Utils/PendingState.h
new file mode 100644
index 0000000..0b34e13
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/PendingState.h
@@ -0,0 +1,12 @@
+#pragma once
+
+// Utility class to switch visual state of dialog controls between idle and pending
+class PendingState
+{
+ std::vector<HWND> editorsWindows;
+ std::vector<bool> wasEnabled;
+ std::vector<HWND> pendingWindows;
+public:
+ void initialize( std::initializer_list<HWND> editors, std::initializer_list<HWND> pending );
+ void setPending( bool nowPending );
+}; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp b/Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp
new file mode 100644
index 0000000..c5e6ac0
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp
@@ -0,0 +1,25 @@
+#include "stdafx.h"
+#include "TranslateCheckbox.h"
+
+static const LPCTSTR regValTranslate = L"translate";
+
+void TranslateCheckbox::initialize( HWND owner, int idc, AppState& state )
+{
+ m_hWnd = GetDlgItem( owner, idc );
+ assert( nullptr != m_hWnd );
+
+ if( state.dwordLoad( regValTranslate, 0 ) != 0 )
+ ::SendMessage( m_hWnd, BM_SETCHECK, BST_CHECKED, 0L );
+}
+
+bool TranslateCheckbox::checked()
+{
+ assert( nullptr != m_hWnd );
+ const int state = ( int )::SendMessage( m_hWnd, BM_GETCHECK, 0, 0 );
+ return state == BST_CHECKED;
+}
+
+void TranslateCheckbox::saveSelection( AppState& state )
+{
+ state.dwordStore( regValTranslate, checked() ? TRUE : FALSE );
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/TranslateCheckbox.h b/Examples/WhisperDesktop/Utils/TranslateCheckbox.h
new file mode 100644
index 0000000..2b1db12
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/TranslateCheckbox.h
@@ -0,0 +1,18 @@
+#pragma once
+#include "../AppState.h"
+
+class TranslateCheckbox
+{
+ HWND m_hWnd = nullptr;
+public:
+ operator HWND() const
+ {
+ return m_hWnd;
+ }
+
+ void initialize( HWND owner, int idc, AppState& state );
+
+ bool checked();
+
+ void saveSelection( AppState& state );
+}; \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/WTL/atlapp.h b/Examples/WhisperDesktop/Utils/WTL/atlapp.h
new file mode 100644
index 0000000..8be6edc
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/WTL/atlapp.h
@@ -0,0 +1,1225 @@
+// Windows Template Library - WTL version 10.0
+// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved.
+//
+// This file is a part of the Windows Template Library.
+// The use and distribution terms for this software are covered by the
+// Microsoft Public License (http://opensource.org/licenses/MS-PL)
+// which can be found in the file MS-PL.txt at the root folder.
+
+#ifndef __ATLAPP_H__
+#define __ATLAPP_H__
+
+#pragma once
+
+#ifndef __cplusplus
+ #error WTL requires C++ compilation (use a .cpp suffix)
+#endif
+
+#ifndef __ATLBASE_H__
+ #error atlapp.h requires atlbase.h to be included first
+#endif
+
+#ifdef _WIN32_WCE
+ #error WTL10 doesn't support Windows CE
+#endif
+
+#ifdef _ATL_NO_COMMODULE
+ #error WTL doesn't support _ATL_NO_COMMODULE
+#endif
+
+#ifdef _ATL_NO_WIN_SUPPORT
+ #error WTL doesn't support _ATL_NO_WIN_SUPPORT
+#endif
+
+#if (_MSC_VER < 1400)
+ #error WTL10 requires C++ compiler version 14 (Visual C++ 2005) or higher
+#endif
+
+#if (WINVER < 0x0501)
+ #error WTL requires WINVER >= 0x0501
+#endif
+
+#if (_WIN32_WINNT < 0x0501)
+ #error WTL requires _WIN32_WINNT >= 0x0501
+#endif
+
+#if (_WIN32_IE < 0x0600)
+ #error WTL requires _WIN32_IE >= 0x0600
+#endif
+
+#if (_ATL_VER < 0x0800)
+ #error WTL10 requires ATL version 8 or higher
+#endif
+
+#ifdef _ATL_MIN_CRT
+ #error WTL10 doesn't support _ATL_MIN_CRT
+#endif
+
+#ifdef _ATL_NO_MSIMG
+ #error WTL10 doesn't support _ATL_NO_MSIMG
+#endif
+
+#include <limits.h>
+#ifdef _MT
+ #include <process.h> // for _beginthreadex
+#endif
+
+#include <commctrl.h>
+#pragma comment(lib, "comctl32.lib")
+
+#include <commdlg.h>
+#include <shellapi.h>
+
+// Check for VS2005 without newer WinSDK
+#if (_MSC_VER == 1400) && !defined(RB_GETEXTENDEDSTYLE)
+ #error WTL10 requires WinSDK 6.0 ot higher
+#endif
+
+#include <uxtheme.h>
+#pragma comment(lib, "uxtheme.lib")
+
+#if defined(_SYSINFOAPI_H_) && defined(NOT_BUILD_WINDOWS_DEPRECATE)
+ #include <VersionHelpers.h>
+#endif
+
+#include "atlres.h"
+
+
+///////////////////////////////////////////////////////////////////////////////
+// WTL version number
+
+#define _WTL_VER 0x1000 // version 10.0
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Classes in this file:
+//
+// CMessageFilter
+// CIdleHandler
+// CMessageLoop
+//
+// CAppModule
+// CServerAppModule
+//
+// Global functions:
+// AtlInitCommonControls()
+// AtlGetDefaultGuiFont()
+// AtlCreateControlFont()
+// AtlCreateBoldFont()
+// AtlGetStringPtr()
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Miscellaneous global support
+
+// define useful macros from winuser.h
+#ifndef IS_INTRESOURCE
+ #define IS_INTRESOURCE(_r) (((ULONG_PTR)(_r) >> 16) == 0)
+#endif // IS_INTRESOURCE
+
+// protect template members from windowsx.h macros
+#ifdef _INC_WINDOWSX
+ #undef SubclassWindow
+#endif // _INC_WINDOWSX
+
+// define useful macros from windowsx.h
+#ifndef GET_X_LPARAM
+ #define GET_X_LPARAM(lParam) ((int)(short)LOWORD(lParam))
+#endif
+#ifndef GET_Y_LPARAM
+ #define GET_Y_LPARAM(lParam) ((int)(short)HIWORD(lParam))
+#endif
+
+// Dummy structs for compiling with /CLR
+#ifdef _MANAGED
+ __if_not_exists(_IMAGELIST::_IMAGELIST) { struct _IMAGELIST { }; }
+ __if_not_exists(_TREEITEM::_TREEITEM) { struct _TREEITEM { }; }
+ __if_not_exists(_PSP::_PSP) { struct _PSP { }; }
+#endif
+
+// Forward declaration for ATL11 fix
+#if (_ATL_VER >= 0x0B00)
+ namespace ATL { HRESULT AtlGetCommCtrlVersion(LPDWORD pdwMajor, LPDWORD pdwMinor); }
+#endif
+
+#ifndef WM_MOUSEHWHEEL
+ #define WM_MOUSEHWHEEL 0x020E
+#endif
+
+// Used for stack allocations with ATL::CTempBuffer
+#ifndef _WTL_STACK_ALLOC_THRESHOLD
+ #define _WTL_STACK_ALLOC_THRESHOLD 512
+#endif
+
+
+namespace WTL
+{
+
+DECLARE_TRACE_CATEGORY(atlTraceUI)
+#ifdef _DEBUG
+ __declspec(selectany) ATL::CTraceCategory atlTraceUI(_T("atlTraceUI"));
+#endif // _DEBUG
+
+// Common Controls initialization helper
+inline BOOL AtlInitCommonControls(DWORD dwFlags)
+{
+ INITCOMMONCONTROLSEX iccx = { sizeof(INITCOMMONCONTROLSEX), dwFlags };
+ BOOL bRet = ::InitCommonControlsEx(&iccx);
+ ATLASSERT(bRet);
+ return bRet;
+}
+
+// Default GUI font helper - "MS Shell Dlg" stock font
+inline HFONT AtlGetDefaultGuiFont()
+{
+ return (HFONT)::GetStockObject(DEFAULT_GUI_FONT);
+}
+
+// Control font helper - default font for controls not in a dialog
+// (NOTE: Caller owns the font, and should destroy it when it's no longer needed)
+inline HFONT AtlCreateControlFont()
+{
+ LOGFONT lf = {};
+ ATLVERIFY(::SystemParametersInfo(SPI_GETICONTITLELOGFONT, sizeof(LOGFONT), &lf, 0) != FALSE);
+ HFONT hFont = ::CreateFontIndirect(&lf);
+ ATLASSERT(hFont != NULL);
+ return hFont;
+}
+
+// Bold font helper
+// (NOTE: Caller owns the font, and should destroy it when it's no longer needed)
+inline HFONT AtlCreateBoldFont(HFONT hFont = NULL)
+{
+ LOGFONT lf = {};
+ if(hFont == NULL)
+ ATLVERIFY(::SystemParametersInfo(SPI_GETICONTITLELOGFONT, sizeof(LOGFONT), &lf, 0) != FALSE);
+ else
+ ATLVERIFY(::GetObject(hFont, sizeof(LOGFONT), &lf) == sizeof(LOGFONT));
+ lf.lfWeight = FW_BOLD;
+ HFONT hFontBold = ::CreateFontIndirect(&lf);
+ ATLASSERT(hFontBold != NULL);
+ return hFontBold;
+}
+
+// Resource string pointer
+inline LPCWSTR AtlGetStringPtr(UINT uID, int* pch = NULL)
+{
+ LPCWSTR lpstr = NULL;
+ int nRet = ::LoadStringW(ATL::_AtlBaseModule.GetResourceInstance(), uID, (LPWSTR)&lpstr, 0);
+ if(pch != NULL)
+ *pch = nRet;
+ return lpstr;
+}
+
+
+///////////////////////////////////////////////////////////////////////////////
+// RunTimeHelper - helper functions for Windows version and structure sizes
+
+#ifndef _WTL_NO_RUNTIME_STRUCT_SIZE
+
+#ifndef _SIZEOF_STRUCT
+ #define _SIZEOF_STRUCT(structname, member) (((int)((LPBYTE)(&((structname*)0)->member) - ((LPBYTE)((structname*)0)))) + sizeof(((structname*)0)->member))
+#endif
+
+#if (_WIN32_WINNT >= 0x0600) && !defined(REBARBANDINFO_V6_SIZE)
+ #define REBARBANDINFO_V6_SIZE _SIZEOF_STRUCT(REBARBANDINFO, cxHeader)
+#endif // (_WIN32_WINNT >= 0x0600) && !defined(REBARBANDINFO_V6_SIZE)
+
+#if (_WIN32_WINNT >= 0x0600) && !defined(LVGROUP_V5_SIZE)
+ #define LVGROUP_V5_SIZE _SIZEOF_STRUCT(LVGROUP, uAlign)
+#endif // (_WIN32_WINNT >= 0x0600) && !defined(LVGROUP_V5_SIZE)
+
+#if (_WIN32_WINNT >= 0x0600) && !defined(LVTILEINFO_V5_SIZE)
+ #define LVTILEINFO_V5_SIZE _SIZEOF_STRUCT(LVTILEINFO, puColumns)
+#endif // (_WIN32_WINNT >= 0x0600) && !defined(LVTILEINFO_V5_SIZE)
+
+#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) && !defined(MCHITTESTINFO_V1_SIZE)
+ #define MCHITTESTINFO_V1_SIZE _SIZEOF_STRUCT(MCHITTESTINFO, st)
+#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN) && !defined(MCHITTESTINFO_V1_SIZE)
+
+#if (WINVER >= 0x0600) && !defined(NONCLIENTMETRICS_V1_SIZE)
+ #define NONCLIENTMETRICS_V1_SIZE _SIZEOF_STRUCT(NONCLIENTMETRICS, lfMessageFont)
+#endif // (WINVER >= 0x0600) && !defined(NONCLIENTMETRICS_V1_SIZE)
+
+#ifndef TTTOOLINFO_V2_SIZE
+ #define TTTOOLINFO_V2_SIZE _SIZEOF_STRUCT(TTTOOLINFO, lParam)
+#endif
+
+#endif // !_WTL_NO_RUNTIME_STRUCT_SIZE
+
+namespace RunTimeHelper
+{
+ inline bool IsCommCtrl6()
+ {
+ DWORD dwMajor = 0, dwMinor = 0;
+ HRESULT hRet = ATL::AtlGetCommCtrlVersion(&dwMajor, &dwMinor);
+ return (SUCCEEDED(hRet) && (dwMajor >= 6));
+ }
+
+ inline bool IsVista()
+ {
+#ifdef _versionhelpers_H_INCLUDED_
+ return ::IsWindowsVistaOrGreater();
+#else // !_versionhelpers_H_INCLUDED_
+ OSVERSIONINFO ovi = { sizeof(OSVERSIONINFO) };
+ BOOL bRet = ::GetVersionEx(&ovi);
+ return ((bRet != FALSE) && (ovi.dwMajorVersion >= 6));
+#endif // _versionhelpers_H_INCLUDED_
+ }
+
+ inline bool IsThemeAvailable()
+ {
+ return IsCommCtrl6() && (::IsThemeActive() != FALSE) && (::IsAppThemed() != FALSE);
+ }
+
+ inline bool IsWin7()
+ {
+#ifdef _versionhelpers_H_INCLUDED_
+ return ::IsWindows7OrGreater();
+#else // !_versionhelpers_H_INCLUDED_
+ OSVERSIONINFO ovi = { sizeof(OSVERSIONINFO) };
+ BOOL bRet = ::GetVersionEx(&ovi);
+ return ((bRet != FALSE) && ((ovi.dwMajorVersion > 6) || ((ovi.dwMajorVersion == 6) && (ovi.dwMinorVersion >= 1))));
+#endif // _versionhelpers_H_INCLUDED_
+ }
+
+ inline bool IsRibbonUIAvailable()
+ {
+ static INT iRibbonUI = -1;
+
+#if defined(NTDDI_WIN7) && (NTDDI_VERSION >= NTDDI_WIN7)
+ if (iRibbonUI == -1)
+ {
+ HMODULE hRibbonDLL = ::LoadLibrary(_T("propsys.dll"));
+ if (hRibbonDLL != NULL)
+ {
+ const GUID CLSID_UIRibbonFramework = { 0x926749fa, 0x2615, 0x4987, { 0x88, 0x45, 0xc3, 0x3e, 0x65, 0xf2, 0xb9, 0x57 } };
+ // block - create instance
+ {
+ ATL::CComPtr<IUnknown> pIUIFramework;
+ iRibbonUI = SUCCEEDED(pIUIFramework.CoCreateInstance(CLSID_UIRibbonFramework)) ? 1 : 0;
+ }
+ ::FreeLibrary(hRibbonDLL);
+ }
+ else
+ {
+ iRibbonUI = 0;
+ }
+ }
+#endif // defined(NTDDI_WIN7) && (NTDDI_VERSION >= NTDDI_WIN7)
+
+ return (iRibbonUI == 1);
+ }
+
+ inline UINT SizeOf_REBARBANDINFO()
+ {
+ UINT uSize = sizeof(REBARBANDINFO);
+#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600)
+ if(!(IsVista() && IsCommCtrl6()))
+ uSize = REBARBANDINFO_V6_SIZE;
+#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600)
+ return uSize;
+ }
+
+ inline UINT SizeOf_LVGROUP()
+ {
+ UINT uSize = sizeof(LVGROUP);
+#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600)
+ if(!IsVista())
+ uSize = LVGROUP_V5_SIZE;
+#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600)
+ return uSize;
+ }
+
+ inline UINT SizeOf_LVTILEINFO()
+ {
+ UINT uSize = sizeof(LVTILEINFO);
+#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600)
+ if(!IsVista())
+ uSize = LVTILEINFO_V5_SIZE;
+#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (_WIN32_WINNT >= 0x0600)
+ return uSize;
+ }
+
+ inline UINT SizeOf_MCHITTESTINFO()
+ {
+ UINT uSize = sizeof(MCHITTESTINFO);
+#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+ if(!(IsVista() && IsCommCtrl6()))
+ uSize = MCHITTESTINFO_V1_SIZE;
+#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+ return uSize;
+ }
+
+ inline UINT SizeOf_NONCLIENTMETRICS()
+ {
+ UINT uSize = sizeof(NONCLIENTMETRICS);
+#if !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (WINVER >= 0x0600)
+ if(!IsVista())
+ uSize = NONCLIENTMETRICS_V1_SIZE;
+#endif // !defined(_WTL_NO_RUNTIME_STRUCT_SIZE) && (WINVER >= 0x0600)
+ return uSize;
+ }
+
+ inline UINT SizeOf_TOOLINFO()
+ {
+ UINT uSize = sizeof(TOOLINFO);
+#ifndef _WTL_NO_RUNTIME_STRUCT_SIZE
+ if(!IsVista())
+ uSize = TTTOOLINFO_V2_SIZE;
+#endif
+ return uSize;
+ }
+} // namespace RunTimeHelper
+
+
+///////////////////////////////////////////////////////////////////////////////
+// ModuleHelper - helper functions for ATL (deprecated)
+
+namespace ModuleHelper
+{
+ inline HINSTANCE GetModuleInstance()
+ {
+ return ATL::_AtlBaseModule.GetModuleInstance();
+ }
+
+ inline HINSTANCE GetResourceInstance()
+ {
+ return ATL::_AtlBaseModule.GetResourceInstance();
+ }
+
+ inline void AddCreateWndData(ATL::_AtlCreateWndData* pData, void* pObject)
+ {
+ ATL::_AtlWinModule.AddCreateWndData(pData, pObject);
+ }
+
+ inline void* ExtractCreateWndData()
+ {
+ return ATL::_AtlWinModule.ExtractCreateWndData();
+ }
+} // namespace ModuleHelper
+
+
+///////////////////////////////////////////////////////////////////////////////
+// SecureHelper - WTL10 requires use of secure functions
+// these are here only for compatibility with existing projects
+
+namespace SecureHelper
+{
+ inline void strcpyA_x(char* lpstrDest, size_t cchDest, const char* lpstrSrc)
+ {
+ ATL::Checked::strcpy_s(lpstrDest, cchDest, lpstrSrc);
+ }
+
+ inline void strcpyW_x(wchar_t* lpstrDest, size_t cchDest, const wchar_t* lpstrSrc)
+ {
+ ATL::Checked::wcscpy_s(lpstrDest, cchDest, lpstrSrc);
+ }
+
+ inline void strcpy_x(LPTSTR lpstrDest, size_t cchDest, LPCTSTR lpstrSrc)
+ {
+#ifdef _UNICODE
+ strcpyW_x(lpstrDest, cchDest, lpstrSrc);
+#else
+ strcpyA_x(lpstrDest, cchDest, lpstrSrc);
+#endif
+ }
+
+ inline errno_t strncpyA_x(char* lpstrDest, size_t cchDest, const char* lpstrSrc, size_t cchCount)
+ {
+ return ATL::Checked::strncpy_s(lpstrDest, cchDest, lpstrSrc, cchCount);
+ }
+
+ inline errno_t strncpyW_x(wchar_t* lpstrDest, size_t cchDest, const wchar_t* lpstrSrc, size_t cchCount)
+ {
+ return ATL::Checked::wcsncpy_s(lpstrDest, cchDest, lpstrSrc, cchCount);
+ }
+
+ inline errno_t strncpy_x(LPTSTR lpstrDest, size_t cchDest, LPCTSTR lpstrSrc, size_t cchCount)
+ {
+#ifdef _UNICODE
+ return strncpyW_x(lpstrDest, cchDest, lpstrSrc, cchCount);
+#else
+ return strncpyA_x(lpstrDest, cchDest, lpstrSrc, cchCount);
+#endif
+ }
+
+ inline void strcatA_x(char* lpstrDest, size_t cchDest, const char* lpstrSrc)
+ {
+ ATL::Checked::strcat_s(lpstrDest, cchDest, lpstrSrc);
+ }
+
+ inline void strcatW_x(wchar_t* lpstrDest, size_t cchDest, const wchar_t* lpstrSrc)
+ {
+ ATL::Checked::wcscat_s(lpstrDest, cchDest, lpstrSrc);
+ }
+
+ inline void strcat_x(LPTSTR lpstrDest, size_t cchDest, LPCTSTR lpstrSrc)
+ {
+#ifdef _UNICODE
+ strcatW_x(lpstrDest, cchDest, lpstrSrc);
+#else
+ strcatA_x(lpstrDest, cchDest, lpstrSrc);
+#endif
+ }
+
+ inline void memcpy_x(void* pDest, size_t cbDest, const void* pSrc, size_t cbSrc)
+ {
+ ATL::Checked::memcpy_s(pDest, cbDest, pSrc, cbSrc);
+ }
+
+ inline void memmove_x(void* pDest, size_t cbDest, const void* pSrc, size_t cbSrc)
+ {
+ ATL::Checked::memmove_s(pDest, cbDest, pSrc, cbSrc);
+ }
+
+ inline int vsprintf_x(LPTSTR lpstrBuff, size_t cchBuff, LPCTSTR lpstrFormat, va_list args)
+ {
+ return _vstprintf_s(lpstrBuff, cchBuff, lpstrFormat, args);
+ }
+
+ inline int wvsprintf_x(LPTSTR lpstrBuff, size_t cchBuff, LPCTSTR lpstrFormat, va_list args)
+ {
+ return _vstprintf_s(lpstrBuff, cchBuff, lpstrFormat, args);
+ }
+
+ inline int sprintf_x(LPTSTR lpstrBuff, size_t cchBuff, LPCTSTR lpstrFormat, ...)
+ {
+ va_list args;
+ va_start(args, lpstrFormat);
+ int nRes = vsprintf_x(lpstrBuff, cchBuff, lpstrFormat, args);
+ va_end(args);
+ return nRes;
+ }
+
+ inline int wsprintf_x(LPTSTR lpstrBuff, size_t cchBuff, LPCTSTR lpstrFormat, ...)
+ {
+ va_list args;
+ va_start(args, lpstrFormat);
+ int nRes = wvsprintf_x(lpstrBuff, cchBuff, lpstrFormat, args);
+ va_end(args);
+ return nRes;
+ }
+} // namespace SecureHelper
+
+
+///////////////////////////////////////////////////////////////////////////////
+// MinCrtHelper - WTL10 doesn't support _ATL_MIN_CRT,
+// these are here only for compatibility with existing projects
+
+namespace MinCrtHelper
+{
+ inline int _isspace(TCHAR ch)
+ {
+ return _istspace(ch);
+ }
+
+ inline int _isdigit(TCHAR ch)
+ {
+ return _istdigit(ch);
+ }
+
+ inline int _atoi(LPCTSTR str)
+ {
+ return _ttoi(str);
+ }
+
+ inline LPCTSTR _strrchr(LPCTSTR str, TCHAR ch)
+ {
+ return _tcsrchr(str, ch);
+ }
+
+ inline LPTSTR _strrchr(LPTSTR str, TCHAR ch)
+ {
+ return _tcsrchr(str, ch);
+ }
+} // namespace MinCrtHelper
+
+
+///////////////////////////////////////////////////////////////////////////////
+// GenericWndClass - generic window class usable for subclassing
+
+// Use in dialog templates to specify a placeholder to be subclassed
+// Specify as a custom control with class name WTL_GenericWindow
+// Call Rregister() before creating dialog (for example, in WinMain)
+namespace GenericWndClass
+{
+ inline LPCTSTR GetName()
+ {
+ return _T("WTL_GenericWindow");
+ }
+
+ inline ATOM Register()
+ {
+ WNDCLASSEX wc = { sizeof(WNDCLASSEX) };
+ wc.lpfnWndProc = ::DefWindowProc;
+ wc.hInstance = ModuleHelper::GetModuleInstance();
+ wc.hCursor = ::LoadCursor(NULL, IDC_ARROW);
+ wc.hbrBackground = (HBRUSH)(COLOR_WINDOW + 1);
+ wc.lpszClassName = GetName();
+ ATOM atom = ::RegisterClassEx(&wc);
+ ATLASSERT(atom != 0);
+ return atom;
+ }
+
+ inline BOOL Unregister() // only needed for DLLs or tmp use
+ {
+ return ::UnregisterClass(GetName(), ModuleHelper::GetModuleInstance());
+ }
+} // namespace GenericWndClass
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CMessageFilter - Interface for message filter support
+
+class ATL_NO_VTABLE CMessageFilter
+{
+public:
+ virtual BOOL PreTranslateMessage(MSG* pMsg) = 0;
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CIdleHandler - Interface for idle processing
+
+class ATL_NO_VTABLE CIdleHandler
+{
+public:
+ virtual BOOL OnIdle() = 0;
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CMessageLoop - message loop implementation
+
+class CMessageLoop
+{
+public:
+ ATL::CSimpleArray<CMessageFilter*> m_aMsgFilter;
+ ATL::CSimpleArray<CIdleHandler*> m_aIdleHandler;
+ MSG m_msg;
+
+ CMessageLoop()
+ {
+ memset(&m_msg, 0, sizeof(m_msg));
+ }
+
+ virtual ~CMessageLoop()
+ { }
+
+// Message filter operations
+ BOOL AddMessageFilter(CMessageFilter* pMessageFilter)
+ {
+ return m_aMsgFilter.Add(pMessageFilter);
+ }
+
+ BOOL RemoveMessageFilter(CMessageFilter* pMessageFilter)
+ {
+ return m_aMsgFilter.Remove(pMessageFilter);
+ }
+
+// Idle handler operations
+ BOOL AddIdleHandler(CIdleHandler* pIdleHandler)
+ {
+ return m_aIdleHandler.Add(pIdleHandler);
+ }
+
+ BOOL RemoveIdleHandler(CIdleHandler* pIdleHandler)
+ {
+ return m_aIdleHandler.Remove(pIdleHandler);
+ }
+
+// message loop
+ int Run()
+ {
+ BOOL bDoIdle = TRUE;
+ int nIdleCount = 0;
+ BOOL bRet = FALSE;
+
+ for(;;)
+ {
+ while(bDoIdle && !::PeekMessage(&m_msg, NULL, 0, 0, PM_NOREMOVE))
+ {
+ if(!OnIdle(nIdleCount++))
+ bDoIdle = FALSE;
+ }
+
+ bRet = ::GetMessage(&m_msg, NULL, 0, 0);
+
+ if(bRet == -1)
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("::GetMessage returned -1 (error)\n"));
+ continue; // error, don't process
+ }
+ else if(!bRet)
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("CMessageLoop::Run - exiting\n"));
+ break; // WM_QUIT, exit message loop
+ }
+
+ if(!PreTranslateMessage(&m_msg))
+ {
+ ::TranslateMessage(&m_msg);
+ ::DispatchMessage(&m_msg);
+ }
+
+ if(IsIdleMessage(&m_msg))
+ {
+ bDoIdle = TRUE;
+ nIdleCount = 0;
+ }
+ }
+
+ return (int)m_msg.wParam;
+ }
+
+// Overrideables
+ // Override to change message filtering
+ virtual BOOL PreTranslateMessage(MSG* pMsg)
+ {
+ // loop backwards
+ for(int i = m_aMsgFilter.GetSize() - 1; i >= 0; i--)
+ {
+ CMessageFilter* pMessageFilter = m_aMsgFilter[i];
+ if((pMessageFilter != NULL) && pMessageFilter->PreTranslateMessage(pMsg))
+ return TRUE;
+ }
+ return FALSE; // not translated
+ }
+
+ // override to change idle processing
+ virtual BOOL OnIdle(int /*nIdleCount*/)
+ {
+ for(int i = 0; i < m_aIdleHandler.GetSize(); i++)
+ {
+ CIdleHandler* pIdleHandler = m_aIdleHandler[i];
+ if(pIdleHandler != NULL)
+ pIdleHandler->OnIdle();
+ }
+ return FALSE; // don't continue
+ }
+
+ // override to change non-idle messages
+ virtual BOOL IsIdleMessage(MSG* pMsg) const
+ {
+ // These messages should NOT cause idle processing
+ switch(pMsg->message)
+ {
+ case WM_MOUSEMOVE:
+ case WM_NCMOUSEMOVE:
+ case WM_PAINT:
+ case 0x0118: // WM_SYSTIMER (caret blink)
+ return FALSE;
+ }
+
+ return TRUE;
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CStaticDataInitCriticalSectionLock and CWindowCreateCriticalSectionLock
+// internal classes to manage critical sections for ATL (deprecated)
+
+class CStaticDataInitCriticalSectionLock
+{
+public:
+ ATL::CComCritSecLock<ATL::CComCriticalSection> m_cslock;
+
+ CStaticDataInitCriticalSectionLock() : m_cslock(ATL::_pAtlModule->m_csStaticDataInitAndTypeInfo, false)
+ { }
+
+ HRESULT Lock()
+ {
+ return m_cslock.Lock();
+ }
+
+ void Unlock()
+ {
+ m_cslock.Unlock();
+ }
+};
+
+
+class CWindowCreateCriticalSectionLock
+{
+public:
+ ATL::CComCritSecLock<ATL::CComCriticalSection> m_cslock;
+
+ CWindowCreateCriticalSectionLock() : m_cslock(ATL::_AtlWinModule.m_csWindowCreate, false)
+ { }
+
+ HRESULT Lock()
+ {
+ return m_cslock.Lock();
+ }
+
+ void Unlock()
+ {
+ m_cslock.Unlock();
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CAppModule - module class for an application
+
+#if (_MSC_VER == 1400) // VS2005
+ #pragma warning(push)
+ #pragma warning(disable : 4244)
+ #pragma warning(disable : 4312)
+#endif
+
+class CAppModule : public ATL::CComModule
+{
+public:
+ DWORD m_dwMainThreadID;
+ ATL::CSimpleMap<DWORD, CMessageLoop*>* m_pMsgLoopMap;
+ ATL::CSimpleArray<HWND>* m_pSettingChangeNotify;
+
+ CAppModule() : m_dwMainThreadID(0), m_pMsgLoopMap(NULL), m_pSettingChangeNotify(NULL)
+ { }
+
+// Overrides of CComModule::Init and Term
+ HRESULT Init(ATL::_ATL_OBJMAP_ENTRY* pObjMap, HINSTANCE hInstance, const GUID* pLibID = NULL)
+ {
+ HRESULT hRet = CComModule::Init(pObjMap, hInstance, pLibID);
+ if(FAILED(hRet))
+ return hRet;
+
+ m_dwMainThreadID = ::GetCurrentThreadId();
+ typedef ATL::CSimpleMap<DWORD, CMessageLoop*> _mapClass;
+ m_pMsgLoopMap = NULL;
+ ATLTRY(m_pMsgLoopMap = new _mapClass);
+ if(m_pMsgLoopMap == NULL)
+ return E_OUTOFMEMORY;
+ m_pSettingChangeNotify = NULL;
+
+ return hRet;
+ }
+
+ void Term()
+ {
+ TermSettingChangeNotify();
+ delete m_pMsgLoopMap;
+ CComModule::Term();
+ }
+
+// Message loop map methods
+ BOOL AddMessageLoop(CMessageLoop* pMsgLoop)
+ {
+ CStaticDataInitCriticalSectionLock lock;
+ if(FAILED(lock.Lock()))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::AddMessageLoop.\n"));
+ ATLASSERT(FALSE);
+ return FALSE;
+ }
+
+ ATLASSERT(pMsgLoop != NULL);
+ ATLASSERT(m_pMsgLoopMap->Lookup(::GetCurrentThreadId()) == NULL); // not in map yet
+
+ BOOL bRet = m_pMsgLoopMap->Add(::GetCurrentThreadId(), pMsgLoop);
+
+ lock.Unlock();
+
+ return bRet;
+ }
+
+ BOOL RemoveMessageLoop()
+ {
+ CStaticDataInitCriticalSectionLock lock;
+ if(FAILED(lock.Lock()))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::RemoveMessageLoop.\n"));
+ ATLASSERT(FALSE);
+ return FALSE;
+ }
+
+ BOOL bRet = m_pMsgLoopMap->Remove(::GetCurrentThreadId());
+
+ lock.Unlock();
+
+ return bRet;
+ }
+
+ CMessageLoop* GetMessageLoop(DWORD dwThreadID = ::GetCurrentThreadId()) const
+ {
+ CStaticDataInitCriticalSectionLock lock;
+ if(FAILED(lock.Lock()))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::GetMessageLoop.\n"));
+ ATLASSERT(FALSE);
+ return NULL;
+ }
+
+ CMessageLoop* pLoop = m_pMsgLoopMap->Lookup(dwThreadID);
+
+ lock.Unlock();
+
+ return pLoop;
+ }
+
+// Setting change notify methods
+ // Note: Call this from the main thread for MSDI apps
+ BOOL InitSettingChangeNotify(DLGPROC pfnDlgProc = _SettingChangeDlgProc)
+ {
+ CStaticDataInitCriticalSectionLock lock;
+ if(FAILED(lock.Lock()))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::InitSettingChangeNotify.\n"));
+ ATLASSERT(FALSE);
+ return FALSE;
+ }
+
+ if(m_pSettingChangeNotify == NULL)
+ {
+ typedef ATL::CSimpleArray<HWND> _notifyClass;
+ ATLTRY(m_pSettingChangeNotify = new _notifyClass);
+ ATLASSERT(m_pSettingChangeNotify != NULL);
+ }
+
+ BOOL bRet = (m_pSettingChangeNotify != NULL);
+ if(bRet && (m_pSettingChangeNotify->GetSize() == 0))
+ {
+ // init everything
+ _ATL_EMPTY_DLGTEMPLATE templ;
+ HWND hNtfWnd = ::CreateDialogIndirect(GetModuleInstance(), &templ, NULL, pfnDlgProc);
+ ATLASSERT(::IsWindow(hNtfWnd));
+ if(::IsWindow(hNtfWnd))
+ {
+ ::SetWindowLongPtr(hNtfWnd, GWLP_USERDATA, (LONG_PTR)this);
+ bRet = m_pSettingChangeNotify->Add(hNtfWnd);
+ }
+ else
+ {
+ bRet = FALSE;
+ }
+ }
+
+ lock.Unlock();
+
+ return bRet;
+ }
+
+ void TermSettingChangeNotify()
+ {
+ CStaticDataInitCriticalSectionLock lock;
+ if(FAILED(lock.Lock()))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::TermSettingChangeNotify.\n"));
+ ATLASSERT(FALSE);
+ return;
+ }
+
+ if((m_pSettingChangeNotify != NULL) && (m_pSettingChangeNotify->GetSize() > 0))
+ ::DestroyWindow((*m_pSettingChangeNotify)[0]);
+ delete m_pSettingChangeNotify;
+ m_pSettingChangeNotify = NULL;
+
+ lock.Unlock();
+ }
+
+ BOOL AddSettingChangeNotify(HWND hWnd)
+ {
+ CStaticDataInitCriticalSectionLock lock;
+ if(FAILED(lock.Lock()))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::AddSettingChangeNotify.\n"));
+ ATLASSERT(FALSE);
+ return FALSE;
+ }
+
+ ATLASSERT(::IsWindow(hWnd));
+ BOOL bRet = FALSE;
+ if(InitSettingChangeNotify() != FALSE)
+ bRet = m_pSettingChangeNotify->Add(hWnd);
+
+ lock.Unlock();
+
+ return bRet;
+ }
+
+ BOOL RemoveSettingChangeNotify(HWND hWnd)
+ {
+ CStaticDataInitCriticalSectionLock lock;
+ if(FAILED(lock.Lock()))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CAppModule::RemoveSettingChangeNotify.\n"));
+ ATLASSERT(FALSE);
+ return FALSE;
+ }
+
+ BOOL bRet = FALSE;
+ if(m_pSettingChangeNotify != NULL)
+ bRet = m_pSettingChangeNotify->Remove(hWnd);
+
+ lock.Unlock();
+
+ return bRet;
+ }
+
+// Implementation - setting change notify dialog template and dialog procedure
+ struct _ATL_EMPTY_DLGTEMPLATE : DLGTEMPLATE
+ {
+ _ATL_EMPTY_DLGTEMPLATE()
+ {
+ memset(this, 0, sizeof(_ATL_EMPTY_DLGTEMPLATE));
+ style = WS_POPUP;
+ }
+ WORD wMenu, wClass, wTitle;
+ };
+
+ static INT_PTR CALLBACK _SettingChangeDlgProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam)
+ {
+ if(uMsg == WM_SETTINGCHANGE)
+ {
+ CAppModule* pModule = (CAppModule*)::GetWindowLongPtr(hWnd, GWLP_USERDATA);
+ ATLASSERT(pModule != NULL);
+ ATLASSERT(pModule->m_pSettingChangeNotify != NULL);
+ const UINT uTimeout = 1500; // ms
+ for(int i = 1; i < pModule->m_pSettingChangeNotify->GetSize(); i++)
+ ::SendMessageTimeout((*pModule->m_pSettingChangeNotify)[i], uMsg, wParam, lParam, SMTO_ABORTIFHUNG, uTimeout, NULL);
+
+ return TRUE;
+ }
+
+ return FALSE;
+ }
+};
+
+#if (_MSC_VER == 1400) // VS2005
+ #pragma warning(pop)
+#endif
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CServerAppModule - module class for a COM server application
+
+class CServerAppModule : public CAppModule
+{
+public:
+ HANDLE m_hEventShutdown;
+ bool m_bActivity;
+ DWORD m_dwTimeOut;
+ DWORD m_dwPause;
+
+ CServerAppModule() : m_hEventShutdown(NULL), m_bActivity(false), m_dwTimeOut(5000), m_dwPause(1000)
+ { }
+
+// Override of CAppModule::Init
+ HRESULT Init(ATL::_ATL_OBJMAP_ENTRY* pObjMap, HINSTANCE hInstance, const GUID* pLibID = NULL)
+ {
+ m_dwTimeOut = 5000;
+ m_dwPause = 1000;
+ return CAppModule::Init(pObjMap, hInstance, pLibID);
+ }
+
+ void Term()
+ {
+ if((m_hEventShutdown != NULL) && ::CloseHandle(m_hEventShutdown))
+ m_hEventShutdown = NULL;
+ CAppModule::Term();
+ }
+
+// COM Server methods
+ LONG Unlock() throw()
+ {
+ LONG lRet = CComModule::Unlock();
+ if(lRet == 0)
+ {
+ m_bActivity = true;
+ ::SetEvent(m_hEventShutdown); // tell monitor that we transitioned to zero
+ }
+ return lRet;
+ }
+
+ void MonitorShutdown()
+ {
+ for(;;)
+ {
+ ::WaitForSingleObject(m_hEventShutdown, INFINITE);
+ DWORD dwWait = 0;
+ do
+ {
+ m_bActivity = false;
+ dwWait = ::WaitForSingleObject(m_hEventShutdown, m_dwTimeOut);
+ }
+ while(dwWait == WAIT_OBJECT_0);
+ // timed out
+ if(!m_bActivity && (m_nLockCnt == 0)) // if no activity let's really bail
+ {
+#if defined(_WIN32_DCOM) && defined(_ATL_FREE_THREADED)
+ ::CoSuspendClassObjects();
+ if(!m_bActivity && (m_nLockCnt == 0))
+#endif
+ break;
+ }
+ }
+ // This handle should be valid now. If it isn't,
+ // check if _Module.Term was called first (it shouldn't)
+ if(::CloseHandle(m_hEventShutdown))
+ m_hEventShutdown = NULL;
+ ::PostThreadMessage(m_dwMainThreadID, WM_QUIT, 0, 0);
+ }
+
+ bool StartMonitor()
+ {
+ m_hEventShutdown = ::CreateEvent(NULL, false, false, NULL);
+ if(m_hEventShutdown == NULL)
+ return false;
+ DWORD dwThreadID = 0;
+#ifdef _MT
+ HANDLE hThread = (HANDLE)_beginthreadex(NULL, 0, (UINT (WINAPI*)(void*))MonitorProc, this, 0, (UINT*)&dwThreadID);
+#else
+ HANDLE hThread = ::CreateThread(NULL, 0, MonitorProc, this, 0, &dwThreadID);
+#endif
+ bool bRet = (hThread != NULL);
+ if(bRet)
+ ::CloseHandle(hThread);
+ return bRet;
+ }
+
+ static DWORD WINAPI MonitorProc(void* pv)
+ {
+ CServerAppModule* p = (CServerAppModule*)pv;
+ p->MonitorShutdown();
+ return 0;
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CRegKeyEx - not used any more, here only for compatibility with old projects
+
+typedef ATL::CRegKey CRegKeyEx;
+
+} // namespace WTL
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CString forward reference (enables CString use in atluser.h and atlgdi.h)
+
+#if (defined(_WTL_USE_CSTRING) || defined(_WTL_FORWARD_DECLARE_CSTRING)) && !defined(__ATLSTR_H__)
+ #include <atlstr.h>
+#endif
+
+// CString namespace
+#define _CSTRING_NS ATL
+
+// Type classes namespace
+#define _WTYPES_NS
+
+
+///////////////////////////////////////////////////////////////////////////////
+// General DLL version helpers (removed in ATL11)
+
+#if (_ATL_VER >= 0x0B00)
+
+namespace ATL
+{
+
+inline HRESULT AtlGetDllVersion(HINSTANCE hInstDLL, DLLVERSIONINFO* pDllVersionInfo)
+{
+ ATLASSERT(pDllVersionInfo != NULL);
+ if(pDllVersionInfo == NULL)
+ return E_INVALIDARG;
+
+ // We must get this function explicitly because some DLLs don't implement it.
+ DLLGETVERSIONPROC pfnDllGetVersion = (DLLGETVERSIONPROC)::GetProcAddress(hInstDLL, "DllGetVersion");
+ if(pfnDllGetVersion == NULL)
+ return E_NOTIMPL;
+
+ return (*pfnDllGetVersion)(pDllVersionInfo);
+}
+
+inline HRESULT AtlGetDllVersion(LPCTSTR lpstrDllName, DLLVERSIONINFO* pDllVersionInfo)
+{
+ HINSTANCE hInstDLL = ::LoadLibrary(lpstrDllName);
+ if(hInstDLL == NULL)
+ return E_FAIL;
+ HRESULT hRet = AtlGetDllVersion(hInstDLL, pDllVersionInfo);
+ ::FreeLibrary(hInstDLL);
+ return hRet;
+}
+
+// Common Control Versions:
+// Win95/WinNT 4.0 maj=4 min=00
+// IE 3.x maj=4 min=70
+// IE 4.0 maj=4 min=71
+inline HRESULT AtlGetCommCtrlVersion(LPDWORD pdwMajor, LPDWORD pdwMinor)
+{
+ ATLASSERT((pdwMajor != NULL) && (pdwMinor != NULL));
+ if((pdwMajor == NULL) || (pdwMinor == NULL))
+ return E_INVALIDARG;
+
+ DLLVERSIONINFO dvi;
+ ::ZeroMemory(&dvi, sizeof(dvi));
+ dvi.cbSize = sizeof(dvi);
+ HRESULT hRet = AtlGetDllVersion(_T("comctl32.dll"), &dvi);
+
+ if(SUCCEEDED(hRet))
+ {
+ *pdwMajor = dvi.dwMajorVersion;
+ *pdwMinor = dvi.dwMinorVersion;
+ }
+ else if(hRet == E_NOTIMPL)
+ {
+ // If DllGetVersion is not there, then the DLL is a version
+ // previous to the one shipped with IE 3.x
+ *pdwMajor = 4;
+ *pdwMinor = 0;
+ hRet = S_OK;
+ }
+
+ return hRet;
+}
+
+// Shell Versions:
+// Win95/WinNT 4.0 maj=4 min=00
+// IE 3.x, IE 4.0 without Web Integrated Desktop maj=4 min=00
+// IE 4.0 with Web Integrated Desktop maj=4 min=71
+// IE 4.01 with Web Integrated Desktop maj=4 min=72
+inline HRESULT AtlGetShellVersion(LPDWORD pdwMajor, LPDWORD pdwMinor)
+{
+ ATLASSERT((pdwMajor != NULL) && (pdwMinor != NULL));
+ if((pdwMajor == NULL) || (pdwMinor == NULL))
+ return E_INVALIDARG;
+
+ DLLVERSIONINFO dvi;
+ ::ZeroMemory(&dvi, sizeof(dvi));
+ dvi.cbSize = sizeof(dvi);
+ HRESULT hRet = AtlGetDllVersion(_T("shell32.dll"), &dvi);
+
+ if(SUCCEEDED(hRet))
+ {
+ *pdwMajor = dvi.dwMajorVersion;
+ *pdwMinor = dvi.dwMinorVersion;
+ }
+ else if(hRet == E_NOTIMPL)
+ {
+ // If DllGetVersion is not there, then the DLL is a version
+ // previous to the one shipped with IE 4.x
+ *pdwMajor = 4;
+ *pdwMinor = 0;
+ hRet = S_OK;
+ }
+
+ return hRet;
+}
+
+} // namespace ATL
+
+#endif // (_ATL_VER >= 0x0B00)
+
+
+// These are always included
+#include "atlwinx.h"
+#include "atluser.h"
+#include "atlgdi.h"
+
+#ifndef _WTL_NO_AUTOMATIC_NAMESPACE
+using namespace WTL;
+#endif // !_WTL_NO_AUTOMATIC_NAMESPACE
+
+#endif // __ATLAPP_H__
diff --git a/Examples/WhisperDesktop/Utils/WTL/atlcrack.h b/Examples/WhisperDesktop/Utils/WTL/atlcrack.h
new file mode 100644
index 0000000..da6a896
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/WTL/atlcrack.h
@@ -0,0 +1,2480 @@
+// Windows Template Library - WTL version 10.0
+// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved.
+//
+// This file is a part of the Windows Template Library.
+// The use and distribution terms for this software are covered by the
+// Microsoft Public License (http://opensource.org/licenses/MS-PL)
+// which can be found in the file MS-PL.txt at the root folder.
+
+#ifndef __ATLCRACK_H__
+#define __ATLCRACK_H__
+
+#pragma once
+
+#ifndef __ATLAPP_H__
+ #error atlcrack.h requires atlapp.h to be included first
+#endif
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Message map macro for cracked handlers
+
+// Note about message maps with cracked handlers:
+// You can use BEGIN_MSG_MAP for classes that derive from CWindowImpl/CDialogImpl,
+// but must use BEGIN_MSG_MAP_EX for classes that don't.
+
+#define BEGIN_MSG_MAP_EX(theClass) \
+public: \
+ BOOL m_bMsgHandled; \
+ /* "handled" management for cracked handlers */ \
+ BOOL IsMsgHandled() const \
+ { \
+ return m_bMsgHandled; \
+ } \
+ void SetMsgHandled(BOOL bHandled) \
+ { \
+ m_bMsgHandled = bHandled; \
+ } \
+ BOOL ProcessWindowMessage(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam, LRESULT& lResult, DWORD dwMsgMapID = 0) \
+ { \
+ BOOL bOldMsgHandled = m_bMsgHandled; \
+ BOOL bRet = _ProcessWindowMessage(hWnd, uMsg, wParam, lParam, lResult, dwMsgMapID); \
+ m_bMsgHandled = bOldMsgHandled; \
+ return bRet; \
+ } \
+ BOOL _ProcessWindowMessage(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam, LRESULT& lResult, DWORD dwMsgMapID) \
+ { \
+ BOOL bHandled = TRUE; \
+ (hWnd); \
+ (uMsg); \
+ (wParam); \
+ (lParam); \
+ (lResult); \
+ (bHandled); \
+ switch(dwMsgMapID) \
+ { \
+ case 0:
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Standard Windows message macros
+
+// int OnCreate(LPCREATESTRUCT lpCreateStruct)
+#define MSG_WM_CREATE(func) \
+ if (uMsg == WM_CREATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((LPCREATESTRUCT)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnInitDialog(CWindow wndFocus, LPARAM lInitParam)
+#define MSG_WM_INITDIALOG(func) \
+ if (uMsg == WM_INITDIALOG) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HWND)wParam, lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnCopyData(CWindow wnd, PCOPYDATASTRUCT pCopyDataStruct)
+#define MSG_WM_COPYDATA(func) \
+ if (uMsg == WM_COPYDATA) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HWND)wParam, (PCOPYDATASTRUCT)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDestroy()
+#define MSG_WM_DESTROY(func) \
+ if (uMsg == WM_DESTROY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMove(CPoint ptPos)
+#define MSG_WM_MOVE(func) \
+ if (uMsg == WM_MOVE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSize(UINT nType, CSize size)
+#define MSG_WM_SIZE(func) \
+ if (uMsg == WM_SIZE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CSize(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnActivate(UINT nState, BOOL bMinimized, CWindow wndOther)
+#define MSG_WM_ACTIVATE(func) \
+ if (uMsg == WM_ACTIVATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)LOWORD(wParam), (BOOL)HIWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSetFocus(CWindow wndOld)
+#define MSG_WM_SETFOCUS(func) \
+ if (uMsg == WM_SETFOCUS) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnKillFocus(CWindow wndFocus)
+#define MSG_WM_KILLFOCUS(func) \
+ if (uMsg == WM_KILLFOCUS) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnEnable(BOOL bEnable)
+#define MSG_WM_ENABLE(func) \
+ if (uMsg == WM_ENABLE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnPaint(CDCHandle dc)
+#define MSG_WM_PAINT(func) \
+ if (uMsg == WM_PAINT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HDC)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnClose()
+#define MSG_WM_CLOSE(func) \
+ if (uMsg == WM_CLOSE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnQueryEndSession(UINT nSource, UINT uLogOff)
+#define MSG_WM_QUERYENDSESSION(func) \
+ if (uMsg == WM_QUERYENDSESSION) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)wParam, (UINT)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnQueryOpen()
+#define MSG_WM_QUERYOPEN(func) \
+ if (uMsg == WM_QUERYOPEN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func(); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnEraseBkgnd(CDCHandle dc)
+#define MSG_WM_ERASEBKGND(func) \
+ if (uMsg == WM_ERASEBKGND) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSysColorChange()
+#define MSG_WM_SYSCOLORCHANGE(func) \
+ if (uMsg == WM_SYSCOLORCHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnEndSession(BOOL bEnding, UINT uLogOff)
+#define MSG_WM_ENDSESSION(func) \
+ if (uMsg == WM_ENDSESSION) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)wParam, (UINT)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnShowWindow(BOOL bShow, UINT nStatus)
+#define MSG_WM_SHOWWINDOW(func) \
+ if (uMsg == WM_SHOWWINDOW) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)wParam, (int)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnCtlColorEdit(CDCHandle dc, CEdit edit)
+#define MSG_WM_CTLCOLOREDIT(func) \
+ if (uMsg == WM_CTLCOLOREDIT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnCtlColorListBox(CDCHandle dc, CListBox listBox)
+#define MSG_WM_CTLCOLORLISTBOX(func) \
+ if (uMsg == WM_CTLCOLORLISTBOX) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnCtlColorBtn(CDCHandle dc, CButton button)
+#define MSG_WM_CTLCOLORBTN(func) \
+ if (uMsg == WM_CTLCOLORBTN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnCtlColorDlg(CDCHandle dc, CWindow wnd)
+#define MSG_WM_CTLCOLORDLG(func) \
+ if (uMsg == WM_CTLCOLORDLG) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnCtlColorScrollBar(CDCHandle dc, CScrollBar scrollBar)
+#define MSG_WM_CTLCOLORSCROLLBAR(func) \
+ if (uMsg == WM_CTLCOLORSCROLLBAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnCtlColorStatic(CDCHandle dc, CStatic wndStatic)
+#define MSG_WM_CTLCOLORSTATIC(func) \
+ if (uMsg == WM_CTLCOLORSTATIC) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSettingChange(UINT uFlags, LPCTSTR lpszSection)
+#define MSG_WM_SETTINGCHANGE(func) \
+ if (uMsg == WM_SETTINGCHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPCTSTR)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDevModeChange(LPCTSTR lpDeviceName)
+#define MSG_WM_DEVMODECHANGE(func) \
+ if (uMsg == WM_DEVMODECHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((LPCTSTR)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnActivateApp(BOOL bActive, DWORD dwThreadID)
+#define MSG_WM_ACTIVATEAPP(func) \
+ if (uMsg == WM_ACTIVATEAPP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)wParam, (DWORD)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnFontChange()
+#define MSG_WM_FONTCHANGE(func) \
+ if (uMsg == WM_FONTCHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnTimeChange()
+#define MSG_WM_TIMECHANGE(func) \
+ if (uMsg == WM_TIMECHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCancelMode()
+#define MSG_WM_CANCELMODE(func) \
+ if (uMsg == WM_CANCELMODE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnSetCursor(CWindow wnd, UINT nHitTest, UINT message)
+#define MSG_WM_SETCURSOR(func) \
+ if (uMsg == WM_SETCURSOR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HWND)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnMouseActivate(CWindow wndTopLevel, UINT nHitTest, UINT message)
+#define MSG_WM_MOUSEACTIVATE(func) \
+ if (uMsg == WM_MOUSEACTIVATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HWND)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnChildActivate()
+#define MSG_WM_CHILDACTIVATE(func) \
+ if (uMsg == WM_CHILDACTIVATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnGetMinMaxInfo(LPMINMAXINFO lpMMI)
+#define MSG_WM_GETMINMAXINFO(func) \
+ if (uMsg == WM_GETMINMAXINFO) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((LPMINMAXINFO)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnIconEraseBkgnd(CDCHandle dc)
+#define MSG_WM_ICONERASEBKGND(func) \
+ if (uMsg == WM_ICONERASEBKGND) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HDC)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSpoolerStatus(UINT nStatus, UINT nJobs)
+#define MSG_WM_SPOOLERSTATUS(func) \
+ if (uMsg == WM_SPOOLERSTATUS) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (UINT)LOWORD(lParam)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDrawItem(int nIDCtl, LPDRAWITEMSTRUCT lpDrawItemStruct)
+#define MSG_WM_DRAWITEM(func) \
+ if (uMsg == WM_DRAWITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPDRAWITEMSTRUCT)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMeasureItem(int nIDCtl, LPMEASUREITEMSTRUCT lpMeasureItemStruct)
+#define MSG_WM_MEASUREITEM(func) \
+ if (uMsg == WM_MEASUREITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPMEASUREITEMSTRUCT)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDeleteItem(int nIDCtl, LPDELETEITEMSTRUCT lpDeleteItemStruct)
+#define MSG_WM_DELETEITEM(func) \
+ if (uMsg == WM_DELETEITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPDELETEITEMSTRUCT)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+//int OnCharToItem(UINT nChar, UINT nIndex, CListBox listBox)
+#define MSG_WM_CHARTOITEM(func) \
+ if (uMsg == WM_CHARTOITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnVKeyToItem(UINT nKey, UINT nIndex, CListBox listBox)
+#define MSG_WM_VKEYTOITEM(func) \
+ if (uMsg == WM_VKEYTOITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HCURSOR OnQueryDragIcon()
+#define MSG_WM_QUERYDRAGICON(func) \
+ if (uMsg == WM_QUERYDRAGICON) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func(); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnCompareItem(int nIDCtl, LPCOMPAREITEMSTRUCT lpCompareItemStruct)
+#define MSG_WM_COMPAREITEM(func) \
+ if (uMsg == WM_COMPAREITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)wParam, (LPCOMPAREITEMSTRUCT)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCompacting(UINT nCpuTime)
+#define MSG_WM_COMPACTING(func) \
+ if (uMsg == WM_COMPACTING) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnNcCreate(LPCREATESTRUCT lpCreateStruct)
+#define MSG_WM_NCCREATE(func) \
+ if (uMsg == WM_NCCREATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((LPCREATESTRUCT)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcDestroy()
+#define MSG_WM_NCDESTROY(func) \
+ if (uMsg == WM_NCDESTROY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnNcCalcSize(BOOL bCalcValidRects, LPARAM lParam)
+#define MSG_WM_NCCALCSIZE(func) \
+ if (uMsg == WM_NCCALCSIZE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((BOOL)wParam, lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// UINT OnNcHitTest(CPoint point)
+#define MSG_WM_NCHITTEST(func) \
+ if (uMsg == WM_NCHITTEST) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func(::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcPaint(CRgnHandle rgn)
+#define MSG_WM_NCPAINT(func) \
+ if (uMsg == WM_NCPAINT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HRGN)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnNcActivate(BOOL bActive)
+#define MSG_WM_NCACTIVATE(func) \
+ if (uMsg == WM_NCACTIVATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((BOOL)wParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// UINT OnGetDlgCode(LPMSG lpMsg)
+#define MSG_WM_GETDLGCODE(func) \
+ if (uMsg == WM_GETDLGCODE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((LPMSG)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcMouseMove(UINT nHitTest, CPoint point)
+#define MSG_WM_NCMOUSEMOVE(func) \
+ if (uMsg == WM_NCMOUSEMOVE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcLButtonDown(UINT nHitTest, CPoint point)
+#define MSG_WM_NCLBUTTONDOWN(func) \
+ if (uMsg == WM_NCLBUTTONDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcLButtonUp(UINT nHitTest, CPoint point)
+#define MSG_WM_NCLBUTTONUP(func) \
+ if (uMsg == WM_NCLBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcLButtonDblClk(UINT nHitTest, CPoint point)
+#define MSG_WM_NCLBUTTONDBLCLK(func) \
+ if (uMsg == WM_NCLBUTTONDBLCLK) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcRButtonDown(UINT nHitTest, CPoint point)
+#define MSG_WM_NCRBUTTONDOWN(func) \
+ if (uMsg == WM_NCRBUTTONDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcRButtonUp(UINT nHitTest, CPoint point)
+#define MSG_WM_NCRBUTTONUP(func) \
+ if (uMsg == WM_NCRBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcRButtonDblClk(UINT nHitTest, CPoint point)
+#define MSG_WM_NCRBUTTONDBLCLK(func) \
+ if (uMsg == WM_NCRBUTTONDBLCLK) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcMButtonDown(UINT nHitTest, CPoint point)
+#define MSG_WM_NCMBUTTONDOWN(func) \
+ if (uMsg == WM_NCMBUTTONDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcMButtonUp(UINT nHitTest, CPoint point)
+#define MSG_WM_NCMBUTTONUP(func) \
+ if (uMsg == WM_NCMBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcMButtonDblClk(UINT nHitTest, CPoint point)
+#define MSG_WM_NCMBUTTONDBLCLK(func) \
+ if (uMsg == WM_NCMBUTTONDBLCLK) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnKeyDown(UINT nChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_KEYDOWN(func) \
+ if (uMsg == WM_KEYDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnKeyUp(UINT nChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_KEYUP(func) \
+ if (uMsg == WM_KEYUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnChar(TCHAR chChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_CHAR(func) \
+ if (uMsg == WM_CHAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDeadChar(TCHAR chChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_DEADCHAR(func) \
+ if (uMsg == WM_DEADCHAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSysKeyDown(UINT nChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_SYSKEYDOWN(func) \
+ if (uMsg == WM_SYSKEYDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSysKeyUp(UINT nChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_SYSKEYUP(func) \
+ if (uMsg == WM_SYSKEYUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSysChar(TCHAR chChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_SYSCHAR(func) \
+ if (uMsg == WM_SYSCHAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSysDeadChar(TCHAR chChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_SYSDEADCHAR(func) \
+ if (uMsg == WM_SYSDEADCHAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSysCommand(UINT nID, CPoint point)
+#define MSG_WM_SYSCOMMAND(func) \
+ if (uMsg == WM_SYSCOMMAND) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnTCard(UINT idAction, DWORD dwActionData)
+#define MSG_WM_TCARD(func) \
+ if (uMsg == WM_TCARD) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (DWORD)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnTimer(UINT_PTR nIDEvent)
+#define MSG_WM_TIMER(func) \
+ if (uMsg == WM_TIMER) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT_PTR)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnHScroll(UINT nSBCode, UINT nPos, CScrollBar pScrollBar)
+#define MSG_WM_HSCROLL(func) \
+ if (uMsg == WM_HSCROLL) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((int)LOWORD(wParam), (short)HIWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnVScroll(UINT nSBCode, UINT nPos, CScrollBar pScrollBar)
+#define MSG_WM_VSCROLL(func) \
+ if (uMsg == WM_VSCROLL) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((int)LOWORD(wParam), (short)HIWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnInitMenu(CMenuHandle menu)
+#define MSG_WM_INITMENU(func) \
+ if (uMsg == WM_INITMENU) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HMENU)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnInitMenuPopup(CMenuHandle menuPopup, UINT nIndex, BOOL bSysMenu)
+#define MSG_WM_INITMENUPOPUP(func) \
+ if (uMsg == WM_INITMENUPOPUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HMENU)wParam, (UINT)LOWORD(lParam), (BOOL)HIWORD(lParam)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMenuSelect(UINT nItemID, UINT nFlags, CMenuHandle menu)
+#define MSG_WM_MENUSELECT(func) \
+ if (uMsg == WM_MENUSELECT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HMENU)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnMenuChar(UINT nChar, UINT nFlags, CMenuHandle menu)
+#define MSG_WM_MENUCHAR(func) \
+ if (uMsg == WM_MENUCHAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((TCHAR)LOWORD(wParam), (UINT)HIWORD(wParam), (HMENU)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnNotify(int idCtrl, LPNMHDR pnmh)
+#define MSG_WM_NOTIFY(func) \
+ if (uMsg == WM_NOTIFY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((int)wParam, (LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnEnterIdle(UINT nWhy, CWindow wndWho)
+#define MSG_WM_ENTERIDLE(func) \
+ if (uMsg == WM_ENTERIDLE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMouseMove(UINT nFlags, CPoint point)
+#define MSG_WM_MOUSEMOVE(func) \
+ if (uMsg == WM_MOUSEMOVE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnMouseWheel(UINT nFlags, short zDelta, CPoint pt)
+#define MSG_WM_MOUSEWHEEL(func) \
+ if (uMsg == WM_MOUSEWHEEL) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)LOWORD(wParam), (short)HIWORD(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnLButtonDown(UINT nFlags, CPoint point)
+#define MSG_WM_LBUTTONDOWN(func) \
+ if (uMsg == WM_LBUTTONDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnLButtonUp(UINT nFlags, CPoint point)
+#define MSG_WM_LBUTTONUP(func) \
+ if (uMsg == WM_LBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnLButtonDblClk(UINT nFlags, CPoint point)
+#define MSG_WM_LBUTTONDBLCLK(func) \
+ if (uMsg == WM_LBUTTONDBLCLK) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnRButtonDown(UINT nFlags, CPoint point)
+#define MSG_WM_RBUTTONDOWN(func) \
+ if (uMsg == WM_RBUTTONDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnRButtonUp(UINT nFlags, CPoint point)
+#define MSG_WM_RBUTTONUP(func) \
+ if (uMsg == WM_RBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnRButtonDblClk(UINT nFlags, CPoint point)
+#define MSG_WM_RBUTTONDBLCLK(func) \
+ if (uMsg == WM_RBUTTONDBLCLK) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMButtonDown(UINT nFlags, CPoint point)
+#define MSG_WM_MBUTTONDOWN(func) \
+ if (uMsg == WM_MBUTTONDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMButtonUp(UINT nFlags, CPoint point)
+#define MSG_WM_MBUTTONUP(func) \
+ if (uMsg == WM_MBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMButtonDblClk(UINT nFlags, CPoint point)
+#define MSG_WM_MBUTTONDBLCLK(func) \
+ if (uMsg == WM_MBUTTONDBLCLK) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnParentNotify(UINT message, UINT nChildID, LPARAM lParam)
+#define MSG_WM_PARENTNOTIFY(func) \
+ if (uMsg == WM_PARENTNOTIFY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMDIActivate(CWindow wndDeactivate, CWindow wndActivate)
+#define MSG_WM_MDIACTIVATE(func) \
+ if (uMsg == WM_MDIACTIVATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam, (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnRenderFormat(UINT nFormat)
+#define MSG_WM_RENDERFORMAT(func) \
+ if (uMsg == WM_RENDERFORMAT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnRenderAllFormats()
+#define MSG_WM_RENDERALLFORMATS(func) \
+ if (uMsg == WM_RENDERALLFORMATS) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDestroyClipboard()
+#define MSG_WM_DESTROYCLIPBOARD(func) \
+ if (uMsg == WM_DESTROYCLIPBOARD) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDrawClipboard()
+#define MSG_WM_DRAWCLIPBOARD(func) \
+ if (uMsg == WM_DRAWCLIPBOARD) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnPaintClipboard(CWindow wndViewer, const LPPAINTSTRUCT lpPaintStruct)
+#define MSG_WM_PAINTCLIPBOARD(func) \
+ if (uMsg == WM_PAINTCLIPBOARD) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam, (const LPPAINTSTRUCT)::GlobalLock((HGLOBAL)lParam)); \
+ ::GlobalUnlock((HGLOBAL)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnVScrollClipboard(CWindow wndViewer, UINT nSBCode, UINT nPos)
+#define MSG_WM_VSCROLLCLIPBOARD(func) \
+ if (uMsg == WM_VSCROLLCLIPBOARD) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnContextMenu(CWindow wnd, CPoint point)
+#define MSG_WM_CONTEXTMENU(func) \
+ if (uMsg == WM_CONTEXTMENU) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSizeClipboard(CWindow wndViewer, const LPRECT lpRect)
+#define MSG_WM_SIZECLIPBOARD(func) \
+ if (uMsg == WM_SIZECLIPBOARD) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam, (const LPRECT)::GlobalLock((HGLOBAL)lParam)); \
+ ::GlobalUnlock((HGLOBAL)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnAskCbFormatName(UINT nMaxCount, LPTSTR lpszString)
+#define MSG_WM_ASKCBFORMATNAME(func) \
+ if (uMsg == WM_ASKCBFORMATNAME) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPTSTR)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnChangeCbChain(CWindow wndRemove, CWindow wndAfter)
+#define MSG_WM_CHANGECBCHAIN(func) \
+ if (uMsg == WM_CHANGECBCHAIN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam, (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnHScrollClipboard(CWindow wndViewer, UINT nSBCode, UINT nPos)
+#define MSG_WM_HSCROLLCLIPBOARD(func) \
+ if (uMsg == WM_HSCROLLCLIPBOARD) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnQueryNewPalette()
+#define MSG_WM_QUERYNEWPALETTE(func) \
+ if (uMsg == WM_QUERYNEWPALETTE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func(); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnPaletteChanged(CWindow wndFocus)
+#define MSG_WM_PALETTECHANGED(func) \
+ if (uMsg == WM_PALETTECHANGED) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnPaletteIsChanging(CWindow wndPalChg)
+#define MSG_WM_PALETTEISCHANGING(func) \
+ if (uMsg == WM_PALETTEISCHANGING) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDropFiles(HDROP hDropInfo)
+#define MSG_WM_DROPFILES(func) \
+ if (uMsg == WM_DROPFILES) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HDROP)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnWindowPosChanging(LPWINDOWPOS lpWndPos)
+#define MSG_WM_WINDOWPOSCHANGING(func) \
+ if (uMsg == WM_WINDOWPOSCHANGING) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((LPWINDOWPOS)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnWindowPosChanged(LPWINDOWPOS lpWndPos)
+#define MSG_WM_WINDOWPOSCHANGED(func) \
+ if (uMsg == WM_WINDOWPOSCHANGED) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((LPWINDOWPOS)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnExitMenuLoop(BOOL fIsTrackPopupMenu)
+#define MSG_WM_EXITMENULOOP(func) \
+ if (uMsg == WM_EXITMENULOOP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnEnterMenuLoop(BOOL fIsTrackPopupMenu)
+#define MSG_WM_ENTERMENULOOP(func) \
+ if (uMsg == WM_ENTERMENULOOP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnStyleChanged(int nStyleType, LPSTYLESTRUCT lpStyleStruct)
+#define MSG_WM_STYLECHANGED(func) \
+ if (uMsg == WM_STYLECHANGED) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPSTYLESTRUCT)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnStyleChanging(int nStyleType, LPSTYLESTRUCT lpStyleStruct)
+#define MSG_WM_STYLECHANGING(func) \
+ if (uMsg == WM_STYLECHANGING) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPSTYLESTRUCT)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSizing(UINT fwSide, LPRECT pRect)
+#define MSG_WM_SIZING(func) \
+ if (uMsg == WM_SIZING) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPRECT)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMoving(UINT fwSide, LPRECT pRect)
+#define MSG_WM_MOVING(func) \
+ if (uMsg == WM_MOVING) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPRECT)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCaptureChanged(CWindow wnd)
+#define MSG_WM_CAPTURECHANGED(func) \
+ if (uMsg == WM_CAPTURECHANGED) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnDeviceChange(UINT nEventType, DWORD_PTR dwData)
+#define MSG_WM_DEVICECHANGE(func) \
+ if (uMsg == WM_DEVICECHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)wParam, (DWORD_PTR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCommand(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define MSG_WM_COMMAND(func) \
+ if (uMsg == WM_COMMAND) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDisplayChange(UINT uBitsPerPixel, CSize sizeScreen)
+#define MSG_WM_DISPLAYCHANGE(func) \
+ if (uMsg == WM_DISPLAYCHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CSize(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnEnterSizeMove()
+#define MSG_WM_ENTERSIZEMOVE(func) \
+ if (uMsg == WM_ENTERSIZEMOVE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnExitSizeMove()
+#define MSG_WM_EXITSIZEMOVE(func) \
+ if (uMsg == WM_EXITSIZEMOVE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HFONT OnGetFont()
+#define MSG_WM_GETFONT(func) \
+ if (uMsg == WM_GETFONT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func(); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnGetHotKey()
+#define MSG_WM_GETHOTKEY(func) \
+ if (uMsg == WM_GETHOTKEY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func(); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HICON OnGetIcon()
+#define MSG_WM_GETICON(func) \
+ if (uMsg == WM_GETICON) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)wParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnGetText(int cchTextMax, LPTSTR lpszText)
+#define MSG_WM_GETTEXT(func) \
+ if (uMsg == WM_GETTEXT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((int)wParam, (LPTSTR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnGetTextLength()
+#define MSG_WM_GETTEXTLENGTH(func) \
+ if (uMsg == WM_GETTEXTLENGTH) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func(); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnHelp(LPHELPINFO lpHelpInfo)
+#define MSG_WM_HELP(func) \
+ if (uMsg == WM_HELP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((LPHELPINFO)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnHotKey(int nHotKeyID, UINT uModifiers, UINT uVirtKey)
+#define MSG_WM_HOTKEY(func) \
+ if (uMsg == WM_HOTKEY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((int)wParam, (UINT)LOWORD(lParam), (UINT)HIWORD(lParam)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnInputLangChange(DWORD dwCharSet, HKL hKbdLayout)
+#define MSG_WM_INPUTLANGCHANGE(func) \
+ if (uMsg == WM_INPUTLANGCHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((DWORD)wParam, (HKL)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnInputLangChangeRequest(BOOL bSysCharSet, HKL hKbdLayout)
+#define MSG_WM_INPUTLANGCHANGEREQUEST(func) \
+ if (uMsg == WM_INPUTLANGCHANGEREQUEST) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)wParam, (HKL)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNextDlgCtl(BOOL bHandle, WPARAM wCtlFocus)
+#define MSG_WM_NEXTDLGCTL(func) \
+ if (uMsg == WM_NEXTDLGCTL) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)LOWORD(lParam), wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNextMenu(int nVirtKey, LPMDINEXTMENU lpMdiNextMenu)
+#define MSG_WM_NEXTMENU(func) \
+ if (uMsg == WM_NEXTMENU) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((int)wParam, (LPMDINEXTMENU)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnNotifyFormat(CWindow wndFrom, int nCommand)
+#define MSG_WM_NOTIFYFORMAT(func) \
+ if (uMsg == WM_NOTIFYFORMAT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HWND)wParam, (int)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnPowerBroadcast(DWORD dwPowerEvent, DWORD_PTR dwData)
+#define MSG_WM_POWERBROADCAST(func) \
+ if (uMsg == WM_POWERBROADCAST) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((DWORD)wParam, (DWORD_PTR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnPrint(CDCHandle dc, UINT uFlags)
+#define MSG_WM_PRINT(func) \
+ if (uMsg == WM_PRINT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HDC)wParam, (UINT)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnPrintClient(CDCHandle dc, UINT uFlags)
+#define MSG_WM_PRINTCLIENT(func) \
+ if (uMsg == WM_PRINTCLIENT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HDC)wParam, (UINT)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnRasDialEvent(RASCONNSTATE rasconnstate, DWORD dwError)
+#define MSG_WM_RASDIALEVENT(func) \
+ if (uMsg == WM_RASDIALEVENT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((RASCONNSTATE)wParam, (DWORD)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSetFont(CFontHandle font, BOOL bRedraw)
+#define MSG_WM_SETFONT(func) \
+ if (uMsg == WM_SETFONT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((HFONT)wParam, (BOOL)LOWORD(lParam)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnSetHotKey(int nVirtKey, UINT uFlags)
+#define MSG_WM_SETHOTKEY(func) \
+ if (uMsg == WM_SETHOTKEY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((int)LOBYTE(LOWORD(wParam)), (UINT)HIBYTE(LOWORD(wParam))); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HICON OnSetIcon(UINT uType, HICON hIcon)
+#define MSG_WM_SETICON(func) \
+ if (uMsg == WM_SETICON) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)wParam, (HICON)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnSetRedraw(BOOL bRedraw)
+#define MSG_WM_SETREDRAW(func) \
+ if (uMsg == WM_SETREDRAW) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((BOOL)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnSetText(LPCTSTR lpstrText)
+#define MSG_WM_SETTEXT(func) \
+ if (uMsg == WM_SETTEXT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((LPCTSTR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnUserChanged()
+#define MSG_WM_USERCHANGED(func) \
+ if (uMsg == WM_USERCHANGED) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+///////////////////////////////////////////////////////////////////////////////
+// Newer Windows messages
+
+// void OnMouseHover(WPARAM wParam, CPoint ptPos)
+#define MSG_WM_MOUSEHOVER(func) \
+ if (uMsg == WM_MOUSEHOVER) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(wParam, ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMouseLeave()
+#define MSG_WM_MOUSELEAVE(func) \
+ if (uMsg == WM_MOUSELEAVE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcMouseHover(UINT nHitTest, CPoint ptPos)
+#define MSG_WM_NCMOUSEHOVER(func) \
+ if (uMsg == WM_NCMOUSEHOVER) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, ::CPoint(MAKEPOINTS(lParam).x, MAKEPOINTS(lParam).y)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNcMouseLeave()
+#define MSG_WM_NCMOUSELEAVE(func) \
+ if (uMsg == WM_NCMOUSELEAVE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMenuRButtonUp(WPARAM wParam, CMenuHandle menu)
+#define MSG_WM_MENURBUTTONUP(func) \
+ if (uMsg == WM_MENURBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(wParam, (HMENU)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnMenuDrag(WPARAM wParam, CMenuHandle menu)
+#define MSG_WM_MENUDRAG(func) \
+ if (uMsg == WM_MENUDRAG) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func(wParam, (HMENU)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnMenuGetObject(PMENUGETOBJECTINFO info)
+#define MSG_WM_MENUGETOBJECT(func) \
+ if (uMsg == WM_MENUGETOBJECT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((PMENUGETOBJECTINFO)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnUnInitMenuPopup(UINT nID, CMenuHandle menu)
+#define MSG_WM_UNINITMENUPOPUP(func) \
+ if (uMsg == WM_UNINITMENUPOPUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(lParam), (HMENU)wParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnMenuCommand(WPARAM nIndex, CMenuHandle menu)
+#define MSG_WM_MENUCOMMAND(func) \
+ if (uMsg == WM_MENUCOMMAND) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(wParam, (HMENU)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnAppCommand(CWindow wndFocus, short cmd, WORD uDevice, int dwKeys)
+#define MSG_WM_APPCOMMAND(func) \
+ if (uMsg == WM_APPCOMMAND) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HWND)wParam, GET_APPCOMMAND_LPARAM(lParam), GET_DEVICE_LPARAM(lParam), GET_KEYSTATE_LPARAM(lParam)); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNCXButtonDown(int fwButton, short nHittest, CPoint ptPos)
+#define MSG_WM_NCXBUTTONDOWN(func) \
+ if (uMsg == WM_NCXBUTTONDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(GET_XBUTTON_WPARAM(wParam), GET_NCHITTEST_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = (LRESULT)TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNCXButtonUp(int fwButton, short nHittest, CPoint ptPos)
+#define MSG_WM_NCXBUTTONUP(func) \
+ if (uMsg == WM_NCXBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(GET_XBUTTON_WPARAM(wParam), GET_NCHITTEST_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = (LRESULT)TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnNCXButtonDblClk(int fwButton, short nHittest, CPoint ptPos)
+#define MSG_WM_NCXBUTTONDBLCLK(func) \
+ if (uMsg == WM_NCXBUTTONDBLCLK) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(GET_XBUTTON_WPARAM(wParam), GET_NCHITTEST_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = (LRESULT)TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnXButtonDown(int fwButton, int dwKeys, CPoint ptPos)
+#define MSG_WM_XBUTTONDOWN(func) \
+ if (uMsg == WM_XBUTTONDOWN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(GET_XBUTTON_WPARAM(wParam), GET_KEYSTATE_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = (LRESULT)TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnXButtonUp(int fwButton, int dwKeys, CPoint ptPos)
+#define MSG_WM_XBUTTONUP(func) \
+ if (uMsg == WM_XBUTTONUP) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(GET_XBUTTON_WPARAM(wParam), GET_KEYSTATE_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = (LRESULT)TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnXButtonDblClk(int fwButton, int dwKeys, CPoint ptPos)
+#define MSG_WM_XBUTTONDBLCLK(func) \
+ if (uMsg == WM_XBUTTONDBLCLK) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(GET_XBUTTON_WPARAM(wParam), GET_KEYSTATE_WPARAM(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ lResult = (LRESULT)TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnChangeUIState(WORD nAction, WORD nState)
+#define MSG_WM_CHANGEUISTATE(func) \
+ if (uMsg == WM_CHANGEUISTATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(LOWORD(wParam), HIWORD(wParam)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnUpdateUIState(WORD nAction, WORD nState)
+#define MSG_WM_UPDATEUISTATE(func) \
+ if (uMsg == WM_UPDATEUISTATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(LOWORD(wParam), HIWORD(wParam)); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnQueryUIState()
+#define MSG_WM_QUERYUISTATE(func) \
+ if (uMsg == WM_QUERYUISTATE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func(); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnInput(WPARAM RawInputCode, HRAWINPUT hRawInput)
+#define MSG_WM_INPUT(func) \
+ if (uMsg == WM_INPUT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(GET_RAWINPUT_CODE_WPARAM(wParam), (HRAWINPUT)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnUniChar(TCHAR nChar, UINT nRepCnt, UINT nFlags)
+#define MSG_WM_UNICHAR(func) \
+ if (uMsg == WM_UNICHAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((TCHAR)wParam, (UINT)lParam & 0xFFFF, (UINT)((lParam & 0xFFFF0000) >> 16)); \
+ if(this->IsMsgHandled()) \
+ { \
+ lResult = (wParam == UNICODE_NOCHAR) ? TRUE : FALSE; \
+ return TRUE; \
+ } \
+ }
+
+// void OnWTSSessionChange(WPARAM nStatusCode, DWORD dwSessionID)
+#define MSG_WM_WTSSESSION_CHANGE(func) \
+ if (uMsg == WM_WTSSESSION_CHANGE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(wParam, (DWORD)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnThemeChanged()
+#define MSG_WM_THEMECHANGED(func) \
+ if (uMsg == WM_THEMECHANGED) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+
+// BOOL OnMouseHWheel(UINT nFlags, short zDelta, CPoint pt)
+#define MSG_WM_MOUSEHWHEEL(func) \
+ if (uMsg == WM_MOUSEHWHEEL) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)LOWORD(wParam), (short)HIWORD(wParam), ::CPoint(GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam))); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+#endif // (_WIN32_WINNT >= 0x0600)
+
+#if (WINVER >= 0x0601)
+
+// void OnGesture(ULONGLONG ullArguments, HGESTUREINFO hGestureInfo)
+#define MSG_WM_GESTURE(func) \
+ if (uMsg == WM_GESTURE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((ULONGLONG)wParam, (HGESTUREINFO)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnGestureNotify(PGESTURENOTIFYSTRUCT pGestureNotifyStruct)
+#define MSG_WM_GESTURENOTIFY(func) \
+ if (uMsg == WM_GESTURENOTIFY) \
+ { \
+ func((PGESTURENOTIFYSTRUCT)lParam); \
+ }
+
+// void OnDpiChanged(UINT nDpiX, UINT nDpiY, PRECT pRect)
+#define MSG_WM_DPICHANGED(func) \
+ if (uMsg == WM_DPICHANGED) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (PRECT)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+#endif // (WINVER >= 0x0601)
+
+#if (WINVER >= 0x0605)
+
+// void OnDpiChangedBeforeParent()
+#define MSG_WM_DPICHANGED_BEFOREPARENT(func) \
+ if (uMsg == WM_DPICHANGED_BEFOREPARENT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDpiChangedAfterParent()
+#define MSG_WM_DPICHANGED_AFTERPARENT(func) \
+ if (uMsg == WM_DPICHANGED_AFTERPARENT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// BOOL OnGetDpiScaledSize(UINT uDpi, PSIZE pSize)
+#define MSG_WM_GETDPISCALEDSIZE(func) \
+if (uMsg == WM_GETDPISCALEDSIZE) \
+{ \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)wParam, (PSIZE)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+}
+
+#endif // (WINVER >= 0x0605)
+
+///////////////////////////////////////////////////////////////////////////////
+// ATL defined messages
+
+// BOOL OnForwardMsg(LPMSG Msg, DWORD nUserData)
+#define MSG_WM_FORWARDMSG(func) \
+ if (uMsg == WM_FORWARDMSG) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((LPMSG)lParam, (DWORD)wParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+///////////////////////////////////////////////////////////////////////////////
+// Dialog specific messages
+
+// LRESULT OnDMGetDefID()
+#define MSG_DM_GETDEFID(func) \
+ if (uMsg == DM_GETDEFID) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func(); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDMSetDefID(UINT DefID)
+#define MSG_DM_SETDEFID(func) \
+ if (uMsg == DM_SETDEFID) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnDMReposition()
+#define MSG_DM_REPOSITION(func) \
+ if (uMsg == DM_REPOSITION) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+///////////////////////////////////////////////////////////////////////////////
+// Reflected messages
+
+// void OnReflectedCommand(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define MSG_OCM_COMMAND(func) \
+ if (uMsg == OCM_COMMAND) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedNotify(int idCtrl, LPNMHDR pnmh)
+#define MSG_OCM_NOTIFY(func) \
+ if (uMsg == OCM_NOTIFY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((int)wParam, (LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnReflectedParentNotify(UINT message, UINT nChildID, LPARAM lParam)
+#define MSG_OCM_PARENTNOTIFY(func) \
+ if (uMsg == OCM_PARENTNOTIFY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnReflectedDrawItem(int nIDCtl, LPDRAWITEMSTRUCT lpDrawItemStruct)
+#define MSG_OCM_DRAWITEM(func) \
+ if (uMsg == OCM_DRAWITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPDRAWITEMSTRUCT)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnReflectedMeasureItem(int nIDCtl, LPMEASUREITEMSTRUCT lpMeasureItemStruct)
+#define MSG_OCM_MEASUREITEM(func) \
+ if (uMsg == OCM_MEASUREITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPMEASUREITEMSTRUCT)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnReflectedCompareItem(int nIDCtl, LPCOMPAREITEMSTRUCT lpCompareItemStruct)
+#define MSG_OCM_COMPAREITEM(func) \
+ if (uMsg == OCM_COMPAREITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)wParam, (LPCOMPAREITEMSTRUCT)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnReflectedDeleteItem(int nIDCtl, LPDELETEITEMSTRUCT lpDeleteItemStruct)
+#define MSG_OCM_DELETEITEM(func) \
+ if (uMsg == OCM_DELETEITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)wParam, (LPDELETEITEMSTRUCT)lParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// int OnReflectedVKeyToItem(UINT nKey, UINT nIndex, CListBox listBox)
+#define MSG_OCM_VKEYTOITEM(func) \
+ if (uMsg == OCM_VKEYTOITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+//int OnReflectedCharToItem(UINT nChar, UINT nIndex, CListBox listBox)
+#define MSG_OCM_CHARTOITEM(func) \
+ if (uMsg == OCM_CHARTOITEM) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((UINT)LOWORD(wParam), (UINT)HIWORD(wParam), (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnReflectedHScroll(UINT nSBCode, UINT nPos, CScrollBar pScrollBar)
+#define MSG_OCM_HSCROLL(func) \
+ if (uMsg == OCM_HSCROLL) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((int)LOWORD(wParam), (short)HIWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnReflectedVScroll(UINT nSBCode, UINT nPos, CScrollBar pScrollBar)
+#define MSG_OCM_VSCROLL(func) \
+ if (uMsg == OCM_VSCROLL) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((int)LOWORD(wParam), (short)HIWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnReflectedCtlColorEdit(CDCHandle dc, CEdit edit)
+#define MSG_OCM_CTLCOLOREDIT(func) \
+ if (uMsg == OCM_CTLCOLOREDIT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnReflectedCtlColorListBox(CDCHandle dc, CListBox listBox)
+#define MSG_OCM_CTLCOLORLISTBOX(func) \
+ if (uMsg == OCM_CTLCOLORLISTBOX) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnReflectedCtlColorBtn(CDCHandle dc, CButton button)
+#define MSG_OCM_CTLCOLORBTN(func) \
+ if (uMsg == OCM_CTLCOLORBTN) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnReflectedCtlColorDlg(CDCHandle dc, CWindow wnd)
+#define MSG_OCM_CTLCOLORDLG(func) \
+ if (uMsg == OCM_CTLCOLORDLG) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnReflectedCtlColorScrollBar(CDCHandle dc, CScrollBar scrollBar)
+#define MSG_OCM_CTLCOLORSCROLLBAR(func) \
+ if (uMsg == OCM_CTLCOLORSCROLLBAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// HBRUSH OnReflectedCtlColorStatic(CDCHandle dc, CStatic wndStatic)
+#define MSG_OCM_CTLCOLORSTATIC(func) \
+ if (uMsg == OCM_CTLCOLORSTATIC) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = (LRESULT)func((HDC)wParam, (HWND)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+///////////////////////////////////////////////////////////////////////////////
+// Edit specific messages
+
+// void OnClear()
+#define MSG_WM_CLEAR(func) \
+ if (uMsg == WM_CLEAR) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCopy()
+#define MSG_WM_COPY(func) \
+ if (uMsg == WM_COPY) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCut()
+#define MSG_WM_CUT(func) \
+ if (uMsg == WM_CUT) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnPaste()
+#define MSG_WM_PASTE(func) \
+ if (uMsg == WM_PASTE) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnUndo()
+#define MSG_WM_UNDO(func) \
+ if (uMsg == WM_UNDO) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+///////////////////////////////////////////////////////////////////////////////
+// Generic message handlers
+
+// LRESULT OnMessageHandlerEX(UINT uMsg, WPARAM wParam, LPARAM lParam)
+#define MESSAGE_HANDLER_EX(msg, func) \
+ if(uMsg == msg) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func(uMsg, wParam, lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnMessageRangeHandlerEX(UINT uMsg, WPARAM wParam, LPARAM lParam)
+#define MESSAGE_RANGE_HANDLER_EX(msgFirst, msgLast, func) \
+ if((uMsg >= msgFirst) && (uMsg <= msgLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func(uMsg, wParam, lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+///////////////////////////////////////////////////////////////////////////////
+// Commands and notifications
+
+// void OnCommandHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define COMMAND_HANDLER_EX(id, code, func) \
+ if ((uMsg == WM_COMMAND) && (code == HIWORD(wParam)) && (id == LOWORD(wParam))) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCommandIDHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define COMMAND_ID_HANDLER_EX(id, func) \
+ if ((uMsg == WM_COMMAND) && (id == LOWORD(wParam))) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCommandCodeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define COMMAND_CODE_HANDLER_EX(code, func) \
+ if ((uMsg == WM_COMMAND) && (code == HIWORD(wParam))) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnNotifyHandlerEX(LPNMHDR pnmh)
+#define NOTIFY_HANDLER_EX(id, cd, func) \
+ if ((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (id == ((LPNMHDR)lParam)->idFrom)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnNotifyIDHandlerEX(LPNMHDR pnmh)
+#define NOTIFY_ID_HANDLER_EX(id, func) \
+ if ((uMsg == WM_NOTIFY) && (id == ((LPNMHDR)lParam)->idFrom)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnNotifyCodeHandlerEX(LPNMHDR pnmh)
+#define NOTIFY_CODE_HANDLER_EX(cd, func) \
+ if ((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCommandRangeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define COMMAND_RANGE_HANDLER_EX(idFirst, idLast, func) \
+ if((uMsg == WM_COMMAND) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnCommandRangeCodeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define COMMAND_RANGE_CODE_HANDLER_EX(idFirst, idLast, code, func) \
+ if((uMsg == WM_COMMAND) && (code == HIWORD(wParam)) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnNotifyRangeHandlerEX(LPNMHDR pnmh)
+#define NOTIFY_RANGE_HANDLER_EX(idFirst, idLast, func) \
+ if((uMsg == WM_NOTIFY) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnNotifyRangeCodeHandlerEX(LPNMHDR pnmh)
+#define NOTIFY_RANGE_CODE_HANDLER_EX(idFirst, idLast, cd, func) \
+ if((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedCommandHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define REFLECTED_COMMAND_HANDLER_EX(id, code, func) \
+ if ((uMsg == OCM_COMMAND) && (code == HIWORD(wParam)) && (id == LOWORD(wParam))) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedCommandIDHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define REFLECTED_COMMAND_ID_HANDLER_EX(id, func) \
+ if ((uMsg == OCM_COMMAND) && (id == LOWORD(wParam))) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedCommandCodeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define REFLECTED_COMMAND_CODE_HANDLER_EX(code, func) \
+ if ((uMsg == OCM_COMMAND) && (code == HIWORD(wParam))) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedNotifyHandlerEX(LPNMHDR pnmh)
+#define REFLECTED_NOTIFY_HANDLER_EX(id, cd, func) \
+ if ((uMsg == OCM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (id == ((LPNMHDR)lParam)->idFrom)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedNotifyIDHandlerEX(LPNMHDR pnmh)
+#define REFLECTED_NOTIFY_ID_HANDLER_EX(id, func) \
+ if ((uMsg == OCM_NOTIFY) && (id == ((LPNMHDR)lParam)->idFrom)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedNotifyCodeHandlerEX(LPNMHDR pnmh)
+#define REFLECTED_NOTIFY_CODE_HANDLER_EX(cd, func) \
+ if ((uMsg == OCM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnReflectedCommandRangeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define REFLECTED_COMMAND_RANGE_HANDLER_EX(idFirst, idLast, func) \
+ if((uMsg == OCM_COMMAND) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnReflectedCommandRangeCodeHandlerEX(UINT uNotifyCode, int nID, CWindow wndCtl)
+#define REFLECTED_COMMAND_RANGE_CODE_HANDLER_EX(idFirst, idLast, code, func) \
+ if((uMsg == OCM_COMMAND) && (code == HIWORD(wParam)) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func((UINT)HIWORD(wParam), (int)LOWORD(wParam), (HWND)lParam); \
+ lResult = 0; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedNotifyRangeHandlerEX(LPNMHDR pnmh)
+#define REFLECTED_NOTIFY_RANGE_HANDLER_EX(idFirst, idLast, func) \
+ if((uMsg == OCM_NOTIFY) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// LRESULT OnReflectedNotifyRangeCodeHandlerEX(LPNMHDR pnmh)
+#define REFLECTED_NOTIFY_RANGE_CODE_HANDLER_EX(idFirst, idLast, cd, func) \
+ if((uMsg == OCM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ lResult = func((LPNMHDR)lParam); \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+// void OnAppCommandHandler(UINT uDevice, DWORD dwKeys, CWindow wndFocus)
+#define APPCOMMAND_HANDLER_EX(cmd, func) \
+ if((uMsg == WM_APPCOMMAND) && (cmd == GET_APPCOMMAND_LPARAM(lParam))) \
+ { \
+ this->SetMsgHandled(TRUE); \
+ func(GET_DEVICE_LPARAM(lParam), GET_KEYSTATE_LPARAM(lParam), (HWND)wParam); \
+ lResult = TRUE; \
+ if(this->IsMsgHandled()) \
+ return TRUE; \
+ }
+
+#endif // __ATLCRACK_H__
diff --git a/Examples/WhisperDesktop/Utils/WTL/atlctrls.h b/Examples/WhisperDesktop/Utils/WTL/atlctrls.h
new file mode 100644
index 0000000..61df427
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/WTL/atlctrls.h
@@ -0,0 +1,9764 @@
+// Windows Template Library - WTL version 10.0
+// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved.
+//
+// This file is a part of the Windows Template Library.
+// The use and distribution terms for this software are covered by the
+// Microsoft Public License (http://opensource.org/licenses/MS-PL)
+// which can be found in the file MS-PL.txt at the root folder.
+
+#ifndef __ATLCTRLS_H__
+#define __ATLCTRLS_H__
+
+#pragma once
+
+#ifndef __ATLAPP_H__
+ #error atlctrls.h requires atlapp.h to be included first
+#endif
+
+#ifndef __ATLWIN_H__
+ #error atlctrls.h requires atlwin.h to be included first
+#endif
+
+#include <richedit.h>
+#include <richole.h>
+
+#if (_RICHEDIT_VER < 0x0300)
+ #error WTL10 requires _RICHEDIT_VER >= 0x0300
+#endif
+
+// protect template members from windowsx.h macros
+#ifdef _INC_WINDOWSX
+ #undef GetNextSibling
+ #undef GetPrevSibling
+#endif // _INC_WINDOWSX
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Classes in this file:
+//
+// CStaticT<TBase> - CStatic
+// CButtonT<TBase> - CButton
+// CListBoxT<TBase> - CListBox
+// CComboBoxT<TBase> - CComboBox
+// CEditT<TBase> - CEdit
+// CEditCommands<T>
+// CScrollBarT<TBase> - CScrollBar
+//
+// CImageListT<t_bManaged> - CImageList, CImageListManaged
+// CListViewCtrlT<TBase> - CListViewCtrl
+// CTreeViewCtrlT<TBase> - CTreeViewCtrl
+// CTreeItemT<TBase> - CTreeItem
+// CTreeViewCtrlExT<TBase> - CTreeViewCtrlEx
+// CHeaderCtrlT<TBase> - CHeaderCtrl
+// CToolBarCtrlT<TBase> - CToolBarCtrl
+// CStatusBarCtrlT<TBase> - CStatusBarCtrl
+// CTabCtrlT<TBase> - CTabCtrl
+// CToolInfo
+// CToolTipCtrlT<TBase> - CToolTipCtrl
+// CTrackBarCtrlT<TBase> - CTrackBarCtrl
+// CUpDownCtrlT<TBase> - CUpDownCtrl
+// CProgressBarCtrlT<TBase> - CProgressBarCtrl
+// CHotKeyCtrlT<TBase> - CHotKeyCtrl
+// CAnimateCtrlT<TBase> - CAnimateCtrl
+// CRichEditCtrlT<TBase> - CRichEditCtrl
+// CRichEditCommands<T>
+// CDragListBoxT<TBase> - CDragListBox
+// CDragListNotifyImpl<T>
+// CReBarCtrlT<TBase> - CReBarCtrl
+// CComboBoxExT<TBase> - CComboBoxEx
+// CDateTimePickerCtrlT<TBase> - CDateTimePickerCtrl
+// CMonthCalendarCtrlT<TBase> - CMonthCalendarCtrl
+// CFlatScrollBarImpl<T>
+// CFlatScrollBarT<TBase> - CFlatScrollBar
+// CIPAddressCtrlT<TBase> - CIPAddressCtrl
+// CPagerCtrlT<TBase> - CPagerCtrl
+// CLinkCtrlT<TBase> - CLinkCtrl
+//
+// CCustomDraw<T>
+
+
+namespace WTL
+{
+
+// These are wrapper classes for Windows standard and common controls.
+// To implement a window based on a control, use following:
+// Example: Implementing a window based on a list box
+//
+// class CMyListBox : CWindowImpl<CMyListBox, CListBox>
+// {
+// public:
+// BEGIN_MSG_MAP(CMyListBox)
+// // put your message handler entries here
+// END_MSG_MAP()
+// };
+
+
+
+// --- Standard Windows controls ---
+
+///////////////////////////////////////////////////////////////////////////////
+// CStatic - client side for a Windows STATIC control
+
+template <class TBase>
+class CStaticT : public TBase
+{
+public:
+// Constructors
+ CStaticT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CStaticT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return _T("STATIC");
+ }
+
+ HICON GetIcon() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HICON)::SendMessage(this->m_hWnd, STM_GETICON, 0, 0L);
+ }
+
+ HICON SetIcon(HICON hIcon)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HICON)::SendMessage(this->m_hWnd, STM_SETICON, (WPARAM)hIcon, 0L);
+ }
+
+ HENHMETAFILE GetEnhMetaFile() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HENHMETAFILE)::SendMessage(this->m_hWnd, STM_GETIMAGE, IMAGE_ENHMETAFILE, 0L);
+ }
+
+ HENHMETAFILE SetEnhMetaFile(HENHMETAFILE hMetaFile)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HENHMETAFILE)::SendMessage(this->m_hWnd, STM_SETIMAGE, IMAGE_ENHMETAFILE, (LPARAM)hMetaFile);
+ }
+
+ CBitmapHandle GetBitmap() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CBitmapHandle((HBITMAP)::SendMessage(this->m_hWnd, STM_GETIMAGE, IMAGE_BITMAP, 0L));
+ }
+
+ CBitmapHandle SetBitmap(HBITMAP hBitmap)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CBitmapHandle((HBITMAP)::SendMessage(this->m_hWnd, STM_SETIMAGE, IMAGE_BITMAP, (LPARAM)hBitmap));
+ }
+
+ HCURSOR GetCursor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HCURSOR)::SendMessage(this->m_hWnd, STM_GETIMAGE, IMAGE_CURSOR, 0L);
+ }
+
+ HCURSOR SetCursor(HCURSOR hCursor)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HCURSOR)::SendMessage(this->m_hWnd, STM_SETIMAGE, IMAGE_CURSOR, (LPARAM)hCursor);
+ }
+};
+
+typedef CStaticT<ATL::CWindow> CStatic;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CButton - client side for a Windows BUTTON control
+
+template <class TBase>
+class CButtonT : public TBase
+{
+public:
+// Constructors
+ CButtonT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CButtonT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return _T("BUTTON");
+ }
+
+ UINT GetState() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, BM_GETSTATE, 0, 0L);
+ }
+
+ void SetState(BOOL bHighlight)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, BM_SETSTATE, bHighlight, 0L);
+ }
+
+ int GetCheck() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, BM_GETCHECK, 0, 0L);
+ }
+
+ void SetCheck(int nCheck)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, BM_SETCHECK, nCheck, 0L);
+ }
+
+ UINT GetButtonStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::GetWindowLong(this->m_hWnd, GWL_STYLE) & 0xFFFF;
+ }
+
+ void SetButtonStyle(UINT nStyle, BOOL bRedraw = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, BM_SETSTYLE, nStyle, (LPARAM)bRedraw);
+ }
+
+ HICON GetIcon() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HICON)::SendMessage(this->m_hWnd, BM_GETIMAGE, IMAGE_ICON, 0L);
+ }
+
+ HICON SetIcon(HICON hIcon)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HICON)::SendMessage(this->m_hWnd, BM_SETIMAGE, IMAGE_ICON, (LPARAM)hIcon);
+ }
+
+ CBitmapHandle GetBitmap() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CBitmapHandle((HBITMAP)::SendMessage(this->m_hWnd, BM_GETIMAGE, IMAGE_BITMAP, 0L));
+ }
+
+ CBitmapHandle SetBitmap(HBITMAP hBitmap)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CBitmapHandle((HBITMAP)::SendMessage(this->m_hWnd, BM_SETIMAGE, IMAGE_BITMAP, (LPARAM)hBitmap));
+ }
+
+ BOOL GetIdealSize(LPSIZE lpSize) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_GETIDEALSIZE, 0, (LPARAM)lpSize);
+ }
+
+ BOOL GetImageList(PBUTTON_IMAGELIST pButtonImagelist) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_GETIMAGELIST, 0, (LPARAM)pButtonImagelist);
+ }
+
+ BOOL SetImageList(PBUTTON_IMAGELIST pButtonImagelist)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_SETIMAGELIST, 0, (LPARAM)pButtonImagelist);
+ }
+
+ BOOL GetTextMargin(LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_GETTEXTMARGIN, 0, (LPARAM)lpRect);
+ }
+
+ BOOL SetTextMargin(LPRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_SETTEXTMARGIN, 0, (LPARAM)lpRect);
+ }
+
+#if (WINVER >= 0x0600)
+ void SetDontClick(BOOL bDontClick)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, BM_SETDONTCLICK, (WPARAM)bDontClick, 0L);
+ }
+#endif // (WINVER >= 0x0600)
+
+#if (_WIN32_WINNT >= 0x0600)
+ BOOL SetDropDownState(BOOL bDropDown)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (BS_SPLITBUTTON | BS_DEFSPLITBUTTON)) != 0);
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_SETDROPDOWNSTATE, (WPARAM)bDropDown, 0L);
+ }
+
+ BOOL GetSplitInfo(PBUTTON_SPLITINFO pSplitInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (BS_SPLITBUTTON | BS_DEFSPLITBUTTON)) != 0);
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_GETSPLITINFO, 0, (LPARAM)pSplitInfo);
+ }
+
+ BOOL SetSplitInfo(PBUTTON_SPLITINFO pSplitInfo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (BS_SPLITBUTTON | BS_DEFSPLITBUTTON)) != 0);
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_SETSPLITINFO, 0, (LPARAM)pSplitInfo);
+ }
+
+ int GetNoteLength() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (BS_COMMANDLINK | BS_DEFCOMMANDLINK)) != 0);
+ return (int)::SendMessage(this->m_hWnd, BCM_GETNOTELENGTH, 0, 0L);
+ }
+
+ BOOL GetNote(LPWSTR lpstrNoteText, int cchNoteText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (BS_COMMANDLINK | BS_DEFCOMMANDLINK)) != 0);
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_GETNOTE, cchNoteText, (LPARAM)lpstrNoteText);
+ }
+
+ BOOL SetNote(LPCWSTR lpstrNoteText)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (BS_COMMANDLINK | BS_DEFCOMMANDLINK)) != 0);
+ return (BOOL)::SendMessage(this->m_hWnd, BCM_SETNOTE, 0, (LPARAM)lpstrNoteText);
+ }
+
+ LRESULT SetElevationRequiredState(BOOL bSet)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SendMessage(this->m_hWnd, BCM_SETSHIELD, 0, (LPARAM)bSet);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+// Operations
+ void Click()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, BM_CLICK, 0, 0L);
+ }
+};
+
+typedef CButtonT<ATL::CWindow> CButton;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CListBox - client side for a Windows LISTBOX control
+
+template <class TBase>
+class CListBoxT : public TBase
+{
+public:
+// Constructors
+ CListBoxT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CListBoxT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return _T("LISTBOX");
+ }
+
+ // for entire listbox
+ int GetCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETCOUNT, 0, 0L);
+ }
+
+ int SetCount(int cItems)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(((this->GetStyle() & LBS_NODATA) != 0) && ((this->GetStyle() & LBS_HASSTRINGS) == 0));
+ return (int)::SendMessage(this->m_hWnd, LB_SETCOUNT, cItems, 0L);
+ }
+
+ int GetHorizontalExtent() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETHORIZONTALEXTENT, 0, 0L);
+ }
+
+ void SetHorizontalExtent(int cxExtent)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LB_SETHORIZONTALEXTENT, cxExtent, 0L);
+ }
+
+ int GetTopIndex() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETTOPINDEX, 0, 0L);
+ }
+
+ int SetTopIndex(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_SETTOPINDEX, nIndex, 0L);
+ }
+
+ LCID GetLocale() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LCID)::SendMessage(this->m_hWnd, LB_GETLOCALE, 0, 0L);
+ }
+
+ LCID SetLocale(LCID nNewLocale)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LCID)::SendMessage(this->m_hWnd, LB_SETLOCALE, (WPARAM)nNewLocale, 0L);
+ }
+
+ DWORD GetListBoxInfo() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, LB_GETLISTBOXINFO, 0, 0L);
+ }
+
+ // for single-selection listboxes
+ int GetCurSel() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) == 0);
+ return (int)::SendMessage(this->m_hWnd, LB_GETCURSEL, 0, 0L);
+ }
+
+ int SetCurSel(int nSelect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) == 0);
+ return (int)::SendMessage(this->m_hWnd, LB_SETCURSEL, nSelect, 0L);
+ }
+
+ // for multiple-selection listboxes
+ int GetSel(int nIndex) const // also works for single-selection
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETSEL, nIndex, 0L);
+ }
+
+ int SetSel(int nIndex, BOOL bSelect = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0);
+ return (int)::SendMessage(this->m_hWnd, LB_SETSEL, bSelect, nIndex);
+ }
+
+ int GetSelCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0);
+ return (int)::SendMessage(this->m_hWnd, LB_GETSELCOUNT, 0, 0L);
+ }
+
+ int GetSelItems(int nMaxItems, LPINT rgIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0);
+ return (int)::SendMessage(this->m_hWnd, LB_GETSELITEMS, nMaxItems, (LPARAM)rgIndex);
+ }
+
+ int GetAnchorIndex() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0);
+ return (int)::SendMessage(this->m_hWnd, LB_GETANCHORINDEX, 0, 0L);
+ }
+
+ void SetAnchorIndex(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0);
+ ::SendMessage(this->m_hWnd, LB_SETANCHORINDEX, nIndex, 0L);
+ }
+
+ int GetCaretIndex() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETCARETINDEX, 0, 0);
+ }
+
+ int SetCaretIndex(int nIndex, BOOL bScroll = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_SETCARETINDEX, nIndex, MAKELONG(bScroll, 0));
+ }
+
+ // for listbox items
+ DWORD_PTR GetItemData(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD_PTR)::SendMessage(this->m_hWnd, LB_GETITEMDATA, nIndex, 0L);
+ }
+
+ int SetItemData(int nIndex, DWORD_PTR dwItemData)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_SETITEMDATA, nIndex, (LPARAM)dwItemData);
+ }
+
+ void* GetItemDataPtr(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (void*)::SendMessage(this->m_hWnd, LB_GETITEMDATA, nIndex, 0L);
+ }
+
+ int SetItemDataPtr(int nIndex, void* pData)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItemData(nIndex, (DWORD_PTR)pData);
+ }
+
+ int GetItemRect(int nIndex, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETITEMRECT, nIndex, (LPARAM)lpRect);
+ }
+
+ int GetText(int nIndex, LPTSTR lpszBuffer) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETTEXT, nIndex, (LPARAM)lpszBuffer);
+ }
+
+#ifdef _OLEAUTO_H_
+ BOOL GetTextBSTR(int nIndex, BSTR& bstrText) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(bstrText == NULL);
+
+ int nLen = GetTextLen(nIndex);
+ if(nLen == LB_ERR)
+ return FALSE;
+
+ ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff;
+ LPTSTR lpstrText = buff.Allocate(nLen + 1);
+ if(lpstrText == NULL)
+ return FALSE;
+
+ if(GetText(nIndex, lpstrText) == LB_ERR)
+ return FALSE;
+
+ bstrText = ::SysAllocString(T2OLE(lpstrText));
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+#endif // _OLEAUTO_H_
+
+#ifdef __ATLSTR_H__
+ int GetText(int nIndex, ATL::CString& strText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ int cchLen = GetTextLen(nIndex);
+ if(cchLen == LB_ERR)
+ return LB_ERR;
+ int nRet = LB_ERR;
+ LPTSTR lpstr = strText.GetBufferSetLength(cchLen);
+ if(lpstr != NULL)
+ {
+ nRet = GetText(nIndex, lpstr);
+ strText.ReleaseBuffer();
+ }
+ return nRet;
+ }
+#endif // __ATLSTR_H__
+
+ int GetTextLen(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETTEXTLEN, nIndex, 0L);
+ }
+
+ int GetItemHeight(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_GETITEMHEIGHT, nIndex, 0L);
+ }
+
+ int SetItemHeight(int nIndex, UINT cyItemHeight)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_SETITEMHEIGHT, nIndex, MAKELONG(cyItemHeight, 0));
+ }
+
+ // Settable only attributes
+ void SetColumnWidth(int cxWidth)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LB_SETCOLUMNWIDTH, cxWidth, 0L);
+ }
+
+ BOOL SetTabStops(int nTabStops, LPINT rgTabStops)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & LBS_USETABSTOPS) != 0);
+ return (BOOL)::SendMessage(this->m_hWnd, LB_SETTABSTOPS, nTabStops, (LPARAM)rgTabStops);
+ }
+
+ BOOL SetTabStops()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & LBS_USETABSTOPS) != 0);
+ return (BOOL)::SendMessage(this->m_hWnd, LB_SETTABSTOPS, 0, 0L);
+ }
+
+ BOOL SetTabStops(const int& cxEachStop) // takes an 'int'
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & LBS_USETABSTOPS) != 0);
+ return (BOOL)::SendMessage(this->m_hWnd, LB_SETTABSTOPS, 1, (LPARAM)(LPINT)&cxEachStop);
+ }
+
+// Operations
+ int InitStorage(int nItems, UINT nBytes)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_INITSTORAGE, (WPARAM)nItems, nBytes);
+ }
+
+ void ResetContent()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LB_RESETCONTENT, 0, 0L);
+ }
+
+ UINT ItemFromPoint(POINT pt, BOOL& bOutside) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dw = (DWORD)::SendMessage(this->m_hWnd, LB_ITEMFROMPOINT, 0, MAKELPARAM(pt.x, pt.y));
+ bOutside = (BOOL)HIWORD(dw);
+ return (UINT)LOWORD(dw);
+ }
+
+ // manipulating listbox items
+ int AddString(LPCTSTR lpszItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_ADDSTRING, 0, (LPARAM)lpszItem);
+ }
+
+ int DeleteString(UINT nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_DELETESTRING, nIndex, 0L);
+ }
+
+ int InsertString(int nIndex, LPCTSTR lpszItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_INSERTSTRING, nIndex, (LPARAM)lpszItem);
+ }
+
+ int Dir(UINT attr, LPCTSTR lpszWildCard)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_DIR, attr, (LPARAM)lpszWildCard);
+ }
+
+ int AddFile(LPCTSTR lpstrFileName)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_ADDFILE, 0, (LPARAM)lpstrFileName);
+ }
+
+ // selection helpers
+ int FindString(int nStartAfter, LPCTSTR lpszItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_FINDSTRING, nStartAfter, (LPARAM)lpszItem);
+ }
+
+ int FindStringExact(int nIndexStart, LPCTSTR lpszFind) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_FINDSTRINGEXACT, nIndexStart, (LPARAM)lpszFind);
+ }
+
+ int SelectString(int nStartAfter, LPCTSTR lpszItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LB_SELECTSTRING, nStartAfter, (LPARAM)lpszItem);
+ }
+
+ int SelItemRange(BOOL bSelect, int nFirstItem, int nLastItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) != 0);
+ ATLASSERT(nFirstItem <= nLastItem);
+ return bSelect ? (int)::SendMessage(this->m_hWnd, LB_SELITEMRANGEEX, nFirstItem, nLastItem) : (int)::SendMessage(this->m_hWnd, LB_SELITEMRANGEEX, nLastItem, nFirstItem);
+ }
+};
+
+typedef CListBoxT<ATL::CWindow> CListBox;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CComboBox - client side for a Windows COMBOBOX control
+
+template <class TBase>
+class CComboBoxT : public TBase
+{
+public:
+// Constructors
+ CComboBoxT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CComboBoxT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return _T("COMBOBOX");
+ }
+
+ // for entire combo box
+ int GetCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_GETCOUNT, 0, 0L);
+ }
+
+ int GetCurSel() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_GETCURSEL, 0, 0L);
+ }
+
+ int SetCurSel(int nSelect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_SETCURSEL, nSelect, 0L);
+ }
+
+ LCID GetLocale() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LCID)::SendMessage(this->m_hWnd, CB_GETLOCALE, 0, 0L);
+ }
+
+ LCID SetLocale(LCID nNewLocale)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LCID)::SendMessage(this->m_hWnd, CB_SETLOCALE, (WPARAM)nNewLocale, 0L);
+ }
+
+ int GetTopIndex() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_GETTOPINDEX, 0, 0L);
+ }
+
+ int SetTopIndex(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_SETTOPINDEX, nIndex, 0L);
+ }
+
+ UINT GetHorizontalExtent() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, CB_GETHORIZONTALEXTENT, 0, 0L);
+ }
+
+ void SetHorizontalExtent(UINT nExtent)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, CB_SETHORIZONTALEXTENT, nExtent, 0L);
+ }
+
+ int GetDroppedWidth() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_GETDROPPEDWIDTH, 0, 0L);
+ }
+
+ int SetDroppedWidth(UINT nWidth)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_SETDROPPEDWIDTH, nWidth, 0L);
+ }
+
+ BOOL GetComboBoxInfo(PCOMBOBOXINFO pComboBoxInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CB_GETCOMBOBOXINFO, 0, (LPARAM)pComboBoxInfo);
+ }
+
+ // for edit control
+ DWORD GetEditSel() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, CB_GETEDITSEL, 0, 0L);
+ }
+
+ BOOL SetEditSel(int nStartChar, int nEndChar)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CB_SETEDITSEL, 0, MAKELONG(nStartChar, nEndChar));
+ }
+
+ // for combobox item
+ DWORD_PTR GetItemData(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD_PTR)::SendMessage(this->m_hWnd, CB_GETITEMDATA, nIndex, 0L);
+ }
+
+ int SetItemData(int nIndex, DWORD_PTR dwItemData)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_SETITEMDATA, nIndex, (LPARAM)dwItemData);
+ }
+
+ void* GetItemDataPtr(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (void*)GetItemData(nIndex);
+ }
+
+ int SetItemDataPtr(int nIndex, void* pData)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItemData(nIndex, (DWORD_PTR)pData);
+ }
+
+ int GetLBText(int nIndex, LPTSTR lpszText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_GETLBTEXT, nIndex, (LPARAM)lpszText);
+ }
+
+ BOOL GetLBTextBSTR(int nIndex, BSTR& bstrText) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(bstrText == NULL);
+
+ int nLen = GetLBTextLen(nIndex);
+ if(nLen == CB_ERR)
+ return FALSE;
+
+ ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff;
+ LPTSTR lpstrText = buff.Allocate(nLen + 1);
+ if(lpstrText == NULL)
+ return FALSE;
+
+ if(GetLBText(nIndex, lpstrText) == CB_ERR)
+ return FALSE;
+
+ bstrText = ::SysAllocString(T2OLE(lpstrText));
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+
+#ifdef __ATLSTR_H__
+ int GetLBText(int nIndex, ATL::CString& strText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ int cchLen = GetLBTextLen(nIndex);
+ if(cchLen == CB_ERR)
+ return CB_ERR;
+ int nRet = CB_ERR;
+ LPTSTR lpstr = strText.GetBufferSetLength(cchLen);
+ if(lpstr != NULL)
+ {
+ nRet = GetLBText(nIndex, lpstr);
+ strText.ReleaseBuffer();
+ }
+ return nRet;
+ }
+#endif // __ATLSTR_H__
+
+ int GetLBTextLen(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_GETLBTEXTLEN, nIndex, 0L);
+ }
+
+ int GetItemHeight(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_GETITEMHEIGHT, nIndex, 0L);
+ }
+
+ int SetItemHeight(int nIndex, UINT cyItemHeight)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_SETITEMHEIGHT, nIndex, MAKELONG(cyItemHeight, 0));
+ }
+
+ BOOL GetExtendedUI() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CB_GETEXTENDEDUI, 0, 0L);
+ }
+
+ int SetExtendedUI(BOOL bExtended = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_SETEXTENDEDUI, bExtended, 0L);
+ }
+
+ void GetDroppedControlRect(LPRECT lprect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, CB_GETDROPPEDCONTROLRECT, 0, (LPARAM)lprect);
+ }
+
+ BOOL GetDroppedState() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CB_GETDROPPEDSTATE, 0, 0L);
+ }
+
+ int GetMinVisible() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_GETMINVISIBLE, 0, 0L);
+ }
+
+ BOOL SetMinVisible(int nMinVisible)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CB_SETMINVISIBLE, nMinVisible, 0L);
+ }
+
+ // Vista only
+ BOOL GetCueBannerText(LPWSTR lpwText, int cchText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CB_GETCUEBANNER, (WPARAM)lpwText, cchText);
+ }
+
+ // Vista only
+ BOOL SetCueBannerText(LPCWSTR lpcwText)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CB_SETCUEBANNER, 0, (LPARAM)lpcwText);
+ }
+
+// Operations
+ int InitStorage(int nItems, UINT nBytes)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_INITSTORAGE, (WPARAM)nItems, nBytes);
+ }
+
+ void ResetContent()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, CB_RESETCONTENT, 0, 0L);
+ }
+
+ // for edit control
+ BOOL LimitText(int nMaxChars)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CB_LIMITTEXT, nMaxChars, 0L);
+ }
+
+ // for drop-down combo boxes
+ void ShowDropDown(BOOL bShowIt = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, CB_SHOWDROPDOWN, bShowIt, 0L);
+ }
+
+ // manipulating listbox items
+ int AddString(LPCTSTR lpszString)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_ADDSTRING, 0, (LPARAM)lpszString);
+ }
+
+ int DeleteString(UINT nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_DELETESTRING, nIndex, 0L);
+ }
+
+ int InsertString(int nIndex, LPCTSTR lpszString)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_INSERTSTRING, nIndex, (LPARAM)lpszString);
+ }
+
+ int Dir(UINT attr, LPCTSTR lpszWildCard)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_DIR, attr, (LPARAM)lpszWildCard);
+ }
+
+ // selection helpers
+ int FindString(int nStartAfter, LPCTSTR lpszString) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_FINDSTRING, nStartAfter, (LPARAM)lpszString);
+ }
+
+ int FindStringExact(int nIndexStart, LPCTSTR lpszFind) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_FINDSTRINGEXACT, nIndexStart, (LPARAM)lpszFind);
+ }
+
+ int SelectString(int nStartAfter, LPCTSTR lpszString)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CB_SELECTSTRING, nStartAfter, (LPARAM)lpszString);
+ }
+
+ // Clipboard operations
+ void Clear()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_CLEAR, 0, 0L);
+ }
+
+ void Copy()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_COPY, 0, 0L);
+ }
+
+ void Cut()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_CUT, 0, 0L);
+ }
+
+ void Paste()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_PASTE, 0, 0L);
+ }
+};
+
+typedef CComboBoxT<ATL::CWindow> CComboBox;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CEdit - client side for a Windows EDIT control
+
+template <class TBase>
+class CEditT : public TBase
+{
+public:
+// Constructors
+ CEditT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CEditT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return _T("EDIT");
+ }
+
+ BOOL CanUndo() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_CANUNDO, 0, 0L);
+ }
+
+ int GetLineCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETLINECOUNT, 0, 0L);
+ }
+
+ BOOL GetModify() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_GETMODIFY, 0, 0L);
+ }
+
+ void SetModify(BOOL bModified = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETMODIFY, bModified, 0L);
+ }
+
+ void GetRect(LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_GETRECT, 0, (LPARAM)lpRect);
+ }
+
+ DWORD GetSel() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETSEL, 0, 0L);
+ }
+
+ void GetSel(int& nStartChar, int& nEndChar) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_GETSEL, (WPARAM)&nStartChar, (LPARAM)&nEndChar);
+ }
+
+ HLOCAL GetHandle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HLOCAL)::SendMessage(this->m_hWnd, EM_GETHANDLE, 0, 0L);
+ }
+
+ void SetHandle(HLOCAL hBuffer)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETHANDLE, (WPARAM)hBuffer, 0L);
+ }
+
+ DWORD GetMargins() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETMARGINS, 0, 0L);
+ }
+
+ void GetMargins(UINT& nLeft, UINT& nRight) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, EM_GETMARGINS, 0, 0L);
+ nLeft = LOWORD(dwRet);
+ nRight = HIWORD(dwRet);
+ }
+
+ void SetMargins(UINT nLeft, UINT nRight, WORD wFlags = EC_LEFTMARGIN | EC_RIGHTMARGIN)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETMARGINS, wFlags, MAKELONG(nLeft, nRight));
+ }
+
+ UINT GetLimitText() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, EM_GETLIMITTEXT, 0, 0L);
+ }
+
+ void SetLimitText(UINT nMax)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETLIMITTEXT, nMax, 0L);
+ }
+
+ POINT PosFromChar(UINT nChar) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, EM_POSFROMCHAR, nChar, 0);
+ POINT point = { GET_X_LPARAM(dwRet), GET_Y_LPARAM(dwRet) };
+ return point;
+ }
+
+ int CharFromPos(POINT pt, int* pLine = NULL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, EM_CHARFROMPOS, 0, MAKELPARAM(pt.x, pt.y));
+ if(pLine != NULL)
+ *pLine = (int)(short)HIWORD(dwRet);
+ return (int)(short)LOWORD(dwRet);
+ }
+
+ // NOTE: first word in lpszBuffer must contain the size of the buffer!
+ int GetLine(int nIndex, LPTSTR lpszBuffer) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETLINE, nIndex, (LPARAM)lpszBuffer);
+ }
+
+ int GetLine(int nIndex, LPTSTR lpszBuffer, int nMaxLength) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ *(LPWORD)lpszBuffer = (WORD)nMaxLength;
+ return (int)::SendMessage(this->m_hWnd, EM_GETLINE, nIndex, (LPARAM)lpszBuffer);
+ }
+
+ TCHAR GetPasswordChar() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (TCHAR)::SendMessage(this->m_hWnd, EM_GETPASSWORDCHAR, 0, 0L);
+ }
+
+ void SetPasswordChar(TCHAR ch)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETPASSWORDCHAR, ch, 0L);
+ }
+
+ EDITWORDBREAKPROC GetWordBreakProc() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (EDITWORDBREAKPROC)::SendMessage(this->m_hWnd, EM_GETWORDBREAKPROC, 0, 0L);
+ }
+
+ void SetWordBreakProc(EDITWORDBREAKPROC ewbprc)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETWORDBREAKPROC, 0, (LPARAM)ewbprc);
+ }
+
+ int GetFirstVisibleLine() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETFIRSTVISIBLELINE, 0, 0L);
+ }
+
+ int GetThumb() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & ES_MULTILINE) != 0);
+ return (int)::SendMessage(this->m_hWnd, EM_GETTHUMB, 0, 0L);
+ }
+
+ BOOL SetReadOnly(BOOL bReadOnly = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETREADONLY, bReadOnly, 0L);
+ }
+
+ UINT GetImeStatus(UINT uStatus) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, EM_GETIMESTATUS, uStatus, 0L);
+ }
+
+ UINT SetImeStatus(UINT uStatus, UINT uData)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, EM_SETIMESTATUS, uStatus, uData);
+ }
+
+ BOOL GetCueBannerText(LPCWSTR lpstrText, int cchText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_GETCUEBANNER, (WPARAM)lpstrText, cchText);
+ }
+
+ // bKeepWithFocus - Vista only
+ BOOL SetCueBannerText(LPCWSTR lpstrText, BOOL bKeepWithFocus = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCUEBANNER, (WPARAM)bKeepWithFocus, (LPARAM)(lpstrText));
+ }
+
+// Operations
+ void EmptyUndoBuffer()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_EMPTYUNDOBUFFER, 0, 0L);
+ }
+
+ BOOL FmtLines(BOOL bAddEOL)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_FMTLINES, bAddEOL, 0L);
+ }
+
+ void LimitText(int nChars = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_LIMITTEXT, nChars, 0L);
+ }
+
+ int LineFromChar(int nIndex = -1) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_LINEFROMCHAR, nIndex, 0L);
+ }
+
+ int LineIndex(int nLine = -1) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_LINEINDEX, nLine, 0L);
+ }
+
+ int LineLength(int nLine = -1) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_LINELENGTH, nLine, 0L);
+ }
+
+ void LineScroll(int nLines, int nChars = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_LINESCROLL, nChars, nLines);
+ }
+
+ void ReplaceSel(LPCTSTR lpszNewText, BOOL bCanUndo = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_REPLACESEL, (WPARAM) bCanUndo, (LPARAM)lpszNewText);
+ }
+
+ void SetRect(LPCRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETRECT, 0, (LPARAM)lpRect);
+ }
+
+ void SetRectNP(LPCRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETRECTNP, 0, (LPARAM)lpRect);
+ }
+
+ void SetSel(DWORD dwSelection, BOOL bNoScroll = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETSEL, LOWORD(dwSelection), HIWORD(dwSelection));
+ if(!bNoScroll)
+ ::SendMessage(this->m_hWnd, EM_SCROLLCARET, 0, 0L);
+ }
+
+ void SetSel(int nStartChar, int nEndChar, BOOL bNoScroll = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETSEL, nStartChar, nEndChar);
+ if(!bNoScroll)
+ ::SendMessage(this->m_hWnd, EM_SCROLLCARET, 0, 0L);
+ }
+
+ void SetSelAll(BOOL bNoScroll = FALSE)
+ {
+ SetSel(0, -1, bNoScroll);
+ }
+
+ void SetSelNone(BOOL bNoScroll = FALSE)
+ {
+ SetSel(-1, 0, bNoScroll);
+ }
+
+ BOOL SetTabStops(int nTabStops, LPINT rgTabStops)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, nTabStops, (LPARAM)rgTabStops);
+ }
+
+ BOOL SetTabStops()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, 0, 0L);
+ }
+
+ BOOL SetTabStops(const int& cxEachStop) // takes an 'int'
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, 1, (LPARAM)(LPINT)&cxEachStop);
+ }
+
+ void ScrollCaret()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SCROLLCARET, 0, 0L);
+ }
+
+ int Scroll(int nScrollAction)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & ES_MULTILINE) != 0);
+ LRESULT lRet = ::SendMessage(this->m_hWnd, EM_SCROLL, nScrollAction, 0L);
+ if(!(BOOL)HIWORD(lRet))
+ return -1; // failed
+ return (int)(short)LOWORD(lRet);
+
+ }
+
+ void InsertText(int nInsertAfterChar, LPCTSTR lpstrText, BOOL bNoScroll = FALSE, BOOL bCanUndo = FALSE)
+ {
+ SetSel(nInsertAfterChar, nInsertAfterChar, bNoScroll);
+ ReplaceSel(lpstrText, bCanUndo);
+ }
+
+ void AppendText(LPCTSTR lpstrText, BOOL bNoScroll = FALSE, BOOL bCanUndo = FALSE)
+ {
+ InsertText(this->GetWindowTextLength(), lpstrText, bNoScroll, bCanUndo);
+ }
+
+ BOOL ShowBalloonTip(PEDITBALLOONTIP pEditBaloonTip)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SHOWBALLOONTIP, 0, (LPARAM)pEditBaloonTip);
+ }
+
+ BOOL HideBalloonTip()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_HIDEBALLOONTIP, 0, 0L);
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ DWORD GetHilite() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETHILITE, 0, 0L);
+ }
+
+ void GetHilite(int& nStartChar, int& nEndChar) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, EM_GETHILITE, 0, 0L);
+ nStartChar = (int)(short)LOWORD(dwRet);
+ nEndChar = (int)(short)HIWORD(dwRet);
+ }
+
+ void SetHilite(int nStartChar, int nEndChar)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETHILITE, nStartChar, nEndChar);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+ // Clipboard operations
+ BOOL Undo()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_UNDO, 0, 0L);
+ }
+
+ void Clear()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_CLEAR, 0, 0L);
+ }
+
+ void Copy()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_COPY, 0, 0L);
+ }
+
+ void Cut()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_CUT, 0, 0L);
+ }
+
+ void Paste()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_PASTE, 0, 0L);
+ }
+
+ // New messages added in Windows 10.0.17763
+#if defined(NTDDI_VERSION) && defined(NTDDI_WIN10_RS5) && (NTDDI_VERSION >= NTDDI_WIN10_RS5)
+ DWORD SetExtendedStyle(DWORD dwStyle, DWORD dwMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SendMessage(this->m_hWnd, EM_SETEXTENDEDSTYLE, dwMask, dwStyle);
+ }
+
+ DWORD GetExtendedStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SendMessage(this->m_hWnd, EM_GETEXTENDEDSTYLE, 0, 0L);
+ }
+
+ BOOL SetEndOfLine(EC_ENDOFLINE eolType)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETENDOFLINE, eolType, 0L);
+ }
+
+ EC_ENDOFLINE GetEndOfLine() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (EC_ENDOFLINE)::SendMessage(this->m_hWnd, EM_GETENDOFLINE, 0, 0L);
+ }
+
+ BOOL EnableSearchWeb(BOOL bEnable)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_ENABLESEARCHWEB, (WPARAM)bEnable, 0L);
+ }
+
+ void SearchWeb()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SEARCHWEB, 0, 0L);
+ }
+
+ BOOL SetCaretIndex(DWORD dwCaretIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCARETINDEX, dwCaretIndex, 0L);
+ }
+
+ DWORD GetCaretIndex() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SendMessage(this->m_hWnd, EM_GETCARETINDEX, 0, 0L);
+ }
+
+ BOOL GetZoom(int& nNum, int& nDen) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_GETZOOM, (WPARAM)&nNum, (LPARAM)&nDen);
+ }
+
+ BOOL SetZoom(int nNum, int nDen)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((nNum >= 0) && (nNum <= 64));
+ ATLASSERT((nDen >= 0) && (nDen <= 64));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETZOOM, nNum, nDen);
+ }
+
+ DWORD GetFileLineFromChar(DWORD dwCharIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SendMessage(this->m_hWnd, EM_FILELINEFROMCHAR, dwCharIndex, 0L);
+ }
+
+ DWORD GetFileLineIndex(DWORD dwLineNum) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SendMessage(this->m_hWnd, EM_FILELINEINDEX, dwLineNum, 0L);
+ }
+
+ DWORD GetFileLineLength(DWORD dwCharIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SendMessage(this->m_hWnd, EM_FILELINELENGTH, dwCharIndex, 0L);
+ }
+
+ DWORD GetFileLine(DWORD dwLineNum, LPTSTR lpstrLine, WORD wLen) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ WORD* pw = (WORD*)lpstrLine;
+ *pw = wLen;
+ return ::SendMessage(this->m_hWnd, EM_GETFILELINE, dwLineNum, (LPARAM)lpstrLine);
+ }
+
+#ifdef __ATLSTR_H__
+ ATL::CString GetFileLine(DWORD dwLineNum) const
+ {
+ ATL::CString strLine;
+ DWORD dwCharIndex = GetFileLineIndex(dwLineNum);
+ if(dwCharIndex != (DWORD)-1)
+ {
+ DWORD dwLen = GetFileLineLength(dwCharIndex);
+ if(dwLen > 0)
+ {
+ LPTSTR lpstrLine = strLine.GetBufferSetLength(dwLen);
+ ATLVERIFY(GetFileLine(dwLineNum, lpstrLine, (WORD)dwLen) == dwLen);
+ strLine.ReleaseBuffer();
+ }
+ }
+
+ return strLine;
+ }
+#endif // __ATLSTR_H__
+
+ DWORD GetFileLineCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SendMessage(this->m_hWnd, EM_GETFILELINECOUNT, 0, 0L);
+ }
+#endif // defined(NTDDI_VERSION) && defined(NTDDI_WIN10_RS5) && (NTDDI_VERSION >= NTDDI_WIN10_RS5)
+};
+
+typedef CEditT<ATL::CWindow> CEdit;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CEditCommands - message handlers for standard EDIT commands
+
+// Chain to CEditCommands message map. Your class must also derive from CEdit.
+// Example:
+// class CMyEdit : public CWindowImpl<CMyEdit, CEdit>,
+// public CEditCommands<CMyEdit>
+// {
+// public:
+// BEGIN_MSG_MAP(CMyEdit)
+// // your handlers...
+// CHAIN_MSG_MAP_ALT(CEditCommands<CMyEdit>, 1)
+// END_MSG_MAP()
+// // other stuff...
+// };
+
+template <class T>
+class CEditCommands
+{
+public:
+ BEGIN_MSG_MAP(CEditCommands< T >)
+ ALT_MSG_MAP(1)
+ COMMAND_ID_HANDLER(ID_EDIT_CLEAR, OnEditClear)
+ COMMAND_ID_HANDLER(ID_EDIT_CLEAR_ALL, OnEditClearAll)
+ COMMAND_ID_HANDLER(ID_EDIT_COPY, OnEditCopy)
+ COMMAND_ID_HANDLER(ID_EDIT_CUT, OnEditCut)
+ COMMAND_ID_HANDLER(ID_EDIT_PASTE, OnEditPaste)
+ COMMAND_ID_HANDLER(ID_EDIT_SELECT_ALL, OnEditSelectAll)
+ COMMAND_ID_HANDLER(ID_EDIT_UNDO, OnEditUndo)
+ END_MSG_MAP()
+
+ LRESULT OnEditClear(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->Clear();
+ return 0;
+ }
+
+ LRESULT OnEditClearAll(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->SetSel(0, -1);
+ pT->Clear();
+ return 0;
+ }
+
+ LRESULT OnEditCopy(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->Copy();
+ return 0;
+ }
+
+ LRESULT OnEditCut(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->Cut();
+ return 0;
+ }
+
+ LRESULT OnEditPaste(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->Paste();
+ return 0;
+ }
+
+ LRESULT OnEditSelectAll(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->SetSel(0, -1);
+ return 0;
+ }
+
+ LRESULT OnEditUndo(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->Undo();
+ return 0;
+ }
+
+// State (update UI) helpers
+ BOOL CanCut() const
+ { return HasSelection(); }
+
+ BOOL CanCopy() const
+ { return HasSelection(); }
+
+ BOOL CanClear() const
+ { return HasSelection(); }
+
+ BOOL CanSelectAll() const
+ { return HasText(); }
+
+ BOOL CanFind() const
+ { return HasText(); }
+
+ BOOL CanRepeat() const
+ { return HasText(); }
+
+ BOOL CanReplace() const
+ { return HasText(); }
+
+ BOOL CanClearAll() const
+ { return HasText(); }
+
+// Implementation
+ BOOL HasSelection() const
+ {
+ const T* pT = static_cast<const T*>(this);
+ int nMin = 0, nMax = 0;
+ ::SendMessage(pT->m_hWnd, EM_GETSEL, (WPARAM)&nMin, (LPARAM)&nMax);
+ return (nMin != nMax);
+ }
+
+ BOOL HasText() const
+ {
+ const T* pT = static_cast<const T*>(this);
+ return (pT->GetWindowTextLength() > 0);
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CScrollBar - client side for a Windows SCROLLBAR control
+
+template <class TBase>
+class CScrollBarT : public TBase
+{
+public:
+// Constructors
+ CScrollBarT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CScrollBarT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return _T("SCROLLBAR");
+ }
+
+ int GetScrollPos() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::GetScrollPos(this->m_hWnd, SB_CTL);
+ }
+
+ int SetScrollPos(int nPos, BOOL bRedraw = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SetScrollPos(this->m_hWnd, SB_CTL, nPos, bRedraw);
+ }
+
+ void GetScrollRange(LPINT lpMinPos, LPINT lpMaxPos) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::GetScrollRange(this->m_hWnd, SB_CTL, lpMinPos, lpMaxPos);
+ }
+
+ void SetScrollRange(int nMinPos, int nMaxPos, BOOL bRedraw = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SetScrollRange(this->m_hWnd, SB_CTL, nMinPos, nMaxPos, bRedraw);
+ }
+
+ BOOL GetScrollInfo(LPSCROLLINFO lpScrollInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::GetScrollInfo(this->m_hWnd, SB_CTL, lpScrollInfo);
+ }
+
+ int SetScrollInfo(LPSCROLLINFO lpScrollInfo, BOOL bRedraw = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::SetScrollInfo(this->m_hWnd, SB_CTL, lpScrollInfo, bRedraw);
+ }
+
+ int GetScrollLimit() const
+ {
+ SCROLLINFO info = { sizeof(SCROLLINFO), SIF_RANGE | SIF_PAGE };
+ ::GetScrollInfo(this->m_hWnd, SB_CTL, &info);
+ if(info.nPage > 1)
+ info.nMax -= info.nPage - 1;
+
+ return info.nMax;
+ }
+
+ BOOL GetScrollBarInfo(PSCROLLBARINFO pScrollBarInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, SBM_GETSCROLLBARINFO, 0, (LPARAM)pScrollBarInfo);
+ }
+
+// Operations
+ void ShowScrollBar(BOOL bShow = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::ShowScrollBar(this->m_hWnd, SB_CTL, bShow);
+ }
+
+ BOOL EnableScrollBar(UINT nArrowFlags = ESB_ENABLE_BOTH)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::EnableScrollBar(this->m_hWnd, SB_CTL, nArrowFlags);
+ }
+};
+
+typedef CScrollBarT<ATL::CWindow> CScrollBar;
+
+
+// --- Windows Common Controls ---
+
+///////////////////////////////////////////////////////////////////////////////
+// CImageList
+
+// forward declarations
+template <bool t_bManaged> class CImageListT;
+typedef CImageListT<false> CImageList;
+typedef CImageListT<true> CImageListManaged;
+
+
+template <bool t_bManaged>
+class CImageListT
+{
+public:
+// Data members
+ HIMAGELIST m_hImageList;
+
+// Constructor/destructor/operators
+ CImageListT(HIMAGELIST hImageList = NULL) : m_hImageList(hImageList)
+ { }
+
+ ~CImageListT()
+ {
+ if(t_bManaged && (m_hImageList != NULL))
+ Destroy();
+ }
+
+ CImageListT<t_bManaged>& operator =(HIMAGELIST hImageList)
+ {
+ Attach(hImageList);
+ return *this;
+ }
+
+ void Attach(HIMAGELIST hImageList)
+ {
+ if(t_bManaged && (m_hImageList != NULL) && (m_hImageList != hImageList))
+ ImageList_Destroy(m_hImageList);
+ m_hImageList = hImageList;
+ }
+
+ HIMAGELIST Detach()
+ {
+ HIMAGELIST hImageList = m_hImageList;
+ m_hImageList = NULL;
+ return hImageList;
+ }
+
+ operator HIMAGELIST() const { return m_hImageList; }
+
+ bool IsNull() const { return (m_hImageList == NULL); }
+
+// Attributes
+ int GetImageCount() const
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_GetImageCount(m_hImageList);
+ }
+
+ COLORREF GetBkColor() const
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_GetBkColor(m_hImageList);
+ }
+
+ COLORREF SetBkColor(COLORREF cr)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_SetBkColor(m_hImageList, cr);
+ }
+
+ BOOL GetImageInfo(int nImage, IMAGEINFO* pImageInfo) const
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_GetImageInfo(m_hImageList, nImage, pImageInfo);
+ }
+
+ HICON GetIcon(int nIndex, UINT uFlags = ILD_NORMAL) const
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_GetIcon(m_hImageList, nIndex, uFlags);
+ }
+
+ BOOL GetIconSize(int& cx, int& cy) const
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_GetIconSize(m_hImageList, &cx, &cy);
+ }
+
+ BOOL GetIconSize(SIZE& size) const
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_GetIconSize(m_hImageList, (int*)&size.cx, (int*)&size.cy);
+ }
+
+ BOOL SetIconSize(int cx, int cy)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_SetIconSize(m_hImageList, cx, cy);
+ }
+
+ BOOL SetIconSize(SIZE size)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_SetIconSize(m_hImageList, size.cx, size.cy);
+ }
+
+ BOOL SetImageCount(UINT uNewCount)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_SetImageCount(m_hImageList, uNewCount);
+ }
+
+ BOOL SetOverlayImage(int nImage, int nOverlay)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_SetOverlayImage(m_hImageList, nImage, nOverlay);
+ }
+
+// Operations
+ BOOL Create(int cx, int cy, UINT nFlags, int nInitial, int nGrow)
+ {
+ ATLASSERT(m_hImageList == NULL);
+ m_hImageList = ImageList_Create(cx, cy, nFlags, nInitial, nGrow);
+ return (m_hImageList != NULL) ? TRUE : FALSE;
+ }
+
+ BOOL Create(ATL::_U_STRINGorID bitmap, int cx, int nGrow, COLORREF crMask)
+ {
+ ATLASSERT(m_hImageList == NULL);
+ m_hImageList = ImageList_LoadBitmap(ModuleHelper::GetResourceInstance(), bitmap.m_lpstr, cx, nGrow, crMask);
+ return (m_hImageList != NULL) ? TRUE : FALSE;
+ }
+
+ BOOL CreateFromImage(ATL::_U_STRINGorID image, int cx, int nGrow, COLORREF crMask, UINT uType, UINT uFlags = LR_DEFAULTCOLOR | LR_DEFAULTSIZE)
+ {
+ ATLASSERT(m_hImageList == NULL);
+ m_hImageList = ImageList_LoadImage(ModuleHelper::GetResourceInstance(), image.m_lpstr, cx, nGrow, crMask, uType, uFlags);
+ return (m_hImageList != NULL) ? TRUE : FALSE;
+ }
+
+ BOOL Merge(HIMAGELIST hImageList1, int nImage1, HIMAGELIST hImageList2, int nImage2, int dx, int dy)
+ {
+ ATLASSERT(m_hImageList == NULL);
+ m_hImageList = ImageList_Merge(hImageList1, nImage1, hImageList2, nImage2, dx, dy);
+ return (m_hImageList != NULL) ? TRUE : FALSE;
+ }
+
+#ifdef __IStream_INTERFACE_DEFINED__
+ BOOL CreateFromStream(LPSTREAM lpStream)
+ {
+ ATLASSERT(m_hImageList == NULL);
+ m_hImageList = ImageList_Read(lpStream);
+ return (m_hImageList != NULL) ? TRUE : FALSE;
+ }
+#endif // __IStream_INTERFACE_DEFINED__
+
+ BOOL Destroy()
+ {
+ if (m_hImageList == NULL)
+ return FALSE;
+ BOOL bRet = ImageList_Destroy(m_hImageList);
+ if(bRet)
+ m_hImageList = NULL;
+ return bRet;
+ }
+
+ int Add(HBITMAP hBitmap, HBITMAP hBitmapMask = NULL)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_Add(m_hImageList, hBitmap, hBitmapMask);
+ }
+
+ int Add(HBITMAP hBitmap, COLORREF crMask)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_AddMasked(m_hImageList, hBitmap, crMask);
+ }
+
+ BOOL Remove(int nImage)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_Remove(m_hImageList, nImage);
+ }
+
+ BOOL RemoveAll()
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_RemoveAll(m_hImageList);
+ }
+
+ BOOL Replace(int nImage, HBITMAP hBitmap, HBITMAP hBitmapMask)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_Replace(m_hImageList, nImage, hBitmap, hBitmapMask);
+ }
+
+ int AddIcon(HICON hIcon)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_AddIcon(m_hImageList, hIcon);
+ }
+
+ int ReplaceIcon(int nImage, HICON hIcon)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_ReplaceIcon(m_hImageList, nImage, hIcon);
+ }
+
+ HICON ExtractIcon(int nImage)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_ExtractIcon(NULL, m_hImageList, nImage);
+ }
+
+ BOOL Draw(HDC hDC, int nImage, int x, int y, UINT nStyle)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ ATLASSERT(hDC != NULL);
+ return ImageList_Draw(m_hImageList, nImage, hDC, x, y, nStyle);
+ }
+
+ BOOL Draw(HDC hDC, int nImage, POINT pt, UINT nStyle)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ ATLASSERT(hDC != NULL);
+ return ImageList_Draw(m_hImageList, nImage, hDC, pt.x, pt.y, nStyle);
+ }
+
+ BOOL DrawEx(int nImage, HDC hDC, int x, int y, int dx, int dy, COLORREF rgbBk, COLORREF rgbFg, UINT fStyle)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ ATLASSERT(hDC != NULL);
+ return ImageList_DrawEx(m_hImageList, nImage, hDC, x, y, dx, dy, rgbBk, rgbFg, fStyle);
+ }
+
+ BOOL DrawEx(int nImage, HDC hDC, RECT& rect, COLORREF rgbBk, COLORREF rgbFg, UINT fStyle)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ ATLASSERT(hDC != NULL);
+ return ImageList_DrawEx(m_hImageList, nImage, hDC, rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, rgbBk, rgbFg, fStyle);
+ }
+
+ static BOOL DrawIndirect(IMAGELISTDRAWPARAMS* pimldp)
+ {
+ return ImageList_DrawIndirect(pimldp);
+ }
+
+ BOOL Copy(int nSrc, int nDst, UINT uFlags = ILCF_MOVE)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_Copy(m_hImageList, nDst, m_hImageList, nSrc, uFlags);
+ }
+
+#ifdef __IStream_INTERFACE_DEFINED__
+ static HIMAGELIST Read(LPSTREAM lpStream)
+ {
+ return ImageList_Read(lpStream);
+ }
+
+ BOOL Write(LPSTREAM lpStream)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_Write(m_hImageList, lpStream);
+ }
+
+ static HRESULT ReadEx(DWORD dwFlags, LPSTREAM lpStream, REFIID riid, PVOID* ppv)
+ {
+ return ImageList_ReadEx(dwFlags, lpStream, riid, ppv);
+ }
+
+ HRESULT WriteEx(DWORD dwFlags, LPSTREAM lpStream)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_WriteEx(m_hImageList, dwFlags, lpStream);
+ }
+#endif // __IStream_INTERFACE_DEFINED__
+
+ // Drag operations
+ BOOL BeginDrag(int nImage, POINT ptHotSpot)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_BeginDrag(m_hImageList, nImage, ptHotSpot.x, ptHotSpot.y);
+ }
+
+ BOOL BeginDrag(int nImage, int xHotSpot, int yHotSpot)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_BeginDrag(m_hImageList, nImage, xHotSpot, yHotSpot);
+ }
+
+ static void EndDrag()
+ {
+ ImageList_EndDrag();
+ }
+
+ static BOOL DragMove(POINT pt)
+ {
+ return ImageList_DragMove(pt.x, pt.y);
+ }
+
+ static BOOL DragMove(int x, int y)
+ {
+ return ImageList_DragMove(x, y);
+ }
+
+ BOOL SetDragCursorImage(int nDrag, POINT ptHotSpot)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_SetDragCursorImage(m_hImageList, nDrag, ptHotSpot.x, ptHotSpot.y);
+ }
+
+ BOOL SetDragCursorImage(int nDrag, int xHotSpot, int yHotSpot)
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return ImageList_SetDragCursorImage(m_hImageList, nDrag, xHotSpot, yHotSpot);
+ }
+
+ static BOOL DragShowNolock(BOOL bShow = TRUE)
+ {
+ return ImageList_DragShowNolock(bShow);
+ }
+
+ static CImageList GetDragImage(LPPOINT lpPoint, LPPOINT lpPointHotSpot)
+ {
+ return CImageList(ImageList_GetDragImage(lpPoint, lpPointHotSpot));
+ }
+
+ static BOOL DragEnter(HWND hWnd, POINT point)
+ {
+ return ImageList_DragEnter(hWnd, point.x, point.y);
+ }
+
+ static BOOL DragEnter(HWND hWnd, int x, int y)
+ {
+ return ImageList_DragEnter(hWnd, x, y);
+ }
+
+ static BOOL DragLeave(HWND hWnd)
+ {
+ return ImageList_DragLeave(hWnd);
+ }
+
+ CImageList Duplicate() const
+ {
+ ATLASSERT(m_hImageList != NULL);
+ return CImageList(ImageList_Duplicate(m_hImageList));
+ }
+
+ static CImageList Duplicate(HIMAGELIST hImageList)
+ {
+ ATLASSERT(hImageList != NULL);
+ return CImageList(ImageList_Duplicate(hImageList));
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CToolTipCtrl
+
+class CToolInfo : public TOOLINFO
+{
+public:
+ CToolInfo(UINT nFlags, HWND hWnd, UINT_PTR nIDTool = 0, LPRECT lpRect = NULL, LPTSTR lpstrText = LPSTR_TEXTCALLBACK, LPARAM lUserParam = NULL)
+ {
+ Init(nFlags, hWnd, nIDTool, lpRect, lpstrText, lUserParam);
+ }
+
+ operator LPTOOLINFO() { return this; }
+
+ operator LPARAM() { return (LPARAM)this; }
+
+ void Init(UINT nFlags, HWND hWnd, UINT_PTR nIDTool = 0, LPRECT lpRect = NULL, LPTSTR lpstrText = LPSTR_TEXTCALLBACK, LPARAM lUserParam = NULL)
+ {
+ ATLASSERT(::IsWindow(hWnd));
+ memset(this, 0, sizeof(TOOLINFO));
+ cbSize = RunTimeHelper::SizeOf_TOOLINFO();
+ uFlags = nFlags;
+ if(nIDTool == 0)
+ {
+ hwnd = ::GetParent(hWnd);
+ uFlags |= TTF_IDISHWND;
+ uId = (UINT_PTR)hWnd;
+ }
+ else
+ {
+ hwnd = hWnd;
+ uId = nIDTool;
+ }
+ if(lpRect != NULL)
+ rect = *lpRect;
+ hinst = ModuleHelper::GetResourceInstance();
+ lpszText = lpstrText;
+ lParam = lUserParam;
+ }
+};
+
+template <class TBase>
+class CToolTipCtrlT : public TBase
+{
+public:
+// Constructors
+ CToolTipCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CToolTipCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return TOOLTIPS_CLASS;
+ }
+
+ void GetText(LPTOOLINFO lpToolInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_GETTEXT, 0, (LPARAM)&lpToolInfo);
+ }
+
+ void GetText(LPTSTR lpstrText, HWND hWnd, UINT_PTR nIDTool = 0) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(hWnd != NULL);
+ CToolInfo ti(0, hWnd, nIDTool, NULL, lpstrText);
+ ::SendMessage(this->m_hWnd, TTM_GETTEXT, 0, ti);
+ }
+
+ BOOL GetToolInfo(LPTOOLINFO lpToolInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_GETTOOLINFO, 0, (LPARAM)lpToolInfo);
+ }
+
+ BOOL GetToolInfo(HWND hWnd, UINT_PTR nIDTool, UINT* puFlags, LPRECT lpRect, LPTSTR lpstrText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(hWnd != NULL);
+ ATLASSERT(puFlags != NULL);
+ ATLASSERT(lpRect != NULL);
+ CToolInfo ti(0, hWnd, nIDTool, NULL, lpstrText);
+ BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, TTM_GETTOOLINFO, 0, ti);
+ if(bRet != FALSE)
+ {
+ *puFlags = ti.uFlags;
+ *lpRect = ti.rect;
+ }
+ return bRet;
+ }
+
+ void SetToolInfo(LPTOOLINFO lpToolInfo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_SETTOOLINFO, 0, (LPARAM)lpToolInfo);
+ }
+
+ void SetToolRect(LPTOOLINFO lpToolInfo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_NEWTOOLRECT, 0, (LPARAM)lpToolInfo);
+ }
+
+ void SetToolRect(HWND hWnd, UINT_PTR nIDTool, LPCRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(hWnd != NULL);
+ ATLASSERT(nIDTool != 0);
+
+ CToolInfo ti(0, hWnd, nIDTool, (LPRECT)lpRect, NULL);
+ ::SendMessage(this->m_hWnd, TTM_NEWTOOLRECT, 0, ti);
+ }
+
+ int GetToolCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TTM_GETTOOLCOUNT, 0, 0L);
+ }
+
+ int GetDelayTime(DWORD dwType) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TTM_GETDELAYTIME, dwType, 0L);
+ }
+
+ void SetDelayTime(DWORD dwType, int nTime)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_SETDELAYTIME, dwType, MAKELPARAM(nTime, 0));
+ }
+
+ void GetMargin(LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_GETMARGIN, 0, (LPARAM)lpRect);
+ }
+
+ void SetMargin(LPRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_SETMARGIN, 0, (LPARAM)lpRect);
+ }
+
+ int GetMaxTipWidth() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TTM_GETMAXTIPWIDTH, 0, 0L);
+ }
+
+ int SetMaxTipWidth(int nWidth)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TTM_SETMAXTIPWIDTH, 0, nWidth);
+ }
+
+ COLORREF GetTipBkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TTM_GETTIPBKCOLOR, 0, 0L);
+ }
+
+ void SetTipBkColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_SETTIPBKCOLOR, (WPARAM)clr, 0L);
+ }
+
+ COLORREF GetTipTextColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TTM_GETTIPTEXTCOLOR, 0, 0L);
+ }
+
+ void SetTipTextColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_SETTIPTEXTCOLOR, (WPARAM)clr, 0L);
+ }
+
+ BOOL GetCurrentTool(LPTOOLINFO lpToolInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_GETCURRENTTOOL, 0, (LPARAM)lpToolInfo);
+ }
+
+ SIZE GetBubbleSize(LPTOOLINFO lpToolInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, TTM_GETBUBBLESIZE, 0, (LPARAM)lpToolInfo);
+ SIZE size = { GET_X_LPARAM(dwRet), GET_Y_LPARAM(dwRet) };
+ return size;
+ }
+
+ BOOL SetTitle(UINT_PTR uIcon, LPCTSTR lpstrTitle)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_SETTITLE, uIcon, (LPARAM)lpstrTitle);
+ }
+
+
+ BOOL SetTitle(HICON hIcon, LPCTSTR lpstrTitle)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_SETTITLE, (WPARAM)hIcon, (LPARAM)lpstrTitle);
+ }
+
+ void GetTitle(PTTGETTITLE pTTGetTitle) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_GETTITLE, 0, (LPARAM)pTTGetTitle);
+ }
+
+ void SetWindowTheme(LPCWSTR lpstrTheme)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_SETWINDOWTHEME, 0, (LPARAM)lpstrTheme);
+ }
+
+// Operations
+ void Activate(BOOL bActivate)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_ACTIVATE, bActivate, 0L);
+ }
+
+ BOOL AddTool(LPTOOLINFO lpToolInfo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_ADDTOOL, 0, (LPARAM)lpToolInfo);
+ }
+
+ BOOL AddTool(HWND hWnd, ATL::_U_STRINGorID text = LPSTR_TEXTCALLBACK, LPCRECT lpRectTool = NULL, UINT_PTR nIDTool = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(hWnd != NULL);
+ // the toolrect and toolid must both be zero or both valid
+ ATLASSERT(((lpRectTool != NULL) && (nIDTool != 0)) || ((lpRectTool == NULL) && (nIDTool == 0)));
+
+ CToolInfo ti(0, hWnd, nIDTool, (LPRECT)lpRectTool, (LPTSTR)text.m_lpstr);
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_ADDTOOL, 0, ti);
+ }
+
+ void DelTool(LPTOOLINFO lpToolInfo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_DELTOOL, 0, (LPARAM)lpToolInfo);
+ }
+
+ void DelTool(HWND hWnd, UINT_PTR nIDTool = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(hWnd != NULL);
+
+ CToolInfo ti(0, hWnd, nIDTool, NULL, NULL);
+ ::SendMessage(this->m_hWnd, TTM_DELTOOL, 0, ti);
+ }
+
+ BOOL HitTest(LPTTHITTESTINFO lpHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_HITTEST, 0, (LPARAM)lpHitTestInfo);
+ }
+
+ BOOL HitTest(HWND hWnd, POINT pt, LPTOOLINFO lpToolInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(hWnd != NULL);
+ ATLASSERT(lpToolInfo != NULL);
+
+ TTHITTESTINFO hti = {};
+ hti.ti.cbSize = RunTimeHelper::SizeOf_TOOLINFO();
+ hti.hwnd = hWnd;
+ hti.pt.x = pt.x;
+ hti.pt.y = pt.y;
+ if((BOOL)::SendMessage(this->m_hWnd, TTM_HITTEST, 0, (LPARAM)&hti) != FALSE)
+ {
+ *lpToolInfo = hti.ti;
+ return TRUE;
+ }
+ return FALSE;
+ }
+
+ void RelayEvent(LPMSG lpMsg)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_RELAYEVENT, 0, (LPARAM)lpMsg);
+ }
+
+ void UpdateTipText(LPTOOLINFO lpToolInfo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_UPDATETIPTEXT, 0, (LPARAM)lpToolInfo);
+ }
+
+ void UpdateTipText(ATL::_U_STRINGorID text, HWND hWnd, UINT_PTR nIDTool = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(hWnd != NULL);
+
+ CToolInfo ti(0, hWnd, nIDTool, NULL, (LPTSTR)text.m_lpstr);
+ ::SendMessage(this->m_hWnd, TTM_UPDATETIPTEXT, 0, ti);
+ }
+
+ BOOL EnumTools(UINT_PTR nTool, LPTOOLINFO lpToolInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_ENUMTOOLS, nTool, (LPARAM)lpToolInfo);
+ }
+
+ void Pop()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_POP, 0, 0L);
+ }
+
+ void TrackActivate(LPTOOLINFO lpToolInfo, BOOL bActivate)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_TRACKACTIVATE, bActivate, (LPARAM)lpToolInfo);
+ }
+
+ void TrackActivate(HWND hWnd, UINT_PTR nIDTool, BOOL bActivate)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(hWnd != NULL);
+
+ CToolInfo ti(0, hWnd, nIDTool);
+ ::SendMessage(this->m_hWnd, TTM_TRACKACTIVATE, bActivate, ti);
+ }
+
+ void TrackPosition(int xPos, int yPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_TRACKPOSITION, 0, MAKELPARAM(xPos, yPos));
+ }
+
+ void Update()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_UPDATE, 0, 0L);
+ }
+
+ BOOL AdjustRect(LPRECT lpRect, BOOL bLarger /*= TRUE*/)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TTM_ADJUSTRECT, bLarger, (LPARAM)lpRect);
+ }
+
+ void Popup()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TTM_POPUP, 0, 0L);
+ }
+};
+
+typedef CToolTipCtrlT<ATL::CWindow> CToolTipCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CHeaderCtrl
+
+template <class TBase>
+class CHeaderCtrlT : public TBase
+{
+public:
+// Constructors
+ CHeaderCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CHeaderCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return WC_HEADER;
+ }
+
+ int GetItemCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_GETITEMCOUNT, 0, 0L);
+ }
+
+ BOOL GetItem(int nIndex, LPHDITEM pHeaderItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_GETITEM, nIndex, (LPARAM)pHeaderItem);
+ }
+
+ BOOL SetItem(int nIndex, LPHDITEM pHeaderItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_SETITEM, nIndex, (LPARAM)pHeaderItem);
+ }
+
+ CImageList GetImageList() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, HDM_GETIMAGELIST, 0, 0L));
+ }
+
+ CImageList SetImageList(HIMAGELIST hImageList)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, HDM_SETIMAGELIST, 0, (LPARAM)hImageList));
+ }
+
+ BOOL GetOrderArray(int nSize, int* lpnArray) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_GETORDERARRAY, nSize, (LPARAM)lpnArray);
+ }
+
+ BOOL SetOrderArray(int nSize, int* lpnArray)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_SETORDERARRAY, nSize, (LPARAM)lpnArray);
+ }
+
+ BOOL GetItemRect(int nIndex, LPRECT lpItemRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_GETITEMRECT, nIndex, (LPARAM)lpItemRect);
+ }
+
+ int SetHotDivider(BOOL bPos, DWORD dwInputValue)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_SETHOTDIVIDER, bPos, dwInputValue);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+ int GetBitmapMargin() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_GETBITMAPMARGIN, 0, 0L);
+ }
+
+ int SetBitmapMargin(int nWidth)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_SETBITMAPMARGIN, nWidth, 0L);
+ }
+
+ int SetFilterChangeTimeout(DWORD dwTimeOut)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_SETFILTERCHANGETIMEOUT, 0, dwTimeOut);
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ BOOL GetItemDropDownRect(int nIndex, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_GETITEMDROPDOWNRECT, nIndex, (LPARAM)lpRect);
+ }
+
+ BOOL GetOverflowRect(LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_GETOVERFLOWRECT, 0, (LPARAM)lpRect);
+ }
+
+ int GetFocusedItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_GETFOCUSEDITEM, 0, 0L);
+ }
+
+ BOOL SetFocusedItem(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_SETFOCUSEDITEM, 0, nIndex);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+// Operations
+ int InsertItem(int nIndex, LPHDITEM phdi)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_INSERTITEM, nIndex, (LPARAM)phdi);
+ }
+
+ int AddItem(LPHDITEM phdi)
+ {
+ return InsertItem(GetItemCount(), phdi);
+ }
+
+ BOOL DeleteItem(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_DELETEITEM, nIndex, 0L);
+ }
+
+ BOOL Layout(HD_LAYOUT* pHeaderLayout)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, HDM_LAYOUT, 0, (LPARAM)pHeaderLayout);
+ }
+
+ int HitTest(LPHDHITTESTINFO lpHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_HITTEST, 0, (LPARAM)lpHitTestInfo);
+ }
+
+ int OrderToIndex(int nOrder)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_ORDERTOINDEX, nOrder, 0L);
+ }
+
+ CImageList CreateDragImage(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, HDM_CREATEDRAGIMAGE, nIndex, 0L));
+ }
+
+ int EditFilter(int nColumn, BOOL bDiscardChanges)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_EDITFILTER, nColumn, MAKELPARAM(bDiscardChanges, 0));
+ }
+
+ int ClearFilter(int nColumn)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_CLEARFILTER, nColumn, 0L);
+ }
+
+ int ClearAllFilters()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, HDM_CLEARFILTER, (WPARAM)-1, 0L);
+ }
+};
+
+typedef CHeaderCtrlT<ATL::CWindow> CHeaderCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CListViewCtrl
+
+template <class TBase>
+class CListViewCtrlT : public TBase
+{
+public:
+// Constructors
+ CListViewCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CListViewCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return WC_LISTVIEW;
+ }
+
+ COLORREF GetBkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETBKCOLOR, 0, 0L);
+ }
+
+ BOOL SetBkColor(COLORREF cr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETBKCOLOR, 0, cr);
+ }
+
+ CImageList GetImageList(int nImageListType) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, LVM_GETIMAGELIST, nImageListType, 0L));
+ }
+
+ CImageList SetImageList(HIMAGELIST hImageList, int nImageList)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, LVM_SETIMAGELIST, nImageList, (LPARAM)hImageList));
+ }
+
+ int GetItemCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETITEMCOUNT, 0, 0L);
+ }
+
+ BOOL SetItemCount(int nItems)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMCOUNT, nItems, 0L);
+ }
+
+ BOOL GetItem(LPLVITEM pItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEM, 0, (LPARAM)pItem);
+ }
+
+ BOOL SetItem(const LVITEM* pItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEM, 0, (LPARAM)pItem);
+ }
+
+ BOOL SetItem(int nItem, int nSubItem, UINT nMask, LPCTSTR lpszItem,
+ int nImage, UINT nState, UINT nStateMask, LPARAM lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVITEM lvi = {};
+ lvi.mask = nMask;
+ lvi.iItem = nItem;
+ lvi.iSubItem = nSubItem;
+ lvi.stateMask = nStateMask;
+ lvi.state = nState;
+ lvi.pszText = (LPTSTR) lpszItem;
+ lvi.iImage = nImage;
+ lvi.lParam = lParam;
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEM, 0, (LPARAM)&lvi);
+ }
+
+ UINT GetItemState(int nItem, UINT nMask) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, LVM_GETITEMSTATE, nItem, nMask);
+ }
+
+ BOOL SetItemState(int nItem, UINT nState, UINT nStateMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVITEM lvi = {};
+ lvi.state = nState;
+ lvi.stateMask = nStateMask;
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMSTATE, nItem, (LPARAM)&lvi);
+ }
+
+ BOOL SetItemState(int nItem, LPLVITEM pItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMSTATE, nItem, (LPARAM)pItem);
+ }
+
+ BOOL GetItemText(int nItem, int nSubItem, BSTR& bstrText) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(bstrText == NULL);
+ LVITEM lvi = {};
+ lvi.iSubItem = nSubItem;
+
+ LPTSTR lpstrText = NULL;
+ int nRes = 0;
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ ATLTRY(lpstrText = new TCHAR[nLen]);
+ if(lpstrText == NULL)
+ break;
+ lpstrText[0] = NULL;
+ lvi.cchTextMax = nLen;
+ lvi.pszText = lpstrText;
+ nRes = (int)::SendMessage(this->m_hWnd, LVM_GETITEMTEXT, (WPARAM)nItem, (LPARAM)&lvi);
+ if(nRes < nLen - 1)
+ break;
+ delete [] lpstrText;
+ lpstrText = NULL;
+ }
+
+ if(lpstrText != NULL)
+ {
+ if(nRes != 0)
+ bstrText = ::SysAllocString(T2OLE(lpstrText));
+ delete [] lpstrText;
+ }
+
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+
+#ifdef __ATLSTR_H__
+ int GetItemText(int nItem, int nSubItem, ATL::CString& strText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVITEM lvi = {};
+ lvi.iSubItem = nSubItem;
+
+ strText.Empty();
+ int nRes = 0;
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ lvi.cchTextMax = nLen;
+ lvi.pszText = strText.GetBufferSetLength(nLen);
+ if(lvi.pszText == NULL)
+ {
+ nRes = 0;
+ break;
+ }
+ nRes = (int)::SendMessage(this->m_hWnd, LVM_GETITEMTEXT, (WPARAM)nItem, (LPARAM)&lvi);
+ if(nRes < nLen - 1)
+ break;
+ }
+ strText.ReleaseBuffer();
+ return nRes;
+ }
+#endif // __ATLSTR_H__
+
+ int GetItemText(int nItem, int nSubItem, LPTSTR lpszText, int nLen) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVITEM lvi = {};
+ lvi.iSubItem = nSubItem;
+ lvi.cchTextMax = nLen;
+ lvi.pszText = lpszText;
+ return (int)::SendMessage(this->m_hWnd, LVM_GETITEMTEXT, (WPARAM)nItem, (LPARAM)&lvi);
+ }
+
+ BOOL SetItemText(int nItem, int nSubItem, LPCTSTR lpszText)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItem(nItem, nSubItem, LVIF_TEXT, lpszText, 0, 0, 0, 0);
+ }
+
+ DWORD_PTR GetItemData(int nItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVITEM lvi = {};
+ lvi.iItem = nItem;
+ lvi.mask = LVIF_PARAM;
+ BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEM, 0, (LPARAM)&lvi);
+ return (DWORD_PTR)(bRet ? lvi.lParam : NULL);
+ }
+
+ BOOL SetItemData(int nItem, DWORD_PTR dwData)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItem(nItem, 0, LVIF_PARAM, NULL, 0, 0, 0, (LPARAM)dwData);
+ }
+
+ UINT GetCallbackMask() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, LVM_GETCALLBACKMASK, 0, 0L);
+ }
+
+ BOOL SetCallbackMask(UINT nMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETCALLBACKMASK, nMask, 0L);
+ }
+
+ BOOL GetItemPosition(int nItem, LPPOINT lpPoint) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEMPOSITION, nItem, (LPARAM)lpPoint);
+ }
+
+ BOOL SetItemPosition(int nItem, POINT pt)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(((this->GetStyle() & LVS_TYPEMASK) == LVS_ICON) || ((this->GetStyle() & LVS_TYPEMASK) == LVS_SMALLICON));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMPOSITION32, nItem, (LPARAM)&pt);
+ }
+
+ BOOL SetItemPosition(int nItem, int x, int y)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(((this->GetStyle() & LVS_TYPEMASK) == LVS_ICON) || ((this->GetStyle() & LVS_TYPEMASK) == LVS_SMALLICON));
+ POINT pt = { x, y };
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMPOSITION32, nItem, (LPARAM)&pt);
+ }
+
+ int GetStringWidth(LPCTSTR lpsz) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETSTRINGWIDTH, 0, (LPARAM)lpsz);
+ }
+
+ CEdit GetEditControl() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CEdit((HWND)::SendMessage(this->m_hWnd, LVM_GETEDITCONTROL, 0, 0L));
+ }
+
+ BOOL GetColumn(int nCol, LVCOLUMN* pColumn) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETCOLUMN, nCol, (LPARAM)pColumn);
+ }
+
+ BOOL SetColumn(int nCol, const LVCOLUMN* pColumn)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETCOLUMN, nCol, (LPARAM)pColumn);
+ }
+
+ int GetColumnWidth(int nCol) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETCOLUMNWIDTH, nCol, 0L);
+ }
+
+ BOOL SetColumnWidth(int nCol, int cx)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETCOLUMNWIDTH, nCol, MAKELPARAM(cx, 0));
+ }
+
+ BOOL GetViewRect(LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETVIEWRECT, 0, (LPARAM)lpRect);
+ }
+
+ COLORREF GetTextColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETTEXTCOLOR, 0, 0L);
+ }
+
+ BOOL SetTextColor(COLORREF cr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETTEXTCOLOR, 0, cr);
+ }
+
+ COLORREF GetTextBkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETTEXTBKCOLOR, 0, 0L);
+ }
+
+ BOOL SetTextBkColor(COLORREF cr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETTEXTBKCOLOR, 0, cr);
+ }
+
+ int GetTopIndex() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETTOPINDEX, 0, 0L);
+ }
+
+ int GetCountPerPage() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETCOUNTPERPAGE, 0, 0L);
+ }
+
+ BOOL GetOrigin(LPPOINT lpPoint) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETORIGIN, 0, (LPARAM)lpPoint);
+ }
+
+ UINT GetSelectedCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, LVM_GETSELECTEDCOUNT, 0, 0L);
+ }
+
+ BOOL GetItemRect(int nItem, LPRECT lpRect, UINT nCode) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ lpRect->left = nCode;
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEMRECT, (WPARAM)nItem, (LPARAM)lpRect);
+ }
+
+ HCURSOR GetHotCursor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HCURSOR)::SendMessage(this->m_hWnd, LVM_GETHOTCURSOR, 0, 0L);
+ }
+
+ HCURSOR SetHotCursor(HCURSOR hHotCursor)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HCURSOR)::SendMessage(this->m_hWnd, LVM_SETHOTCURSOR, 0, (LPARAM)hHotCursor);
+ }
+
+ int GetHotItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETHOTITEM, 0, 0L);
+ }
+
+ int SetHotItem(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_SETHOTITEM, nIndex, 0L);
+ }
+
+ BOOL GetColumnOrderArray(int nCount, int* lpnArray) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETCOLUMNORDERARRAY, nCount, (LPARAM)lpnArray);
+ }
+
+ BOOL SetColumnOrderArray(int nCount, int* lpnArray)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETCOLUMNORDERARRAY, nCount, (LPARAM)lpnArray);
+ }
+
+ CHeaderCtrl GetHeader() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CHeaderCtrl((HWND)::SendMessage(this->m_hWnd, LVM_GETHEADER, 0, 0L));
+ }
+
+ BOOL GetSubItemRect(int nItem, int nSubItem, int nFlag, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & LVS_TYPEMASK) == LVS_REPORT);
+ ATLASSERT(lpRect != NULL);
+ lpRect->top = nSubItem;
+ lpRect->left = nFlag;
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETSUBITEMRECT, nItem, (LPARAM)lpRect);
+ }
+
+ DWORD SetIconSpacing(int cx, int cy)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & LVS_TYPEMASK) == LVS_ICON);
+ return (DWORD)::SendMessage(this->m_hWnd, LVM_SETICONSPACING, 0, MAKELPARAM(cx, cy));
+ }
+
+ int GetISearchString(LPTSTR lpstr) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETISEARCHSTRING, 0, (LPARAM)lpstr);
+ }
+
+ void GetItemSpacing(SIZE& sizeSpacing, BOOL bSmallIconView = FALSE) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, LVM_GETITEMSPACING, bSmallIconView, 0L);
+ sizeSpacing.cx = GET_X_LPARAM(dwRet);
+ sizeSpacing.cy = GET_Y_LPARAM(dwRet);
+ }
+
+ // single-selection only
+ int GetSelectedIndex() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & LVS_SINGLESEL) != 0);
+ return (int)::SendMessage(this->m_hWnd, LVM_GETNEXTITEM, (WPARAM)-1, MAKELPARAM(LVNI_ALL | LVNI_SELECTED, 0));
+ }
+
+ BOOL GetSelectedItem(LPLVITEM pItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & LVS_SINGLESEL) != 0);
+ ATLASSERT(pItem != NULL);
+ pItem->iItem = (int)::SendMessage(this->m_hWnd, LVM_GETNEXTITEM, (WPARAM)-1, MAKELPARAM(LVNI_ALL | LVNI_SELECTED, 0));
+ if(pItem->iItem == -1)
+ return FALSE;
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEM, 0, (LPARAM)pItem);
+ }
+
+ // extended list view styles
+ DWORD GetExtendedListViewStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, LVM_GETEXTENDEDLISTVIEWSTYLE, 0, 0L);
+ }
+
+ // dwExMask = 0 means all styles
+ DWORD SetExtendedListViewStyle(DWORD dwExStyle, DWORD dwExMask = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, LVM_SETEXTENDEDLISTVIEWSTYLE, dwExMask, dwExStyle);
+ }
+
+ // checkboxes only
+ BOOL GetCheckState(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((GetExtendedListViewStyle() & LVS_EX_CHECKBOXES) != 0);
+ UINT uRet = GetItemState(nIndex, LVIS_STATEIMAGEMASK);
+ return (uRet >> 12) - 1;
+ }
+
+ BOOL SetCheckState(int nItem, BOOL bCheck)
+ {
+ int nCheck = bCheck ? 2 : 1; // one based index
+ return SetItemState(nItem, INDEXTOSTATEIMAGEMASK(nCheck), LVIS_STATEIMAGEMASK);
+ }
+
+ // view type
+ DWORD GetViewType() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (this->GetStyle() & LVS_TYPEMASK);
+ }
+
+ DWORD SetViewType(DWORD dwType)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((dwType == LVS_ICON) || (dwType == LVS_SMALLICON) || (dwType == LVS_LIST) || (dwType == LVS_REPORT));
+ DWORD dwOldType = GetViewType();
+ if(dwType != dwOldType)
+ this->ModifyStyle(LVS_TYPEMASK, (dwType & LVS_TYPEMASK));
+ return dwOldType;
+ }
+
+ BOOL GetBkImage(LPLVBKIMAGE plvbki) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETBKIMAGE, 0, (LPARAM)plvbki);
+ }
+
+ BOOL SetBkImage(LPLVBKIMAGE plvbki)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETBKIMAGE, 0, (LPARAM)plvbki);
+ }
+
+ int GetSelectionMark() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETSELECTIONMARK, 0, 0L);
+ }
+
+ int SetSelectionMark(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_SETSELECTIONMARK, 0, nIndex);
+ }
+
+ BOOL GetWorkAreas(int nWorkAreas, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETWORKAREAS, nWorkAreas, (LPARAM)lpRect);
+ }
+
+ BOOL SetWorkAreas(int nWorkAreas, LPRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETWORKAREAS, nWorkAreas, (LPARAM)lpRect);
+ }
+
+ DWORD GetHoverTime() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((GetExtendedListViewStyle() & (LVS_EX_TRACKSELECT | LVS_EX_ONECLICKACTIVATE | LVS_EX_TWOCLICKACTIVATE)) != 0);
+ return (DWORD)::SendMessage(this->m_hWnd, LVM_GETHOVERTIME, 0, 0L);
+ }
+
+ DWORD SetHoverTime(DWORD dwHoverTime)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((GetExtendedListViewStyle() & (LVS_EX_TRACKSELECT | LVS_EX_ONECLICKACTIVATE | LVS_EX_TWOCLICKACTIVATE)) != 0);
+ return (DWORD)::SendMessage(this->m_hWnd, LVM_SETHOVERTIME, 0, dwHoverTime);
+ }
+
+ BOOL GetNumberOfWorkAreas(int* pnWorkAreas) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETNUMBEROFWORKAREAS, 0, (LPARAM)pnWorkAreas);
+ }
+
+ BOOL SetItemCountEx(int nItems, DWORD dwFlags)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(((this->GetStyle() & LVS_OWNERDATA) != 0) && (((this->GetStyle() & LVS_TYPEMASK) == LVS_REPORT) || ((this->GetStyle() & LVS_TYPEMASK) == LVS_LIST)));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMCOUNT, nItems, dwFlags);
+ }
+
+ CToolTipCtrl GetToolTips() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, LVM_GETTOOLTIPS, 0, 0L));
+ }
+
+ CToolTipCtrl SetToolTips(HWND hWndTT)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, LVM_SETTOOLTIPS, (WPARAM)hWndTT, 0L));
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+ int GetSelectedColumn() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETSELECTEDCOLUMN, 0, 0L);
+ }
+
+ void SetSelectedColumn(int nColumn)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_SETSELECTEDCOLUMN, nColumn, 0L);
+ }
+
+ DWORD GetView() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, LVM_GETVIEW, 0, 0L);
+ }
+
+ int SetView(DWORD dwView)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_SETVIEW, dwView, 0L);
+ }
+
+ BOOL IsGroupViewEnabled() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_ISGROUPVIEWENABLED, 0, 0L);
+ }
+
+ int GetGroupInfo(int nGroupID, PLVGROUP pGroup) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETGROUPINFO, nGroupID, (LPARAM)pGroup);
+ }
+
+ int SetGroupInfo(int nGroupID, PLVGROUP pGroup)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_SETGROUPINFO, nGroupID, (LPARAM)pGroup);
+ }
+
+ void GetGroupMetrics(PLVGROUPMETRICS pGroupMetrics) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_GETGROUPMETRICS, 0, (LPARAM)pGroupMetrics);
+ }
+
+ void SetGroupMetrics(PLVGROUPMETRICS pGroupMetrics)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_SETGROUPMETRICS, 0, (LPARAM)pGroupMetrics);
+ }
+
+ void GetTileViewInfo(PLVTILEVIEWINFO pTileViewInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_GETTILEVIEWINFO, 0, (LPARAM)pTileViewInfo);
+ }
+
+ BOOL SetTileViewInfo(PLVTILEVIEWINFO pTileViewInfo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETTILEVIEWINFO, 0, (LPARAM)pTileViewInfo);
+ }
+
+ void GetTileInfo(PLVTILEINFO pTileInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_GETTILEINFO, 0, (LPARAM)pTileInfo);
+ }
+
+ BOOL SetTileInfo(PLVTILEINFO pTileInfo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETTILEINFO, 0, (LPARAM)pTileInfo);
+ }
+
+ BOOL GetInsertMark(LPLVINSERTMARK pInsertMark) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETINSERTMARK, 0, (LPARAM)pInsertMark);
+ }
+
+ BOOL SetInsertMark(LPLVINSERTMARK pInsertMark)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETINSERTMARK, 0, (LPARAM)pInsertMark);
+ }
+
+ int GetInsertMarkRect(LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETINSERTMARKRECT, 0, (LPARAM)lpRect);
+ }
+
+ COLORREF GetInsertMarkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETINSERTMARKCOLOR, 0, 0L);
+ }
+
+ COLORREF SetInsertMarkColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, LVM_SETINSERTMARKCOLOR, 0, clr);
+ }
+
+ COLORREF GetOutlineColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, LVM_GETOUTLINECOLOR, 0, 0L);
+ }
+
+ COLORREF SetOutlineColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, LVM_SETOUTLINECOLOR, 0, clr);
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ int GetGroupCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETGROUPCOUNT, 0, 0L);
+ }
+
+ BOOL GetGroupInfoByIndex(int nIndex, PLVGROUP pGroup) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETGROUPINFOBYINDEX, nIndex, (LPARAM)pGroup);
+ }
+
+ BOOL GetGroupRect(int nGroupID, int nType, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpRect != NULL);
+ if(lpRect != NULL)
+ lpRect->top = nType;
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETGROUPRECT, nGroupID, (LPARAM)lpRect);
+ }
+
+ UINT GetGroupState(int nGroupID, UINT uMask) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, LVM_GETGROUPSTATE, nGroupID, (LPARAM)uMask);
+ }
+
+ int GetFocusedGroup() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETFOCUSEDGROUP, 0, 0L);
+ }
+
+ BOOL GetEmptyText(LPWSTR lpstrText, int cchText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETEMPTYTEXT, cchText, (LPARAM)lpstrText);
+ }
+
+ BOOL GetFooterRect(LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETFOOTERRECT, 0, (LPARAM)lpRect);
+ }
+
+ BOOL GetFooterInfo(LPLVFOOTERINFO lpFooterInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETFOOTERINFO, 0, (LPARAM)lpFooterInfo);
+ }
+
+ BOOL GetFooterItemRect(int nItem, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETFOOTERITEMRECT, nItem, (LPARAM)lpRect);
+ }
+
+ BOOL GetFooterItem(int nItem, LPLVFOOTERITEM lpFooterItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETFOOTERITEM, nItem, (LPARAM)lpFooterItem);
+ }
+
+ BOOL GetItemIndexRect(PLVITEMINDEX pItemIndex, int nSubItem, int nType, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(pItemIndex != NULL);
+ ATLASSERT(lpRect != NULL);
+ if(lpRect != NULL)
+ {
+ lpRect->top = nSubItem;
+ lpRect->left = nType;
+ }
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETITEMINDEXRECT, (WPARAM)pItemIndex, (LPARAM)lpRect);
+ }
+
+ BOOL SetItemIndexState(PLVITEMINDEX pItemIndex, UINT uState, UINT dwMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVITEM lvi = {};
+ lvi.state = uState;
+ lvi.stateMask = dwMask;
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETITEMINDEXSTATE, (WPARAM)pItemIndex, (LPARAM)&lvi);
+ }
+
+ BOOL GetNextItemIndex(PLVITEMINDEX pItemIndex, WORD wFlags) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_GETNEXTITEMINDEX, (WPARAM)pItemIndex, MAKELPARAM(wFlags, 0));
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+// Operations
+ int InsertColumn(int nCol, const LVCOLUMN* pColumn)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_INSERTCOLUMN, nCol, (LPARAM)pColumn);
+ }
+
+ int InsertColumn(int nCol, LPCTSTR lpszColumnHeading, int nFormat = LVCFMT_LEFT,
+ int nWidth = -1, int nSubItem = -1, int iImage = -1, int iOrder = -1)
+ {
+ LVCOLUMN column = {};
+ column.mask = LVCF_TEXT | LVCF_FMT;
+ column.pszText = (LPTSTR)lpszColumnHeading;
+ column.fmt = nFormat;
+ if (nWidth != -1)
+ {
+ column.mask |= LVCF_WIDTH;
+ column.cx = nWidth;
+ }
+ if (nSubItem != -1)
+ {
+ column.mask |= LVCF_SUBITEM;
+ column.iSubItem = nSubItem;
+ }
+ if (iImage != -1)
+ {
+ column.mask |= LVCF_IMAGE;
+ column.iImage = iImage;
+ }
+ if (iOrder != -1)
+ {
+ column.mask |= LVCF_ORDER;
+ column.iOrder = iOrder;
+ }
+ return InsertColumn(nCol, &column);
+ }
+
+ BOOL DeleteColumn(int nCol)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_DELETECOLUMN, nCol, 0L);
+ }
+
+ int InsertItem(UINT nMask, int nItem, LPCTSTR lpszItem, UINT nState, UINT nStateMask, int nImage, LPARAM lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVITEM item = {};
+ item.mask = nMask;
+ item.iItem = nItem;
+ item.iSubItem = 0;
+ item.pszText = (LPTSTR)lpszItem;
+ item.state = nState;
+ item.stateMask = nStateMask;
+ item.iImage = nImage;
+ item.lParam = lParam;
+ return InsertItem(&item);
+ }
+
+ int InsertItem(const LVITEM* pItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_INSERTITEM, 0, (LPARAM)pItem);
+ }
+
+ int InsertItem(int nItem, LPCTSTR lpszItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return InsertItem(LVIF_TEXT, nItem, lpszItem, 0, 0, 0, 0);
+ }
+
+ int InsertItem(int nItem, LPCTSTR lpszItem, int nImage)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return InsertItem(LVIF_TEXT|LVIF_IMAGE, nItem, lpszItem, 0, 0, nImage, 0);
+ }
+
+ int GetNextItem(int nItem, int nFlags) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_GETNEXTITEM, nItem, MAKELPARAM(nFlags, 0));
+ }
+
+ BOOL DeleteItem(int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_DELETEITEM, nItem, 0L);
+ }
+
+ BOOL DeleteAllItems()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_DELETEALLITEMS, 0, 0L);
+ }
+
+ int FindItem(LVFINDINFO* pFindInfo, int nStart = -1) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_FINDITEM, nStart, (LPARAM)pFindInfo);
+ }
+
+ int FindItem(LPCTSTR lpstrFind, bool bPartial = true, bool bWrap = false, int nStart = -1) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVFINDINFO lvfi = {};
+ lvfi.flags = LVFI_STRING | (bWrap ? LVFI_WRAP : 0) | (bPartial ? LVFI_PARTIAL : 0);
+ lvfi.psz = lpstrFind;
+ return (int)::SendMessage(this->m_hWnd, LVM_FINDITEM, nStart, (LPARAM)&lvfi);
+ }
+
+ int HitTest(LVHITTESTINFO* pHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_HITTEST, 0, (LPARAM)pHitTestInfo);
+ }
+
+ int HitTest(POINT pt, UINT* pFlags) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVHITTESTINFO hti = {};
+ hti.pt = pt;
+ int nRes = (int)::SendMessage(this->m_hWnd, LVM_HITTEST, 0, (LPARAM)&hti);
+ if (pFlags != NULL)
+ *pFlags = hti.flags;
+ return nRes;
+ }
+
+ BOOL EnsureVisible(int nItem, BOOL bPartialOK)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_ENSUREVISIBLE, nItem, MAKELPARAM(bPartialOK, 0));
+ }
+
+ BOOL Scroll(int cx, int cy)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SCROLL, cx, cy);
+ }
+
+ BOOL Scroll(SIZE size)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SCROLL, size.cx, size.cy);
+ }
+
+ BOOL RedrawItems(int nFirst, int nLast)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_REDRAWITEMS, nFirst, nLast);
+ }
+
+ BOOL Arrange(UINT nCode)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_ARRANGE, nCode, 0L);
+ }
+
+ CEdit EditLabel(int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CEdit((HWND)::SendMessage(this->m_hWnd, LVM_EDITLABEL, nItem, 0L));
+ }
+
+ BOOL Update(int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_UPDATE, nItem, 0L);
+ }
+
+ BOOL SortItems(PFNLVCOMPARE pfnCompare, LPARAM lParamSort)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SORTITEMS, (WPARAM)lParamSort, (LPARAM)pfnCompare);
+ }
+
+ CImageList RemoveImageList(int nImageList)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, LVM_SETIMAGELIST, (WPARAM)nImageList, NULL));
+ }
+
+ CImageList CreateDragImage(int nItem, LPPOINT lpPoint)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, LVM_CREATEDRAGIMAGE, nItem, (LPARAM)lpPoint));
+ }
+
+ DWORD ApproximateViewRect(int cx = -1, int cy = -1, int nCount = -1)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, LVM_APPROXIMATEVIEWRECT, nCount, MAKELPARAM(cx, cy));
+ }
+
+ int SubItemHitTest(LPLVHITTESTINFO lpInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_SUBITEMHITTEST, 0, (LPARAM)lpInfo);
+ }
+
+ int AddColumn(LPCTSTR strColumn, int nItem, int nSubItem = -1,
+ int nMask = LVCF_FMT | LVCF_WIDTH | LVCF_TEXT | LVCF_SUBITEM,
+ int nFmt = LVCFMT_LEFT)
+ {
+ const int cxOffset = 15;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVCOLUMN lvc = {};
+ lvc.mask = nMask;
+ lvc.fmt = nFmt;
+ lvc.pszText = (LPTSTR)strColumn;
+ lvc.cx = GetStringWidth(lvc.pszText) + cxOffset;
+ if(nMask & LVCF_SUBITEM)
+ lvc.iSubItem = (nSubItem != -1) ? nSubItem : nItem;
+ return InsertColumn(nItem, &lvc);
+ }
+
+ int AddItem(int nItem, int nSubItem, LPCTSTR strItem, int nImageIndex = -3)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVITEM lvItem = {};
+ lvItem.mask = LVIF_TEXT;
+ lvItem.iItem = nItem;
+ lvItem.iSubItem = nSubItem;
+ lvItem.pszText = (LPTSTR)strItem;
+ if(nImageIndex != -3)
+ {
+ lvItem.mask |= LVIF_IMAGE;
+ lvItem.iImage = nImageIndex;
+ }
+ if(nSubItem == 0)
+ return InsertItem(&lvItem);
+ return SetItem(&lvItem) ? nItem : -1;
+ }
+
+ BOOL SortItemsEx(PFNLVCOMPARE pfnCompare, LPARAM lParamSort)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SORTITEMSEX, (WPARAM)lParamSort, (LPARAM)pfnCompare);
+ }
+
+ int InsertGroup(int nItem, PLVGROUP pGroup)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_INSERTGROUP, nItem, (LPARAM)pGroup);
+ }
+
+ int AddGroup(PLVGROUP pGroup)
+ {
+ return InsertGroup(-1, pGroup);
+ }
+
+ int RemoveGroup(int nGroupID)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_REMOVEGROUP, nGroupID, 0L);
+ }
+
+ void MoveGroup(int nGroupID, int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_MOVEGROUP, nGroupID, nItem);
+ }
+
+ void MoveItemToGroup(int nItem, int nGroupID)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_MOVEITEMTOGROUP, nItem, nGroupID);
+ }
+
+ int EnableGroupView(BOOL bEnable)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_ENABLEGROUPVIEW, bEnable, 0L);
+ }
+
+ int SortGroups(PFNLVGROUPCOMPARE pCompareFunc, LPVOID lpVoid = NULL)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_SORTGROUPS, (WPARAM)pCompareFunc, (LPARAM)lpVoid);
+ }
+
+ void InsertGroupSorted(PLVINSERTGROUPSORTED pInsertGroupSorted)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_INSERTGROUPSORTED, (WPARAM)pInsertGroupSorted, 0L);
+ }
+
+ void RemoveAllGroups()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_REMOVEALLGROUPS, 0, 0L);
+ }
+
+ BOOL HasGroup(int nGroupID)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_HASGROUP, nGroupID, 0L);
+ }
+
+ BOOL InsertMarkHitTest(LPPOINT lpPoint, LPLVINSERTMARK pInsertMark) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_INSERTMARKHITTEST, (WPARAM)lpPoint, (LPARAM)pInsertMark);
+ }
+
+ BOOL SetInfoTip(PLVSETINFOTIP pSetInfoTip)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_SETINFOTIP, 0, (LPARAM)pSetInfoTip);
+ }
+
+ void CancelEditLabel()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, LVM_CANCELEDITLABEL, 0, 0L);
+ }
+
+ UINT MapIndexToID(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, LVM_MAPINDEXTOID, nIndex, 0L);
+ }
+
+ int MapIDToIndex(UINT uID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_MAPIDTOINDEX, uID, 0L);
+ }
+
+ BOOL IsItemVisible(int nItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LVM_ISITEMVISIBLE, nItem, 0L);
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ int HitTestEx(LPLVHITTESTINFO lpHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_HITTEST, (WPARAM)-1, (LPARAM)lpHitTestInfo);
+ }
+
+ int HitTestEx(POINT pt, UINT* pFlags) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ LVHITTESTINFO hti = {};
+ hti.pt = pt;
+ int nRes = (int)::SendMessage(this->m_hWnd, LVM_HITTEST, (WPARAM)-1, (LPARAM)&hti);
+ if (pFlags != NULL)
+ *pFlags = hti.flags;
+ return nRes;
+ }
+
+ int SubItemHitTestEx(LPLVHITTESTINFO lpHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LVM_SUBITEMHITTEST, (WPARAM)-1, (LPARAM)lpHitTestInfo);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+ // Note: selects only one item
+ BOOL SelectItem(int nIndex) // -1 to select none
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+
+ BOOL bRet = FALSE;
+ if(nIndex != -1)
+ {
+ // multi-selection only: de-select all items
+ if((this->GetStyle() & LVS_SINGLESEL) == 0)
+ SetItemState(-1, 0, LVIS_SELECTED);
+
+ bRet = SetItemState(nIndex, LVIS_SELECTED | LVIS_FOCUSED, LVIS_SELECTED | LVIS_FOCUSED);
+ if(bRet)
+ {
+ SetSelectionMark(nIndex);
+ bRet = EnsureVisible(nIndex, FALSE);
+ }
+ }
+ else // no item specified, just de-select
+ {
+ bRet = SetItemState(-1, 0, LVIS_SELECTED);
+ }
+
+ return bRet;
+ }
+
+ // multi-selection only
+ BOOL SelectAllItems(bool bSelect = true)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & LVS_SINGLESEL) == 0);
+
+ return SetItemState(-1, bSelect ? LVIS_SELECTED : 0, LVIS_SELECTED);
+ }
+};
+
+typedef CListViewCtrlT<ATL::CWindow> CListViewCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CTreeViewCtrl
+
+template <class TBase>
+class CTreeViewCtrlT : public TBase
+{
+public:
+// Constructors
+ CTreeViewCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CTreeViewCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return WC_TREEVIEW;
+ }
+
+ UINT GetCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, TVM_GETCOUNT, 0, 0L);
+ }
+
+ UINT GetIndent() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, TVM_GETINDENT, 0, 0L);
+ }
+
+ void SetIndent(UINT nIndent)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TVM_SETINDENT, nIndent, 0L);
+ }
+
+ CImageList GetImageList(int nImageListType = TVSIL_NORMAL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TVM_GETIMAGELIST, (WPARAM)nImageListType, 0L));
+ }
+
+ CImageList SetImageList(HIMAGELIST hImageList, int nImageListType = TVSIL_NORMAL)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TVM_SETIMAGELIST, (WPARAM)nImageListType, (LPARAM)hImageList));
+ }
+
+ BOOL GetItem(LPTVITEM pItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)pItem);
+ }
+
+ BOOL SetItem(LPTVITEM pItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SETITEM, 0, (LPARAM)pItem);
+ }
+
+ BOOL SetItem(HTREEITEM hItem, UINT nMask, LPCTSTR lpszItem, int nImage,
+ int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVITEM item = {};
+ item.hItem = hItem;
+ item.mask = nMask;
+ item.pszText = (LPTSTR) lpszItem;
+ item.iImage = nImage;
+ item.iSelectedImage = nSelectedImage;
+ item.state = nState;
+ item.stateMask = nStateMask;
+ item.lParam = lParam;
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SETITEM, 0, (LPARAM)&item);
+ }
+
+ BOOL GetItemText(HTREEITEM hItem, LPTSTR lpstrText, int nLen) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpstrText != NULL);
+
+ TVITEM item = {};
+ item.hItem = hItem;
+ item.mask = TVIF_TEXT;
+ item.pszText = lpstrText;
+ item.cchTextMax = nLen;
+
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item);
+ }
+
+ BOOL GetItemText(HTREEITEM hItem, BSTR& bstrText) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(bstrText == NULL);
+ TVITEM item = {};
+ item.hItem = hItem;
+ item.mask = TVIF_TEXT;
+
+ LPTSTR lpstrText = NULL;
+ BOOL bRet = FALSE;
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ ATLTRY(lpstrText = new TCHAR[nLen]);
+ if(lpstrText == NULL)
+ break;
+ lpstrText[0] = NULL;
+ item.pszText = lpstrText;
+ item.cchTextMax = nLen;
+ bRet = (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item);
+ if(!bRet || (lstrlen(item.pszText) < (nLen - 1)))
+ break;
+ delete [] lpstrText;
+ lpstrText = NULL;
+ }
+
+ if(lpstrText != NULL)
+ {
+ if(bRet)
+ bstrText = ::SysAllocString(T2OLE(lpstrText));
+ delete [] lpstrText;
+ }
+
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+
+#ifdef __ATLSTR_H__
+ BOOL GetItemText(HTREEITEM hItem, ATL::CString& strText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVITEM item = {};
+ item.hItem = hItem;
+ item.mask = TVIF_TEXT;
+
+ strText.Empty();
+ BOOL bRet = FALSE;
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ item.pszText = strText.GetBufferSetLength(nLen);
+ if(item.pszText == NULL)
+ {
+ bRet = FALSE;
+ break;
+ }
+ item.cchTextMax = nLen;
+ bRet = (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item);
+ if(!bRet || (lstrlen(item.pszText) < (nLen - 1)))
+ break;
+ }
+ strText.ReleaseBuffer();
+ return bRet;
+ }
+#endif // __ATLSTR_H__
+
+ BOOL SetItemText(HTREEITEM hItem, LPCTSTR lpszItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItem(hItem, TVIF_TEXT, lpszItem, 0, 0, 0, 0, NULL);
+ }
+
+ BOOL GetItemImage(HTREEITEM hItem, int& nImage, int& nSelectedImage) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVITEM item = {};
+ item.hItem = hItem;
+ item.mask = TVIF_IMAGE|TVIF_SELECTEDIMAGE;
+ BOOL bRes = (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item);
+ if (bRes)
+ {
+ nImage = item.iImage;
+ nSelectedImage = item.iSelectedImage;
+ }
+ return bRes;
+ }
+
+ BOOL SetItemImage(HTREEITEM hItem, int nImage, int nSelectedImage)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItem(hItem, TVIF_IMAGE|TVIF_SELECTEDIMAGE, NULL, nImage, nSelectedImage, 0, 0, NULL);
+ }
+
+ UINT GetItemState(HTREEITEM hItem, UINT nStateMask) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (((UINT)::SendMessage(this->m_hWnd, TVM_GETITEMSTATE, (WPARAM)hItem, (LPARAM)nStateMask)) & nStateMask);
+ }
+
+ BOOL SetItemState(HTREEITEM hItem, UINT nState, UINT nStateMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItem(hItem, TVIF_STATE, NULL, 0, 0, nState, nStateMask, NULL);
+ }
+
+ DWORD_PTR GetItemData(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVITEM item = {};
+ item.hItem = hItem;
+ item.mask = TVIF_PARAM;
+ BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item);
+ return (DWORD_PTR)(bRet ? item.lParam : NULL);
+ }
+
+ BOOL SetItemData(HTREEITEM hItem, DWORD_PTR dwData)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItem(hItem, TVIF_PARAM, NULL, 0, 0, 0, 0, (LPARAM)dwData);
+ }
+
+ CEdit GetEditControl() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CEdit((HWND)::SendMessage(this->m_hWnd, TVM_GETEDITCONTROL, 0, 0L));
+ }
+
+ UINT GetVisibleCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, TVM_GETVISIBLECOUNT, 0, 0L);
+ }
+
+ BOOL GetItemRect(HTREEITEM hItem, LPRECT lpRect, BOOL bTextOnly) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ *(HTREEITEM*)lpRect = hItem;
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEMRECT, (WPARAM)bTextOnly, (LPARAM)lpRect);
+ }
+
+ BOOL ItemHasChildren(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVITEM item = {};
+ item.hItem = hItem;
+ item.mask = TVIF_CHILDREN;
+ ::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)&item);
+ return item.cChildren;
+ }
+
+ CToolTipCtrl GetToolTips() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TVM_GETTOOLTIPS, 0, 0L));
+ }
+
+ CToolTipCtrl SetToolTips(HWND hWndTT)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TVM_SETTOOLTIPS, (WPARAM)hWndTT, 0L));
+ }
+
+ int GetISearchString(LPTSTR lpstr) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TVM_GETISEARCHSTRING, 0, (LPARAM)lpstr);
+ }
+
+ // checkboxes only
+ BOOL GetCheckState(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & TVS_CHECKBOXES) != 0);
+ UINT uRet = GetItemState(hItem, TVIS_STATEIMAGEMASK);
+ return (uRet >> 12) - 1;
+ }
+
+ BOOL SetCheckState(HTREEITEM hItem, BOOL bCheck)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & TVS_CHECKBOXES) != 0);
+ int nCheck = bCheck ? 2 : 1; // one based index
+ return SetItemState(hItem, INDEXTOSTATEIMAGEMASK(nCheck), TVIS_STATEIMAGEMASK);
+ }
+
+ // for standard and extended checkboxes (0 = no checkbox, 1 = unchecked, 2 = checked, >2 = optional extended check states)
+ UINT GetCheckStateEx(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(this->GetImageList(TVSIL_STATE) != NULL);
+ UINT uRet = GetItemState(hItem, TVIS_STATEIMAGEMASK);
+ return (uRet >> 12);
+ }
+
+ BOOL SetCheckStateEx(HTREEITEM hItem, UINT uCheckState)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(this->GetImageList(TVSIL_STATE) != NULL);
+ ATLASSERT(uCheckState < (UINT)::ImageList_GetImageCount(this->GetImageList(TVSIL_STATE)));
+ return SetItemState(hItem, INDEXTOSTATEIMAGEMASK(uCheckState), TVIS_STATEIMAGEMASK);
+ }
+
+ COLORREF GetBkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TVM_GETBKCOLOR, 0, 0L);
+ }
+
+ COLORREF SetBkColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TVM_SETBKCOLOR, 0, (LPARAM)clr);
+ }
+
+ COLORREF GetInsertMarkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TVM_GETINSERTMARKCOLOR, 0, 0L);
+ }
+
+ COLORREF SetInsertMarkColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TVM_SETINSERTMARKCOLOR, 0, (LPARAM)clr);
+ }
+
+ int GetItemHeight() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TVM_GETITEMHEIGHT, 0, 0L);
+ }
+
+ int SetItemHeight(int cyHeight)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TVM_SETITEMHEIGHT, cyHeight, 0L);
+ }
+
+ int GetScrollTime() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TVM_GETSCROLLTIME, 0, 0L);
+ }
+
+ int SetScrollTime(int nScrollTime)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TVM_SETSCROLLTIME, nScrollTime, 0L);
+ }
+
+ COLORREF GetTextColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TVM_GETTEXTCOLOR, 0, 0L);
+ }
+
+ COLORREF SetTextColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TVM_SETTEXTCOLOR, 0, (LPARAM)clr);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+ COLORREF GetLineColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TVM_GETLINECOLOR, 0, 0L);
+ }
+
+ COLORREF SetLineColor(COLORREF clrNew /*= CLR_DEFAULT*/)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TVM_SETLINECOLOR, 0, (LPARAM)clrNew);
+ }
+
+ BOOL GetItem(LPTVITEMEX pItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEM, 0, (LPARAM)pItem);
+ }
+
+ BOOL SetItem(LPTVITEMEX pItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SETITEM, 0, (LPARAM)pItem);
+ }
+
+ DWORD GetExtendedStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TVM_GETEXTENDEDSTYLE, 0, 0L);
+ }
+
+ DWORD SetExtendedStyle(DWORD dwStyle, DWORD dwMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TVM_SETEXTENDEDSTYLE, dwMask, dwStyle);
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ BOOL SetAutoScrollInfo(UINT uPixPerSec, UINT uUpdateTime)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SETAUTOSCROLLINFO, (WPARAM)uPixPerSec, (LPARAM)uUpdateTime);
+ }
+
+ DWORD GetSelectedCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TVM_GETSELECTEDCOUNT, 0, 0L);
+ }
+
+ BOOL GetItemPartRect(HTREEITEM hItem, TVITEMPART partID, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVGETITEMPARTRECTINFO gipri = { hItem, lpRect, partID };
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_GETITEMPARTRECT, 0, (LPARAM)&gipri);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+// Operations
+ HTREEITEM InsertItem(LPTVINSERTSTRUCT lpInsertStruct)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_INSERTITEM, 0, (LPARAM)lpInsertStruct);
+ }
+
+ HTREEITEM InsertItem(LPCTSTR lpszItem, int nImage,
+ int nSelectedImage, HTREEITEM hParent, HTREEITEM hInsertAfter)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return InsertItem(TVIF_TEXT | TVIF_IMAGE | TVIF_SELECTEDIMAGE, lpszItem, nImage, nSelectedImage, 0, 0, 0, hParent, hInsertAfter);
+ }
+
+ HTREEITEM InsertItem(LPCTSTR lpszItem, HTREEITEM hParent, HTREEITEM hInsertAfter)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return InsertItem(TVIF_TEXT, lpszItem, 0, 0, 0, 0, 0, hParent, hInsertAfter);
+ }
+
+ HTREEITEM InsertItem(UINT nMask, LPCTSTR lpszItem, int nImage,
+ int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam,
+ HTREEITEM hParent, HTREEITEM hInsertAfter)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVINSERTSTRUCT tvis = {};
+ tvis.hParent = hParent;
+ tvis.hInsertAfter = hInsertAfter;
+ tvis.item.mask = nMask;
+ tvis.item.pszText = (LPTSTR) lpszItem;
+ tvis.item.iImage = nImage;
+ tvis.item.iSelectedImage = nSelectedImage;
+ tvis.item.state = nState;
+ tvis.item.stateMask = nStateMask;
+ tvis.item.lParam = lParam;
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_INSERTITEM, 0, (LPARAM)&tvis);
+ }
+
+ BOOL DeleteItem(HTREEITEM hItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_DELETEITEM, 0, (LPARAM)hItem);
+ }
+
+ BOOL DeleteAllItems()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_DELETEITEM, 0, (LPARAM)TVI_ROOT);
+ }
+
+ BOOL Expand(HTREEITEM hItem, UINT nCode = TVE_EXPAND)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_EXPAND, nCode, (LPARAM)hItem);
+ }
+
+ HTREEITEM GetNextItem(HTREEITEM hItem, UINT nCode) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, nCode, (LPARAM)hItem);
+ }
+
+ HTREEITEM GetChildItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_CHILD, (LPARAM)hItem);
+ }
+
+ HTREEITEM GetNextSiblingItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXT, (LPARAM)hItem);
+ }
+
+ HTREEITEM GetPrevSiblingItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PREVIOUS, (LPARAM)hItem);
+ }
+
+ HTREEITEM GetParentItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PARENT, (LPARAM)hItem);
+ }
+
+ HTREEITEM GetFirstVisibleItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_FIRSTVISIBLE, 0L);
+ }
+
+ HTREEITEM GetNextVisibleItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXTVISIBLE, (LPARAM)hItem);
+ }
+
+ HTREEITEM GetPrevVisibleItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PREVIOUSVISIBLE, (LPARAM)hItem);
+ }
+
+ HTREEITEM GetSelectedItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_CARET, 0L);
+ }
+
+ HTREEITEM GetDropHilightItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_DROPHILITE, 0L);
+ }
+
+ HTREEITEM GetRootItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_ROOT, 0L);
+ }
+
+ HTREEITEM GetLastVisibleItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_LASTVISIBLE, 0L);
+ }
+
+ HTREEITEM GetNextSelectedItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXTSELECTED, (LPARAM)hItem);
+ }
+
+ BOOL Select(HTREEITEM hItem, UINT nCode)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SELECTITEM, nCode, (LPARAM)hItem);
+ }
+
+ BOOL SelectItem(HTREEITEM hItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SELECTITEM, TVGN_CARET, (LPARAM)hItem);
+ }
+
+ BOOL SelectDropTarget(HTREEITEM hItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SELECTITEM, TVGN_DROPHILITE, (LPARAM)hItem);
+ }
+
+ BOOL SelectSetFirstVisible(HTREEITEM hItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SELECTITEM, TVGN_FIRSTVISIBLE, (LPARAM)hItem);
+ }
+
+ CEdit EditLabel(HTREEITEM hItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CEdit((HWND)::SendMessage(this->m_hWnd, TVM_EDITLABEL, 0, (LPARAM)hItem));
+ }
+
+ BOOL EndEditLabelNow(BOOL bCancel)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_ENDEDITLABELNOW, bCancel, 0L);
+ }
+
+ HTREEITEM HitTest(TVHITTESTINFO* pHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_HITTEST, 0, (LPARAM)pHitTestInfo);
+ }
+
+ HTREEITEM HitTest(POINT pt, UINT* pFlags) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVHITTESTINFO hti = {};
+ hti.pt = pt;
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_HITTEST, 0, (LPARAM)&hti);
+ if (pFlags != NULL)
+ *pFlags = hti.flags;
+ return hTreeItem;
+ }
+
+ BOOL SortChildren(HTREEITEM hItem, BOOL bRecurse = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SORTCHILDREN, (WPARAM)bRecurse, (LPARAM)hItem);
+ }
+
+ BOOL EnsureVisible(HTREEITEM hItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_ENSUREVISIBLE, 0, (LPARAM)hItem);
+ }
+
+ BOOL SortChildrenCB(LPTVSORTCB pSort, BOOL bRecurse = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SORTCHILDRENCB, (WPARAM)bRecurse, (LPARAM)pSort);
+ }
+
+ CImageList RemoveImageList(int nImageList)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TVM_SETIMAGELIST, (WPARAM)nImageList, NULL));
+ }
+
+ CImageList CreateDragImage(HTREEITEM hItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TVM_CREATEDRAGIMAGE, 0, (LPARAM)hItem));
+ }
+
+ BOOL SetInsertMark(HTREEITEM hTreeItem, BOOL bAfter)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SETINSERTMARK, bAfter, (LPARAM)hTreeItem);
+ }
+
+ BOOL RemoveInsertMark()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TVM_SETINSERTMARK, 0, 0L);
+ }
+
+ HTREEITEM MapAccIDToHTREEITEM(UINT uID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HTREEITEM)::SendMessage(this->m_hWnd, TVM_MAPACCIDTOHTREEITEM, uID, 0L);
+ }
+
+ UINT MapHTREEITEMToAccID(HTREEITEM hTreeItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, TVM_MAPHTREEITEMTOACCID, (WPARAM)hTreeItem, 0L);
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ void ShowInfoTip(HTREEITEM hItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TVM_SHOWINFOTIP, 0, (LPARAM)hItem);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+};
+
+typedef CTreeViewCtrlT<ATL::CWindow> CTreeViewCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CTreeViewCtrlEx
+
+// forward declaration
+template <class TBase> class CTreeViewCtrlExT;
+
+// Note: TBase here is for CTreeViewCtrlExT, and not for CTreeItemT itself
+template <class TBase>
+class CTreeItemT
+{
+public:
+ HTREEITEM m_hTreeItem;
+ CTreeViewCtrlExT<TBase>* m_pTreeView;
+
+// Construction
+ CTreeItemT(HTREEITEM hTreeItem = NULL, CTreeViewCtrlExT<TBase>* pTreeView = NULL) : m_hTreeItem(hTreeItem), m_pTreeView(pTreeView)
+ { }
+
+ CTreeItemT(const CTreeItemT<TBase>& posSrc)
+ {
+ *this = posSrc;
+ }
+
+ operator HTREEITEM() { return m_hTreeItem; }
+
+ CTreeItemT<TBase>& operator =(const CTreeItemT<TBase>& itemSrc)
+ {
+ m_hTreeItem = itemSrc.m_hTreeItem;
+ m_pTreeView = itemSrc.m_pTreeView;
+ return *this;
+ }
+
+// Attributes
+ CTreeViewCtrlExT<TBase>* GetTreeView() const { return m_pTreeView; }
+
+ BOOL operator !() const { return m_hTreeItem == NULL; }
+
+ BOOL IsNull() const { return m_hTreeItem == NULL; }
+
+ BOOL GetRect(LPRECT lpRect, BOOL bTextOnly) const;
+ BOOL GetText(LPTSTR lpstrText, int nLen) const;
+ BOOL GetText(BSTR& bstrText) const;
+#ifdef __ATLSTR_H__
+ BOOL GetText(ATL::CString& strText) const;
+#endif // __ATLSTR_H__
+ BOOL SetText(LPCTSTR lpszItem);
+ BOOL GetImage(int& nImage, int& nSelectedImage) const;
+ BOOL SetImage(int nImage, int nSelectedImage);
+ UINT GetState(UINT nStateMask) const;
+ BOOL SetState(UINT nState, UINT nStateMask);
+ DWORD_PTR GetData() const;
+ BOOL SetData(DWORD_PTR dwData);
+ BOOL SetItem(UINT nMask, LPCTSTR lpszItem, int nImage, int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam);
+
+// Operations
+ CTreeItemT<TBase> InsertAfter(LPCTSTR lpstrItem, HTREEITEM hItemAfter, int nImageIndex)
+ {
+ return _Insert(lpstrItem, nImageIndex, hItemAfter);
+ }
+
+ CTreeItemT<TBase> AddHead(LPCTSTR lpstrItem, int nImageIndex)
+ {
+ return _Insert(lpstrItem, nImageIndex, TVI_FIRST);
+ }
+
+ CTreeItemT<TBase> AddTail(LPCTSTR lpstrItem, int nImageIndex)
+ {
+ return _Insert(lpstrItem, nImageIndex, TVI_LAST);
+ }
+
+ CTreeItemT<TBase> GetChild() const;
+ CTreeItemT<TBase> GetNext(UINT nCode) const;
+ CTreeItemT<TBase> GetNextSibling() const;
+ CTreeItemT<TBase> GetPrevSibling() const;
+ CTreeItemT<TBase> GetParent() const;
+ CTreeItemT<TBase> GetFirstVisible() const;
+ CTreeItemT<TBase> GetNextVisible() const;
+ CTreeItemT<TBase> GetPrevVisible() const;
+ CTreeItemT<TBase> GetSelected() const;
+ CTreeItemT<TBase> GetDropHilight() const;
+ CTreeItemT<TBase> GetRoot() const;
+ CTreeItemT<TBase> GetLastVisible() const;
+ CTreeItemT<TBase> GetNextSelected() const;
+ BOOL HasChildren() const;
+ BOOL Delete();
+ BOOL Expand(UINT nCode = TVE_EXPAND);
+ BOOL Select(UINT nCode);
+ BOOL Select();
+ BOOL SelectDropTarget();
+ BOOL SelectSetFirstVisible();
+ HWND EditLabel();
+ HIMAGELIST CreateDragImage();
+ BOOL SortChildren(BOOL bRecurse = FALSE);
+ BOOL EnsureVisible();
+ CTreeItemT<TBase> _Insert(LPCTSTR lpstrItem, int nImageIndex, HTREEITEM hItemAfter);
+ int GetImageIndex() const;
+ BOOL SetInsertMark(BOOL bAfter);
+ UINT MapHTREEITEMToAccID() const;
+#if (_WIN32_WINNT >= 0x0600)
+ void ShowInfoTip();
+ BOOL GetPartRect(TVITEMPART partID, LPRECT lpRect) const;
+#endif // (_WIN32_WINNT >= 0x0600)
+};
+
+typedef CTreeItemT<ATL::CWindow> CTreeItem;
+
+
+template <class TBase>
+class CTreeViewCtrlExT : public CTreeViewCtrlT< TBase >
+{
+public:
+// Constructors
+ CTreeViewCtrlExT(HWND hWnd = NULL) : CTreeViewCtrlT< TBase >(hWnd)
+ { }
+
+ CTreeViewCtrlExT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+// Operations (overides that return CTreeItem)
+ CTreeItemT<TBase> InsertItem(LPTVINSERTSTRUCT lpInsertStruct)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_INSERTITEM, 0, (LPARAM)lpInsertStruct);
+ return CTreeItemT<TBase>(hTreeItem, this);
+ }
+
+ CTreeItemT<TBase> InsertItem(LPCTSTR lpszItem, int nImage,
+ int nSelectedImage, HTREEITEM hParent, HTREEITEM hInsertAfter)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return InsertItem(TVIF_TEXT | TVIF_IMAGE | TVIF_SELECTEDIMAGE, lpszItem, nImage, nSelectedImage, 0, 0, 0, hParent, hInsertAfter);
+ }
+
+ CTreeItemT<TBase> InsertItem(LPCTSTR lpszItem, HTREEITEM hParent, HTREEITEM hInsertAfter)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return InsertItem(TVIF_TEXT, lpszItem, 0, 0, 0, 0, 0, hParent, hInsertAfter);
+ }
+
+ CTreeItemT<TBase> GetNextItem(HTREEITEM hItem, UINT nCode) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, nCode, (LPARAM)hItem);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetChildItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_CHILD, (LPARAM)hItem);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetNextSiblingItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXT, (LPARAM)hItem);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetPrevSiblingItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PREVIOUS, (LPARAM)hItem);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetParentItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PARENT, (LPARAM)hItem);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetFirstVisibleItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_FIRSTVISIBLE, 0L);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetNextVisibleItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXTVISIBLE, (LPARAM)hItem);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetPrevVisibleItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_PREVIOUSVISIBLE, (LPARAM)hItem);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetSelectedItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_CARET, 0L);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetDropHilightItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_DROPHILITE, 0L);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetRootItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_ROOT, 0L);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetLastVisibleItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_LASTVISIBLE, 0L);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> GetNextSelectedItem(HTREEITEM hItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_GETNEXTITEM, TVGN_NEXTSELECTED, (LPARAM)hItem);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> HitTest(TVHITTESTINFO* pHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_HITTEST, 0, (LPARAM)pHitTestInfo);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> InsertItem(UINT nMask, LPCTSTR lpszItem, int nImage,
+ int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam,
+ HTREEITEM hParent, HTREEITEM hInsertAfter)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVINSERTSTRUCT tvis = {};
+ tvis.hParent = hParent;
+ tvis.hInsertAfter = hInsertAfter;
+ tvis.item.mask = nMask;
+ tvis.item.pszText = (LPTSTR) lpszItem;
+ tvis.item.iImage = nImage;
+ tvis.item.iSelectedImage = nSelectedImage;
+ tvis.item.state = nState;
+ tvis.item.stateMask = nStateMask;
+ tvis.item.lParam = lParam;
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_INSERTITEM, 0, (LPARAM)&tvis);
+ return CTreeItemT<TBase>(hTreeItem, this);
+ }
+
+ CTreeItemT<TBase> HitTest(POINT pt, UINT* pFlags) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TVHITTESTINFO hti = {};
+ hti.pt = pt;
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_HITTEST, 0, (LPARAM)&hti);
+ if (pFlags != NULL)
+ *pFlags = hti.flags;
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+
+ CTreeItemT<TBase> MapAccIDToHTREEITEM(UINT uID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ HTREEITEM hTreeItem = (HTREEITEM)::SendMessage(this->m_hWnd, TVM_MAPACCIDTOHTREEITEM, uID, 0L);
+ return CTreeItemT<TBase>(hTreeItem, (CTreeViewCtrlExT<TBase>*)this);
+ }
+};
+
+typedef CTreeViewCtrlExT<ATL::CWindow> CTreeViewCtrlEx;
+
+
+// CTreeItem inline methods
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::GetRect(LPRECT lpRect, BOOL bTextOnly) const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetItemRect(m_hTreeItem,lpRect,bTextOnly);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetNext(UINT nCode) const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetNextItem(m_hTreeItem,nCode);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetChild() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetChildItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetNextSibling() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetNextSiblingItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetPrevSibling() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetPrevSiblingItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetParent() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetParentItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetFirstVisible() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetFirstVisibleItem();
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetNextVisible() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetNextVisibleItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetPrevVisible() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetPrevVisibleItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetSelected() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetSelectedItem();
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetDropHilight() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetDropHilightItem();
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetRoot() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetRootItem();
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetLastVisible() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetLastVisibleItem();
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::GetNextSelected() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetNextSelectedItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::GetText(LPTSTR lpstrText, int nLen) const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetItemText(m_hTreeItem, lpstrText, nLen);
+}
+
+#ifdef _OLEAUTO_H_
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::GetText(BSTR& bstrText) const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetItemText(m_hTreeItem, bstrText);
+}
+#endif // _OLEAUTO_H_
+
+#ifdef __ATLSTR_H__
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::GetText(ATL::CString& strText) const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetItemText(m_hTreeItem, strText);
+}
+#endif // __ATLSTR_H__
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::GetImage(int& nImage, int& nSelectedImage) const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetItemImage(m_hTreeItem,nImage,nSelectedImage);
+}
+
+template <class TBase>
+inline UINT CTreeItemT<TBase>::GetState(UINT nStateMask) const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetItemState(m_hTreeItem,nStateMask);
+}
+
+template <class TBase>
+inline DWORD_PTR CTreeItemT<TBase>::GetData() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetItemData(m_hTreeItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SetItem(UINT nMask, LPCTSTR lpszItem, int nImage,
+ int nSelectedImage, UINT nState, UINT nStateMask, LPARAM lParam)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SetItem(m_hTreeItem, nMask, lpszItem, nImage, nSelectedImage, nState, nStateMask, lParam);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SetText(LPCTSTR lpszItem)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SetItemText(m_hTreeItem,lpszItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SetImage(int nImage, int nSelectedImage)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SetItemImage(m_hTreeItem,nImage,nSelectedImage);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SetState(UINT nState, UINT nStateMask)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SetItemState(m_hTreeItem,nState,nStateMask);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SetData(DWORD_PTR dwData)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SetItemData(m_hTreeItem,dwData);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::HasChildren() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->ItemHasChildren(m_hTreeItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::Delete()
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->DeleteItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::Expand(UINT nCode /*= TVE_EXPAND*/)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->Expand(m_hTreeItem,nCode);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::Select(UINT nCode)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->Select(m_hTreeItem,nCode);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::Select()
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SelectItem(m_hTreeItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SelectDropTarget()
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SelectDropTarget(m_hTreeItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SelectSetFirstVisible()
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SelectSetFirstVisible(m_hTreeItem);
+}
+
+template <class TBase>
+inline HWND CTreeItemT<TBase>::EditLabel()
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->EditLabel(m_hTreeItem);
+}
+
+template <class TBase>
+inline HIMAGELIST CTreeItemT<TBase>::CreateDragImage()
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->CreateDragImage(m_hTreeItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SortChildren(BOOL bRecurse /*= FALSE*/)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SortChildren(m_hTreeItem, bRecurse);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::EnsureVisible()
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->EnsureVisible(m_hTreeItem);
+}
+
+template <class TBase>
+inline CTreeItemT<TBase> CTreeItemT<TBase>::_Insert(LPCTSTR lpstrItem, int nImageIndex, HTREEITEM hItemAfter)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ TVINSERTSTRUCT ins = {};
+ ins.hParent = m_hTreeItem;
+ ins.hInsertAfter = hItemAfter;
+ ins.item.mask = TVIF_TEXT;
+ ins.item.pszText = (LPTSTR)lpstrItem;
+ if(nImageIndex != -1)
+ {
+ ins.item.mask |= TVIF_IMAGE | TVIF_SELECTEDIMAGE;
+ ins.item.iImage = nImageIndex;
+ ins.item.iSelectedImage = nImageIndex;
+ }
+ return CTreeItemT<TBase>(m_pTreeView->InsertItem(&ins), m_pTreeView);
+}
+
+template <class TBase>
+inline int CTreeItemT<TBase>::GetImageIndex() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ TVITEM item = {};
+ item.mask = TVIF_HANDLE | TVIF_IMAGE;
+ item.hItem = m_hTreeItem;
+ m_pTreeView->GetItem(&item);
+ return item.iImage;
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::SetInsertMark(BOOL bAfter)
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->SetInsertMark(m_hTreeItem, bAfter);
+}
+
+template <class TBase>
+inline UINT CTreeItemT<TBase>::MapHTREEITEMToAccID() const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->MapHTREEITEMToAccID(m_hTreeItem);
+}
+
+#if (_WIN32_WINNT >= 0x0600)
+template <class TBase>
+inline void CTreeItemT<TBase>::ShowInfoTip()
+{
+ ATLASSERT(m_pTreeView != NULL);
+ m_pTreeView->ShowInfoTip(m_hTreeItem);
+}
+
+template <class TBase>
+inline BOOL CTreeItemT<TBase>::GetPartRect(TVITEMPART partID, LPRECT lpRect) const
+{
+ ATLASSERT(m_pTreeView != NULL);
+ return m_pTreeView->GetItemPartRect(m_hTreeItem, partID, lpRect);
+}
+#endif // (_WIN32_WINNT >= 0x0600)
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CToolBarCtrl
+
+template <class TBase>
+class CToolBarCtrlT : public TBase
+{
+public:
+// Construction
+ CToolBarCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CToolBarCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return TOOLBARCLASSNAME;
+ }
+
+ BOOL IsButtonEnabled(int nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONENABLED, nID, 0L);
+ }
+
+ BOOL IsButtonChecked(int nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONCHECKED, nID, 0L);
+ }
+
+ BOOL IsButtonPressed(int nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONPRESSED, nID, 0L);
+ }
+
+ BOOL IsButtonHidden(int nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return(BOOL) ::SendMessage(this->m_hWnd, TB_ISBUTTONHIDDEN, nID, 0L);
+ }
+
+ BOOL IsButtonIndeterminate(int nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONINDETERMINATE, nID, 0L);
+ }
+
+ int GetState(int nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_GETSTATE, nID, 0L);
+ }
+
+ BOOL SetState(int nID, UINT nState)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETSTATE, nID, MAKELPARAM(nState, 0));
+ }
+
+ BOOL GetButton(int nIndex, LPTBBUTTON lpButton) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_GETBUTTON, nIndex, (LPARAM)lpButton);
+ }
+
+ int GetButtonCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_BUTTONCOUNT, 0, 0L);
+ }
+
+ BOOL GetItemRect(int nIndex, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_GETITEMRECT, nIndex, (LPARAM)lpRect);
+ }
+
+ void SetButtonStructSize(int nSize = sizeof(TBBUTTON))
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_BUTTONSTRUCTSIZE, nSize, 0L);
+ }
+
+ BOOL SetButtonSize(SIZE size)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONSIZE, 0, MAKELPARAM(size.cx, size.cy));
+ }
+
+ BOOL SetButtonSize(int cx, int cy)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONSIZE, 0, MAKELPARAM(cx, cy));
+ }
+
+ BOOL SetBitmapSize(SIZE size)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETBITMAPSIZE, 0, MAKELPARAM(size.cx, size.cy));
+ }
+
+ BOOL SetBitmapSize(int cx, int cy)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETBITMAPSIZE, 0, MAKELPARAM(cx, cy));
+ }
+
+ CToolTipCtrl GetToolTips() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TB_GETTOOLTIPS, 0, 0L));
+ }
+
+ void SetToolTips(HWND hWndToolTip)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_SETTOOLTIPS, (WPARAM)hWndToolTip, 0L);
+ }
+
+ void SetNotifyWnd(HWND hWnd)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_SETPARENT, (WPARAM)hWnd, 0L);
+ }
+
+ int GetRows() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_GETROWS, 0, 0L);
+ }
+
+ void SetRows(int nRows, BOOL bLarger, LPRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_SETROWS, MAKELPARAM(nRows, bLarger), (LPARAM)lpRect);
+ }
+
+ BOOL SetCmdID(int nIndex, UINT nID)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETCMDID, nIndex, nID);
+ }
+
+ DWORD GetBitmapFlags() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TB_GETBITMAPFLAGS, 0, 0L);
+ }
+
+ int GetBitmap(int nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_GETBITMAP, nID, 0L);
+ }
+
+ int GetButtonText(int nID, LPTSTR lpstrText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_GETBUTTONTEXT, nID, (LPARAM)lpstrText);
+ }
+
+ // nIndex - IE5 or higher only
+ CImageList GetImageList(int nIndex = 0) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_GETIMAGELIST, nIndex, 0L));
+ }
+
+ // nIndex - IE5 or higher only
+ CImageList SetImageList(HIMAGELIST hImageList, int nIndex = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_SETIMAGELIST, nIndex, (LPARAM)hImageList));
+ }
+
+ // nIndex - IE5 or higher only
+ CImageList GetDisabledImageList(int nIndex = 0) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_GETDISABLEDIMAGELIST, nIndex, 0L));
+ }
+
+ // nIndex - IE5 or higher only
+ CImageList SetDisabledImageList(HIMAGELIST hImageList, int nIndex = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_SETDISABLEDIMAGELIST, nIndex, (LPARAM)hImageList));
+ }
+
+ // nIndex - IE5 or higher only
+ CImageList GetHotImageList(int nIndex = 0) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_GETHOTIMAGELIST, nIndex, 0L));
+ }
+
+ // nIndex - IE5 or higher only
+ CImageList SetHotImageList(HIMAGELIST hImageList, int nIndex = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_SETHOTIMAGELIST, nIndex, (LPARAM)hImageList));
+ }
+
+ DWORD GetStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TB_GETSTYLE, 0, 0L);
+ }
+
+ void SetStyle(DWORD dwStyle)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_SETSTYLE, 0, dwStyle);
+ }
+
+ DWORD GetButtonSize() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TB_GETBUTTONSIZE, 0, 0L);
+ }
+
+ void GetButtonSize(SIZE& size) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, TB_GETBUTTONSIZE, 0, 0L);
+ size.cx = LOWORD(dwRet);
+ size.cy = HIWORD(dwRet);
+ }
+
+ BOOL GetRect(int nID, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_GETRECT, nID, (LPARAM)lpRect);
+ }
+
+ int GetTextRows() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_GETTEXTROWS, 0, 0L);
+ }
+
+ BOOL SetButtonWidth(int cxMin, int cxMax)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONWIDTH, 0, MAKELPARAM(cxMin, cxMax));
+ }
+
+ BOOL SetIndent(int nIndent)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETINDENT, nIndent, 0L);
+ }
+
+ BOOL SetMaxTextRows(int nMaxTextRows)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETMAXTEXTROWS, nMaxTextRows, 0L);
+ }
+
+ BOOL GetAnchorHighlight() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_GETANCHORHIGHLIGHT, 0, 0L);
+ }
+
+ BOOL SetAnchorHighlight(BOOL bEnable = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETANCHORHIGHLIGHT, bEnable, 0L);
+ }
+
+ int GetButtonInfo(int nID, LPTBBUTTONINFO lptbbi) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_GETBUTTONINFO, nID, (LPARAM)lptbbi);
+ }
+
+ BOOL SetButtonInfo(int nID, LPTBBUTTONINFO lptbbi)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONINFO, nID, (LPARAM)lptbbi);
+ }
+
+ BOOL SetButtonInfo(int nID, DWORD dwMask, BYTE Style, BYTE State, LPCTSTR lpszItem,
+ int iImage, WORD cx, int iCommand, DWORD_PTR lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TBBUTTONINFO tbbi = {};
+ tbbi.cbSize = sizeof(TBBUTTONINFO);
+ tbbi.dwMask = dwMask;
+ tbbi.idCommand = iCommand;
+ tbbi.iImage = iImage;
+ tbbi.fsState = State;
+ tbbi.fsStyle = Style;
+ tbbi.cx = cx;
+ tbbi.pszText = (LPTSTR) lpszItem;
+ tbbi.lParam = lParam;
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETBUTTONINFO, nID, (LPARAM)&tbbi);
+ }
+
+ int GetHotItem() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_GETHOTITEM, 0, 0L);
+ }
+
+ int SetHotItem(int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_SETHOTITEM, nItem, 0L);
+ }
+
+ BOOL IsButtonHighlighted(int nButtonID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_ISBUTTONHIGHLIGHTED, nButtonID, 0L);
+ }
+
+ DWORD SetDrawTextFlags(DWORD dwMask, DWORD dwFlags)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TB_SETDRAWTEXTFLAGS, dwMask, dwFlags);
+ }
+
+ BOOL GetColorScheme(LPCOLORSCHEME lpcs) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_GETCOLORSCHEME, 0, (LPARAM)lpcs);
+ }
+
+ void SetColorScheme(LPCOLORSCHEME lpcs)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_SETCOLORSCHEME, 0, (LPARAM)lpcs);
+ }
+
+ DWORD GetExtendedStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TB_GETEXTENDEDSTYLE, 0, 0L);
+ }
+
+ DWORD SetExtendedStyle(DWORD dwStyle)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TB_SETEXTENDEDSTYLE, 0, dwStyle);
+ }
+
+ void GetInsertMark(LPTBINSERTMARK lptbim) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_GETINSERTMARK, 0, (LPARAM)lptbim);
+ }
+
+ void SetInsertMark(LPTBINSERTMARK lptbim)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_SETINSERTMARK, 0, (LPARAM)lptbim);
+ }
+
+ COLORREF GetInsertMarkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TB_GETINSERTMARKCOLOR, 0, 0L);
+ }
+
+ COLORREF SetInsertMarkColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, TB_SETINSERTMARKCOLOR, 0, (LPARAM)clr);
+ }
+
+ BOOL GetMaxSize(LPSIZE lpSize) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_GETMAXSIZE, 0, (LPARAM)lpSize);
+ }
+
+ void GetPadding(LPSIZE lpSizePadding) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpSizePadding != NULL);
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, TB_GETPADDING, 0, 0L);
+ lpSizePadding->cx = GET_X_LPARAM(dwRet);
+ lpSizePadding->cy = GET_Y_LPARAM(dwRet);
+ }
+
+ void SetPadding(int cx, int cy, LPSIZE lpSizePadding = NULL)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, TB_SETPADDING, 0, MAKELPARAM(cx, cy));
+ if(lpSizePadding != NULL)
+ {
+ lpSizePadding->cx = GET_X_LPARAM(dwRet);
+ lpSizePadding->cy = GET_Y_LPARAM(dwRet);
+ }
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+ int GetString(int nString, LPTSTR lpstrString, int cchMaxLen) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(cchMaxLen, nString), (LPARAM)lpstrString);
+ }
+
+ int GetStringBSTR(int nString, BSTR& bstrString) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(bstrString == NULL);
+ int nLength = (int)(short)LOWORD(::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(0, nString), NULL));
+ if(nLength != -1)
+ {
+ ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff;
+ LPTSTR lpstrText = buff.Allocate(nLength + 1);
+ if(lpstrText != NULL)
+ {
+ nLength = (int)::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(nLength + 1, nString), (LPARAM)lpstrText);
+ if(nLength != -1)
+ bstrString = ::SysAllocString(T2OLE(lpstrText));
+ }
+ else
+ {
+ nLength = -1;
+ }
+ }
+
+ return nLength;
+ }
+
+#ifdef __ATLSTR_H__
+ int GetString(int nString, ATL::CString& str) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ int nLength = (int)(short)LOWORD(::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(0, nString), NULL));
+ if(nLength != -1)
+ {
+ LPTSTR lpstr = str.GetBufferSetLength(nLength + 1);
+ if(lpstr != NULL)
+ nLength = (int)::SendMessage(this->m_hWnd, TB_GETSTRING, MAKEWPARAM(nLength + 1, nString), (LPARAM)lpstr);
+ else
+ nLength = -1;
+ str.ReleaseBuffer();
+ }
+ return nLength;
+ }
+#endif // __ATLSTR_H__
+
+ void GetMetrics(LPTBMETRICS lptbm) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_GETMETRICS, 0, (LPARAM)lptbm);
+ }
+
+ void SetMetrics(LPTBMETRICS lptbm)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_SETMETRICS, 0, (LPARAM)lptbm);
+ }
+
+ void SetWindowTheme(LPCWSTR lpstrTheme)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_SETWINDOWTHEME, 0, (LPARAM)lpstrTheme);
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ CImageList GetPressedImageList(int nIndex = 0) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_GETPRESSEDIMAGELIST, nIndex, 0L));
+ }
+
+ CImageList SetPressedImageList(HIMAGELIST hImageList, int nIndex = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TB_SETPRESSEDIMAGELIST, nIndex, (LPARAM)hImageList));
+ }
+
+ void GetItemDropDownRect(int nIndex, LPRECT lpRect) const
+ {
+#ifndef TB_GETITEMDROPDOWNRECT
+ const int TB_GETITEMDROPDOWNRECT = WM_USER + 103;
+#endif
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, TB_GETITEMDROPDOWNRECT, nIndex, (LPARAM)lpRect);
+ (void)bRet; // avoid level 4 warning
+ ATLASSERT(bRet != FALSE);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+// Operations
+ BOOL EnableButton(int nID, BOOL bEnable = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_ENABLEBUTTON, nID, MAKELPARAM(bEnable, 0));
+ }
+
+ BOOL CheckButton(int nID, BOOL bCheck = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_CHECKBUTTON, nID, MAKELPARAM(bCheck, 0));
+ }
+
+ BOOL PressButton(int nID, BOOL bPress = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_PRESSBUTTON, nID, MAKELPARAM(bPress, 0));
+ }
+
+ BOOL HideButton(int nID, BOOL bHide = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_HIDEBUTTON, nID, MAKELPARAM(bHide, 0));
+ }
+
+ BOOL Indeterminate(int nID, BOOL bIndeterminate = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_INDETERMINATE, nID, MAKELPARAM(bIndeterminate, 0));
+ }
+
+ int AddBitmap(int nNumButtons, UINT nBitmapID)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TBADDBITMAP tbab = {};
+ tbab.hInst = ModuleHelper::GetResourceInstance();
+ ATLASSERT(tbab.hInst != NULL);
+ tbab.nID = nBitmapID;
+ return (int)::SendMessage(this->m_hWnd, TB_ADDBITMAP, (WPARAM)nNumButtons, (LPARAM)&tbab);
+ }
+
+ int AddBitmap(int nNumButtons, HBITMAP hBitmap)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TBADDBITMAP tbab = {};
+ tbab.hInst = NULL;
+ tbab.nID = (UINT_PTR)hBitmap;
+ return (int)::SendMessage(this->m_hWnd, TB_ADDBITMAP, (WPARAM)nNumButtons, (LPARAM)&tbab);
+ }
+
+ BOOL AddButtons(int nNumButtons, LPCTBBUTTON lpButtons)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_ADDBUTTONS, nNumButtons, (LPARAM)lpButtons);
+ }
+
+ BOOL InsertButton(int nIndex, LPCTBBUTTON lpButton)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_INSERTBUTTON, nIndex, (LPARAM)lpButton);
+ }
+
+ BOOL InsertButton(int nIndex, int iCommand, BYTE Style, BYTE State, int iBitmap,
+ INT_PTR iString, DWORD_PTR lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TBBUTTON tbb = {};
+ tbb.fsStyle = Style;
+ tbb.fsState = State;
+ tbb.idCommand = iCommand;
+ tbb.iBitmap = iBitmap;
+ tbb.iString = iString;
+ tbb.dwData = lParam;
+ return (BOOL)::SendMessage(this->m_hWnd, TB_INSERTBUTTON, nIndex, (LPARAM)&tbb);
+ }
+
+ BOOL InsertButton(int nIndex, int iCommand, BYTE Style, BYTE State, int iBitmap,
+ LPCTSTR lpszItem, DWORD_PTR lParam)
+ {
+ return InsertButton(nIndex, iCommand, Style, State, iBitmap, (INT_PTR)lpszItem, lParam);
+ }
+
+ BOOL AddButton(LPTBBUTTON lpButton)
+ {
+ return InsertButton(-1, lpButton);
+ }
+
+ BOOL AddButton(int iCommand, BYTE Style, BYTE State, int iBitmap, INT_PTR iString, DWORD_PTR lParam)
+ {
+ return InsertButton(-1, iCommand, Style, State, iBitmap, iString, lParam);
+ }
+
+ BOOL AddButton(int iCommand, BYTE Style, BYTE State, int iBitmap, LPCTSTR lpszItem, DWORD_PTR lParam)
+ {
+ return InsertButton(-1, iCommand, Style, State, iBitmap, lpszItem, lParam);
+ }
+
+ BOOL DeleteButton(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_DELETEBUTTON, nIndex, 0L);
+ }
+
+ BOOL InsertSeparator(int nIndex, int cxWidth = 8)
+ {
+ return InsertButton(nIndex, 0, BTNS_SEP, 0, cxWidth, (INT_PTR)0, 0);
+ }
+
+ BOOL AddSeparator(int cxWidth = 8)
+ {
+ return AddButton(0, BTNS_SEP, 0, cxWidth, (INT_PTR)0, 0);
+ }
+
+ int CommandToIndex(UINT nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_COMMANDTOINDEX, nID, 0L);
+ }
+
+ void SaveState(HKEY hKeyRoot, LPCTSTR lpszSubKey, LPCTSTR lpszValueName)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TBSAVEPARAMS tbs = {};
+ tbs.hkr = hKeyRoot;
+ tbs.pszSubKey = lpszSubKey;
+ tbs.pszValueName = lpszValueName;
+ ::SendMessage(this->m_hWnd, TB_SAVERESTORE, (WPARAM)TRUE, (LPARAM)&tbs);
+ }
+
+ void RestoreState(HKEY hKeyRoot, LPCTSTR lpszSubKey, LPCTSTR lpszValueName)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TBSAVEPARAMS tbs = {};
+ tbs.hkr = hKeyRoot;
+ tbs.pszSubKey = lpszSubKey;
+ tbs.pszValueName = lpszValueName;
+ ::SendMessage(this->m_hWnd, TB_SAVERESTORE, (WPARAM)FALSE, (LPARAM)&tbs);
+ }
+
+ void Customize()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_CUSTOMIZE, 0, 0L);
+ }
+
+ int AddString(UINT nStringID)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_ADDSTRING, (WPARAM)ModuleHelper::GetResourceInstance(), (LPARAM)nStringID);
+ }
+
+ int AddStrings(LPCTSTR lpszStrings)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_ADDSTRING, 0, (LPARAM)lpszStrings);
+ }
+
+ void AutoSize()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TB_AUTOSIZE, 0, 0L);
+ }
+
+ BOOL ChangeBitmap(int nID, int nBitmap)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_CHANGEBITMAP, nID, MAKELPARAM(nBitmap, 0));
+ }
+
+ int LoadImages(int nBitmapID)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_LOADIMAGES, nBitmapID, (LPARAM)ModuleHelper::GetResourceInstance());
+ }
+
+ int LoadStdImages(int nBitmapID)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_LOADIMAGES, nBitmapID, (LPARAM)HINST_COMMCTRL);
+ }
+
+ BOOL ReplaceBitmap(LPTBREPLACEBITMAP ptbrb)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_REPLACEBITMAP, 0, (LPARAM)ptbrb);
+ }
+
+ int HitTest(LPPOINT lpPoint) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TB_HITTEST, 0, (LPARAM)lpPoint);
+ }
+
+ BOOL InsertMarkHitTest(LPPOINT lpPoint, LPTBINSERTMARK lptbim) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_INSERTMARKHITTEST, (WPARAM)lpPoint, (LPARAM)lptbim);
+ }
+
+ BOOL InsertMarkHitTest(int x, int y, LPTBINSERTMARK lptbim) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ POINT pt = { x, y };
+ return (BOOL)::SendMessage(this->m_hWnd, TB_INSERTMARKHITTEST, (WPARAM)&pt, (LPARAM)lptbim);
+ }
+
+ BOOL MapAccelerator(TCHAR chAccel, int& nID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_MAPACCELERATOR, (WPARAM)chAccel, (LPARAM)&nID);
+ }
+
+ BOOL MarkButton(int nID, BOOL bHighlight = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_MARKBUTTON, nID, MAKELPARAM(bHighlight, 0));
+ }
+
+ BOOL MoveButton(int nOldPos, int nNewPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TB_MOVEBUTTON, nOldPos, nNewPos);
+ }
+
+ HRESULT GetObject(REFIID iid, LPVOID* ppvObject)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HRESULT)::SendMessage(this->m_hWnd, TB_GETOBJECT, (WPARAM)&iid, (LPARAM)ppvObject);
+ }
+};
+
+typedef CToolBarCtrlT<ATL::CWindow> CToolBarCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CStatusBarCtrl
+
+template <class TBase>
+class CStatusBarCtrlT : public TBase
+{
+public:
+// Constructors
+ CStatusBarCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CStatusBarCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Methods
+ static LPCTSTR GetWndClassName()
+ {
+ return STATUSCLASSNAME;
+ }
+
+ int GetParts(int nParts, int* pParts) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, SB_GETPARTS, nParts, (LPARAM)pParts);
+ }
+
+ BOOL SetParts(int nParts, int* pWidths)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, SB_SETPARTS, nParts, (LPARAM)pWidths);
+ }
+
+ int GetTextLength(int nPane, int* pType = NULL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, SB_GETTEXTLENGTH, (WPARAM)nPane, 0L);
+ if (pType != NULL)
+ *pType = (int)(short)HIWORD(dwRet);
+ return (int)(short)LOWORD(dwRet);
+ }
+
+ int GetText(int nPane, LPTSTR lpszText, int* pType = NULL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, SB_GETTEXT, (WPARAM)nPane, (LPARAM)lpszText);
+ if(pType != NULL)
+ *pType = (int)(short)HIWORD(dwRet);
+ return (int)(short)LOWORD(dwRet);
+ }
+
+ BOOL GetTextBSTR(int nPane, BSTR& bstrText, int* pType = NULL) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ ATLASSERT(bstrText == NULL);
+ int nLength = (int)(short)LOWORD(::SendMessage(this->m_hWnd, SB_GETTEXTLENGTH, (WPARAM)nPane, 0L));
+ if(nLength == 0)
+ return FALSE;
+
+ ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff;
+ LPTSTR lpstrText = buff.Allocate(nLength + 1);
+ if(lpstrText == NULL)
+ return FALSE;
+
+ if(!GetText(nPane, lpstrText, pType))
+ return FALSE;
+
+ bstrText = ::SysAllocString(T2OLE(lpstrText));
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+
+#ifdef __ATLSTR_H__
+ int GetText(int nPane, ATL::CString& strText, int* pType = NULL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ int nLength = (int)(short)LOWORD(::SendMessage(this->m_hWnd, SB_GETTEXTLENGTH, (WPARAM)nPane, 0L));
+ if(nLength == 0)
+ return 0;
+
+ LPTSTR lpstr = strText.GetBufferSetLength(nLength);
+ if(lpstr == NULL)
+ return 0;
+ return GetText(nPane, lpstr, pType);
+ }
+#endif // __ATLSTR_H__
+
+ BOOL SetText(int nPane, LPCTSTR lpszText, int nType = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ return (BOOL)::SendMessage(this->m_hWnd, SB_SETTEXT, (nPane | nType), (LPARAM)lpszText);
+ }
+
+ BOOL GetRect(int nPane, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ return (BOOL)::SendMessage(this->m_hWnd, SB_GETRECT, nPane, (LPARAM)lpRect);
+ }
+
+ BOOL GetBorders(int* pBorders) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, SB_GETBORDERS, 0, (LPARAM)pBorders);
+ }
+
+ BOOL GetBorders(int& nHorz, int& nVert, int& nSpacing) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ int borders[3] = {};
+ BOOL bResult = (BOOL)::SendMessage(this->m_hWnd, SB_GETBORDERS, 0, (LPARAM)&borders);
+ if(bResult)
+ {
+ nHorz = borders[0];
+ nVert = borders[1];
+ nSpacing = borders[2];
+ }
+ return bResult;
+ }
+
+ void SetMinHeight(int nMin)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, SB_SETMINHEIGHT, nMin, 0L);
+ }
+
+ BOOL SetSimple(BOOL bSimple = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, SB_SIMPLE, bSimple, 0L);
+ }
+
+ BOOL IsSimple() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, SB_ISSIMPLE, 0, 0L);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, SB_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, SB_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+ void GetTipText(int nPane, LPTSTR lpstrText, int nSize) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ ::SendMessage(this->m_hWnd, SB_GETTIPTEXT, MAKEWPARAM(nPane, nSize), (LPARAM)lpstrText);
+ }
+
+ void SetTipText(int nPane, LPCTSTR lpstrText)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ ::SendMessage(this->m_hWnd, SB_SETTIPTEXT, nPane, (LPARAM)lpstrText);
+ }
+
+ COLORREF SetBkColor(COLORREF clrBk)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, SB_SETBKCOLOR, 0, (LPARAM)clrBk);
+ }
+
+ HICON GetIcon(int nPane) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ return (HICON)::SendMessage(this->m_hWnd, SB_GETICON, nPane, 0L);
+ }
+
+ BOOL SetIcon(int nPane, HICON hIcon)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(nPane < 256);
+ return (BOOL)::SendMessage(this->m_hWnd, SB_SETICON, nPane, (LPARAM)hIcon);
+ }
+};
+
+typedef CStatusBarCtrlT<ATL::CWindow> CStatusBarCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CTabCtrl
+
+template <class TBase>
+class CTabCtrlT : public TBase
+{
+public:
+// Constructors
+ CTabCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CTabCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return WC_TABCONTROL;
+ }
+
+ CImageList GetImageList() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TCM_GETIMAGELIST, 0, 0L));
+ }
+
+ CImageList SetImageList(HIMAGELIST hImageList)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, TCM_SETIMAGELIST, 0, (LPARAM)hImageList));
+ }
+
+ int GetItemCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TCM_GETITEMCOUNT, 0, 0L);
+ }
+
+ BOOL GetItem(int nItem, LPTCITEM pTabCtrlItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_GETITEM, nItem, (LPARAM)pTabCtrlItem);
+ }
+
+ BOOL SetItem(int nItem, LPTCITEM pTabCtrlItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_SETITEM, nItem, (LPARAM)pTabCtrlItem);
+ }
+
+ int SetItem(int nItem, UINT mask, LPCTSTR lpszItem, DWORD dwState, DWORD dwStateMask, int iImage, LPARAM lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TCITEM tci = {};
+ tci.mask = mask;
+ tci.pszText = (LPTSTR) lpszItem;
+ tci.dwState = dwState;
+ tci.dwStateMask = dwStateMask;
+ tci.iImage = iImage;
+ tci.lParam = lParam;
+ return (int)::SendMessage(this->m_hWnd, TCM_SETITEM, nItem, (LPARAM)&tci);
+ }
+
+ BOOL GetItemRect(int nItem, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_GETITEMRECT, nItem, (LPARAM)lpRect);
+ }
+
+ int GetCurSel() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TCM_GETCURSEL, 0, 0L);
+ }
+
+ int SetCurSel(int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TCM_SETCURSEL, nItem, 0L);
+ }
+
+ SIZE SetItemSize(SIZE size)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwSize = (DWORD)::SendMessage(this->m_hWnd, TCM_SETITEMSIZE, 0, MAKELPARAM(size.cx, size.cy));
+ SIZE sizeRet = { GET_X_LPARAM(dwSize), GET_Y_LPARAM(dwSize) };
+ return sizeRet;
+ }
+
+ void SetItemSize(int cx, int cy)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TCM_SETITEMSIZE, 0, MAKELPARAM(cx, cy));
+ }
+
+ void SetPadding(SIZE size)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TCM_SETPADDING, 0, MAKELPARAM(size.cx, size.cy));
+ }
+
+ int GetRowCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TCM_GETROWCOUNT, 0, 0L);
+ }
+
+ CToolTipCtrl GetToolTips() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TCM_GETTOOLTIPS, 0, 0L));
+ }
+
+ // this method is deprecated, please use GetToolTips
+ CToolTipCtrl GetTooltips() const { return GetToolTips(); }
+
+ void SetToolTips(HWND hWndToolTip)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TCM_SETTOOLTIPS, (WPARAM)hWndToolTip, 0L);
+ }
+
+ // this method is deprecated, please use SetToolTips
+ void SetTooltips(HWND hWndToolTip) { SetToolTips(hWndToolTip); }
+
+ int GetCurFocus() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TCM_GETCURFOCUS, 0, 0L);
+ }
+
+ void SetCurFocus(int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TCM_SETCURFOCUS, nItem, 0L);
+ }
+
+ BOOL SetItemExtra(int cbExtra)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(GetItemCount() == 0); // must be empty
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_SETITEMEXTRA, cbExtra, 0L);
+ }
+
+ int SetMinTabWidth(int nWidth = -1)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TCM_SETMINTABWIDTH, 0, nWidth);
+ }
+
+ DWORD GetExtendedStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TCM_GETEXTENDEDSTYLE, 0, 0L);
+ }
+
+ DWORD SetExtendedStyle(DWORD dwExMask, DWORD dwExStyle)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, TCM_SETEXTENDEDSTYLE, dwExMask, dwExStyle);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+// Operations
+ int InsertItem(int nItem, LPTCITEM pTabCtrlItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TCM_INSERTITEM, nItem, (LPARAM)pTabCtrlItem);
+ }
+
+ int InsertItem(int nItem, UINT mask, LPCTSTR lpszItem, int iImage, LPARAM lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TCITEM tci = {};
+ tci.mask = mask;
+ tci.pszText = (LPTSTR) lpszItem;
+ tci.iImage = iImage;
+ tci.lParam = lParam;
+ return (int)::SendMessage(this->m_hWnd, TCM_INSERTITEM, nItem, (LPARAM)&tci);
+ }
+
+ int InsertItem(int nItem, LPCTSTR lpszItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TCITEM tci = {};
+ tci.mask = TCIF_TEXT;
+ tci.pszText = (LPTSTR) lpszItem;
+ return (int)::SendMessage(this->m_hWnd, TCM_INSERTITEM, nItem, (LPARAM)&tci);
+ }
+
+ int AddItem(LPTCITEM pTabCtrlItem)
+ {
+ return InsertItem(GetItemCount(), pTabCtrlItem);
+ }
+
+ int AddItem(UINT mask, LPCTSTR lpszItem, int iImage, LPARAM lParam)
+ {
+ return InsertItem(GetItemCount(), mask, lpszItem, iImage, lParam);
+ }
+
+ int AddItem(LPCTSTR lpszItem)
+ {
+ return InsertItem(GetItemCount(), lpszItem);
+ }
+
+ BOOL DeleteItem(int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_DELETEITEM, nItem, 0L);
+ }
+
+ BOOL DeleteAllItems()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_DELETEALLITEMS, 0, 0L);
+ }
+
+ void AdjustRect(BOOL bLarger, LPRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TCM_ADJUSTRECT, bLarger, (LPARAM)lpRect);
+ }
+
+ void RemoveImage(int nImage)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TCM_REMOVEIMAGE, nImage, 0L);
+ }
+
+ int HitTest(TC_HITTESTINFO* pHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TCM_HITTEST, 0, (LPARAM)pHitTestInfo);
+ }
+
+ void DeselectAll(BOOL bExcludeFocus = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TCM_DESELECTALL, bExcludeFocus, 0L);
+ }
+
+ BOOL HighlightItem(int nIndex, BOOL bHighlight = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TCM_HIGHLIGHTITEM, nIndex, MAKELPARAM(bHighlight, 0));
+ }
+};
+
+typedef CTabCtrlT<ATL::CWindow> CTabCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CTrackBarCtrl
+
+template <class TBase>
+class CTrackBarCtrlT : public TBase
+{
+public:
+// Constructors
+ CTrackBarCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CTrackBarCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return TRACKBAR_CLASS;
+ }
+
+ int GetLineSize() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETLINESIZE, 0, 0L);
+ }
+
+ int SetLineSize(int nSize)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_SETLINESIZE, 0, nSize);
+ }
+
+ int GetPageSize() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETPAGESIZE, 0, 0L);
+ }
+
+ int SetPageSize(int nSize)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_SETPAGESIZE, 0, nSize);
+ }
+
+ int GetRangeMin() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETRANGEMIN, 0, 0L);
+ }
+
+ void SetRangeMin(int nMin, BOOL bRedraw = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETRANGEMIN, bRedraw, nMin);
+ }
+
+ int GetRangeMax() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETRANGEMAX, 0, 0L);
+ }
+
+ void SetRangeMax(int nMax, BOOL bRedraw = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETRANGEMAX, bRedraw, nMax);
+ }
+
+ void GetRange(int& nMin, int& nMax) const
+ {
+ nMin = GetRangeMin();
+ nMax = GetRangeMax();
+ }
+
+ void SetRange(int nMin, int nMax, BOOL bRedraw = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETRANGE, bRedraw, MAKELPARAM(nMin, nMax));
+ }
+
+ int GetSelStart() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETSELSTART, 0, 0L);
+ }
+
+ void SetSelStart(int nMin, BOOL bRedraw = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETSELSTART, bRedraw, (LPARAM)nMin);
+ }
+
+ int GetSelEnd() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETSELEND, 0, 0L);
+ }
+
+ void SetSelEnd(int nMax, BOOL bRedraw = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETSELEND, bRedraw, (LPARAM)nMax);
+ }
+
+ void GetSelection(int& nMin, int& nMax) const
+ {
+ nMin = GetSelStart();
+ nMax = GetSelEnd();
+ }
+
+ void SetSelection(int nMin, int nMax, BOOL bRedraw = TRUE)
+ {
+ SetSelStart(nMin, FALSE);
+ SetSelEnd(nMax, bRedraw);
+ }
+
+ void GetChannelRect(LPRECT lprc) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_GETCHANNELRECT, 0, (LPARAM)lprc);
+ }
+
+ void GetThumbRect(LPRECT lprc) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_GETTHUMBRECT, 0, (LPARAM)lprc);
+ }
+
+ int GetPos() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETPOS, 0, 0L);
+ }
+
+ void SetPos(int nPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETPOS, TRUE, nPos);
+ }
+
+ UINT GetNumTics() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, TBM_GETNUMTICS, 0, 0L);
+ }
+
+ DWORD* GetTicArray() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD*)::SendMessage(this->m_hWnd, TBM_GETPTICS, 0, 0L);
+ }
+
+ int GetTic(int nTic) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETTIC, nTic, 0L);
+ }
+
+ BOOL SetTic(int nTic)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TBM_SETTIC, 0, nTic);
+ }
+
+ int GetTicPos(int nTic) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETTICPOS, nTic, 0L);
+ }
+
+ void SetTicFreq(int nFreq)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETTICFREQ, nFreq, 0L);
+ }
+
+ int GetThumbLength() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_GETTHUMBLENGTH, 0, 0L);
+ }
+
+ void SetThumbLength(int nLength)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETTHUMBLENGTH, nLength, 0L);
+ }
+
+ void SetSel(int nStart, int nEnd, BOOL bRedraw = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & TBS_ENABLESELRANGE) != 0);
+ ::SendMessage(this->m_hWnd, TBM_SETSEL, bRedraw, MAKELPARAM(nStart, nEnd));
+ }
+
+ ATL::CWindow GetBuddy(BOOL bLeft = TRUE) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, TBM_GETBUDDY, bLeft, 0L));
+ }
+
+ ATL::CWindow SetBuddy(HWND hWndBuddy, BOOL bLeft = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, TBM_SETBUDDY, bLeft, (LPARAM)hWndBuddy));
+ }
+
+ CToolTipCtrl GetToolTips() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, TBM_GETTOOLTIPS, 0, 0L));
+ }
+
+ void SetToolTips(HWND hWndTT)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETTOOLTIPS, (WPARAM)hWndTT, 0L);
+ }
+
+ int SetTipSide(int nSide)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, TBM_SETTIPSIDE, nSide, 0L);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TBM_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, TBM_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+// Operations
+ void ClearSel(BOOL bRedraw = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_CLEARSEL, bRedraw, 0L);
+ }
+
+ void VerifyPos()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_SETPOS, FALSE, 0L);
+ }
+
+ void ClearTics(BOOL bRedraw = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, TBM_CLEARTICS, bRedraw, 0L);
+ }
+};
+
+typedef CTrackBarCtrlT<ATL::CWindow> CTrackBarCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CUpDownCtrl
+
+template <class TBase>
+class CUpDownCtrlT : public TBase
+{
+public:
+// Constructors
+ CUpDownCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CUpDownCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return UPDOWN_CLASS;
+ }
+
+ UINT GetAccel(int nAccel, UDACCEL* pAccel) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)LOWORD(::SendMessage(this->m_hWnd, UDM_GETACCEL, nAccel, (LPARAM)pAccel));
+ }
+
+ BOOL SetAccel(int nAccel, UDACCEL* pAccel)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)LOWORD(::SendMessage(this->m_hWnd, UDM_SETACCEL, nAccel, (LPARAM)pAccel));
+ }
+
+ UINT GetBase() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)LOWORD(::SendMessage(this->m_hWnd, UDM_GETBASE, 0, 0L));
+ }
+
+ int SetBase(int nBase)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, UDM_SETBASE, nBase, 0L);
+ }
+
+ ATL::CWindow GetBuddy() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, UDM_GETBUDDY, 0, 0L));
+ }
+
+ ATL::CWindow SetBuddy(HWND hWndBuddy)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, UDM_SETBUDDY, (WPARAM)hWndBuddy, 0L));
+ }
+
+ int GetPos(LPBOOL lpbError = NULL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, UDM_GETPOS, 0, 0L);
+ // Note: Seems that Windows always sets error to TRUE if
+ // UDS_SETBUDDYINT style is not used
+ if(lpbError != NULL)
+ *lpbError = (HIWORD(dwRet) != 0) ? TRUE : FALSE;
+ return (int)(short)LOWORD(dwRet);
+ }
+
+ int SetPos(int nPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)(short)LOWORD(::SendMessage(this->m_hWnd, UDM_SETPOS, 0, MAKELPARAM(nPos, 0)));
+ }
+
+ DWORD GetRange() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, UDM_GETRANGE, 0, 0L);
+ }
+
+ void GetRange(int& nLower, int& nUpper) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, UDM_GETRANGE, 0, 0L);
+ nLower = (int)(short)HIWORD(dwRet);
+ nUpper = (int)(short)LOWORD(dwRet);
+ }
+
+ void SetRange(int nLower, int nUpper)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, UDM_SETRANGE, 0, MAKELPARAM(nUpper, nLower));
+ }
+
+ void SetRange32(int nLower, int nUpper)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, UDM_SETRANGE32, nLower, nUpper);
+ }
+
+ void GetRange32(int& nLower, int& nUpper) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, UDM_GETRANGE32, (WPARAM)&nLower, (LPARAM)&nUpper);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, UDM_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, UDM_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+ int GetPos32(LPBOOL lpbError = NULL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ // Note: Seems that Windows always sets error to TRUE if
+ // UDS_SETBUDDYINT style is not used
+ return (int)::SendMessage(this->m_hWnd, UDM_GETPOS32, 0, (LPARAM)lpbError);
+ }
+
+ int SetPos32(int nPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, UDM_SETPOS32, 0, (LPARAM)nPos);
+ }
+};
+
+typedef CUpDownCtrlT<ATL::CWindow> CUpDownCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CProgressBarCtrl
+
+template <class TBase>
+class CProgressBarCtrlT : public TBase
+{
+public:
+// Constructors
+ CProgressBarCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CProgressBarCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return PROGRESS_CLASS;
+ }
+
+ DWORD SetRange(int nLower, int nUpper)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, PBM_SETRANGE, 0, MAKELPARAM(nLower, nUpper));
+ }
+
+ int SetPos(int nPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)(short)LOWORD(::SendMessage(this->m_hWnd, PBM_SETPOS, nPos, 0L));
+ }
+
+ int OffsetPos(int nPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)(short)LOWORD(::SendMessage(this->m_hWnd, PBM_DELTAPOS, nPos, 0L));
+ }
+
+ int SetStep(int nStep)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)(short)LOWORD(::SendMessage(this->m_hWnd, PBM_SETSTEP, nStep, 0L));
+ }
+
+ UINT GetPos() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, PBM_GETPOS, 0, 0L);
+ }
+
+ void GetRange(PPBRANGE pPBRange) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(pPBRange != NULL);
+ ::SendMessage(this->m_hWnd, PBM_GETRANGE, TRUE, (LPARAM)pPBRange);
+ }
+
+ void GetRange(int& nLower, int& nUpper) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ PBRANGE range = {};
+ ::SendMessage(this->m_hWnd, PBM_GETRANGE, TRUE, (LPARAM)&range);
+ nLower = range.iLow;
+ nUpper = range.iHigh;
+ }
+
+ int GetRangeLimit(BOOL bLowLimit) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PBM_GETRANGE, bLowLimit, (LPARAM)NULL);
+ }
+
+ DWORD SetRange32(int nMin, int nMax)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, PBM_SETRANGE32, nMin, nMax);
+ }
+
+ COLORREF SetBarColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, PBM_SETBARCOLOR, 0, (LPARAM)clr);
+ }
+
+ COLORREF SetBkColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, PBM_SETBKCOLOR, 0, (LPARAM)clr);
+ }
+
+#ifdef PBM_SETMARQUEE
+ BOOL SetMarquee(BOOL bMarquee, UINT uUpdateTime = 0U)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, PBM_SETMARQUEE, (WPARAM)bMarquee, (LPARAM)uUpdateTime);
+ }
+#endif
+
+#if (_WIN32_WINNT >= 0x0600)
+ int GetStep() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PBM_GETSTEP, 0, 0L);
+ }
+
+ COLORREF GetBkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, PBM_GETBKCOLOR, 0, 0L);
+ }
+
+ COLORREF GetBarColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, PBM_GETBARCOLOR, 0, 0L);
+ }
+
+ int GetState() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PBM_GETSTATE, 0, 0L);
+ }
+
+ int SetState(int nState)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PBM_SETSTATE, nState, 0L);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+// Operations
+ int StepIt()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)(short)LOWORD(::SendMessage(this->m_hWnd, PBM_STEPIT, 0, 0L));
+ }
+};
+
+typedef CProgressBarCtrlT<ATL::CWindow> CProgressBarCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CHotKeyCtrl
+
+template <class TBase>
+class CHotKeyCtrlT : public TBase
+{
+public:
+// Constructors
+ CHotKeyCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CHotKeyCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return HOTKEY_CLASS;
+ }
+
+ DWORD GetHotKey() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, HKM_GETHOTKEY, 0, 0L);
+ }
+
+ void GetHotKey(WORD &wVirtualKeyCode, WORD &wModifiers) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dw = (DWORD)::SendMessage(this->m_hWnd, HKM_GETHOTKEY, 0, 0L);
+ wVirtualKeyCode = LOBYTE(LOWORD(dw));
+ wModifiers = HIBYTE(LOWORD(dw));
+ }
+
+ void SetHotKey(WORD wVirtualKeyCode, WORD wModifiers)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, HKM_SETHOTKEY, MAKEWORD(wVirtualKeyCode, wModifiers), 0L);
+ }
+
+ void SetRules(WORD wInvalidComb, WORD wModifiers)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, HKM_SETRULES, wInvalidComb, MAKELPARAM(wModifiers, 0));
+ }
+};
+
+typedef CHotKeyCtrlT<ATL::CWindow> CHotKeyCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CAnimateCtrl
+
+template <class TBase>
+class CAnimateCtrlT : public TBase
+{
+public:
+// Constructors
+ CAnimateCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CAnimateCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return ANIMATE_CLASS;
+ }
+
+// Operations
+ BOOL Open(ATL::_U_STRINGorID FileName)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, ACM_OPEN, 0, (LPARAM)FileName.m_lpstr);
+ }
+
+ BOOL Play(UINT nFrom, UINT nTo, UINT nRep)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, ACM_PLAY, nRep, MAKELPARAM(nFrom, nTo));
+ }
+
+ BOOL Stop()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, ACM_STOP, 0, 0L);
+ }
+
+ BOOL Close()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, ACM_OPEN, 0, 0L);
+ }
+
+ BOOL Seek(UINT nTo)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, ACM_PLAY, 0, MAKELPARAM(nTo, nTo));
+ }
+
+ // Vista only
+ BOOL IsPlaying() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, ACM_ISPLAYING, 0, 0L);
+ }
+};
+
+typedef CAnimateCtrlT<ATL::CWindow> CAnimateCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CRichEditCtrl
+
+#if !defined(_UNICODE) && (_RICHEDIT_VER >= 0x0500)
+ #undef MSFTEDIT_CLASS
+ #define MSFTEDIT_CLASS "RICHEDIT50W"
+#endif
+
+template <class TBase>
+class CRichEditCtrlT : public TBase
+{
+public:
+// Constructors
+ CRichEditCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CRichEditCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+#if (_RICHEDIT_VER >= 0x0500)
+ return MSFTEDIT_CLASS;
+#else
+ return RICHEDIT_CLASS;
+#endif
+ }
+
+ static LPCTSTR GetLibraryName()
+ {
+#if (_RICHEDIT_VER >= 0x0500)
+ return _T("MSFTEDIT.DLL");
+#else
+ return _T("RICHED20.DLL");
+#endif
+ }
+
+ int GetLineCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETLINECOUNT, 0, 0L);
+ }
+
+ BOOL GetModify() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_GETMODIFY, 0, 0L);
+ }
+
+ void SetModify(BOOL bModified = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETMODIFY, bModified, 0L);
+ }
+
+ void GetRect(LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_GETRECT, 0, (LPARAM)lpRect);
+ }
+
+ DWORD GetOptions() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETOPTIONS, 0, 0L);
+ }
+
+ DWORD SetOptions(WORD wOperation, DWORD dwOptions)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_SETOPTIONS, wOperation, dwOptions);
+ }
+
+ // NOTE: first word in lpszBuffer must contain the size of the buffer!
+ int GetLine(int nIndex, LPTSTR lpszBuffer) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETLINE, nIndex, (LPARAM)lpszBuffer);
+ }
+
+ int GetLine(int nIndex, LPTSTR lpszBuffer, int nMaxLength) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ *(LPWORD)lpszBuffer = (WORD)nMaxLength;
+ return (int)::SendMessage(this->m_hWnd, EM_GETLINE, nIndex, (LPARAM)lpszBuffer);
+ }
+
+ BOOL CanUndo() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_CANUNDO, 0, 0L);
+ }
+
+ BOOL CanPaste(UINT nFormat = 0) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_CANPASTE, nFormat, 0L);
+ }
+
+ void GetSel(LONG& nStartChar, LONG& nEndChar) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ CHARRANGE cr = {};
+ ::SendMessage(this->m_hWnd, EM_EXGETSEL, 0, (LPARAM)&cr);
+ nStartChar = cr.cpMin;
+ nEndChar = cr.cpMax;
+ }
+
+ void GetSel(CHARRANGE &cr) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_EXGETSEL, 0, (LPARAM)&cr);
+ }
+
+ int SetSel(LONG nStartChar, LONG nEndChar)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ CHARRANGE cr = { nStartChar, nEndChar };
+ return (int)::SendMessage(this->m_hWnd, EM_EXSETSEL, 0, (LPARAM)&cr);
+ }
+
+ int SetSel(CHARRANGE &cr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_EXSETSEL, 0, (LPARAM)&cr);
+ }
+
+ int SetSelAll()
+ {
+ return SetSel(0, -1);
+ }
+
+ int SetSelNone()
+ {
+ return SetSel(-1, 0);
+ }
+
+ DWORD GetDefaultCharFormat(CHARFORMAT& cf) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT);
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETCHARFORMAT, 0, (LPARAM)&cf);
+ }
+
+ DWORD GetSelectionCharFormat(CHARFORMAT& cf) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT);
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETCHARFORMAT, 1, (LPARAM)&cf);
+ }
+
+ DWORD GetEventMask() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETEVENTMASK, 0, 0L);
+ }
+
+ LONG GetLimitText() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LONG)::SendMessage(this->m_hWnd, EM_GETLIMITTEXT, 0, 0L);
+ }
+
+ DWORD GetParaFormat(PARAFORMAT& pf) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ pf.cbSize = sizeof(PARAFORMAT);
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETPARAFORMAT, 0, (LPARAM)&pf);
+ }
+
+ LONG GetSelText(LPTSTR lpstrBuff) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LONG)::SendMessage(this->m_hWnd, EM_GETSELTEXT, 0, (LPARAM)lpstrBuff);
+ }
+
+ BOOL GetSelTextBSTR(BSTR& bstrText) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(bstrText == NULL);
+
+ CHARRANGE cr = {};
+ ::SendMessage(this->m_hWnd, EM_EXGETSEL, 0, (LPARAM)&cr);
+
+ ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff;
+ LPTSTR lpstrText = buff.Allocate(cr.cpMax - cr.cpMin + 1);
+ if(lpstrText == NULL)
+ return FALSE;
+ if(::SendMessage(this->m_hWnd, EM_GETSELTEXT, 0, (LPARAM)lpstrText) == 0)
+ return FALSE;
+
+ bstrText = ::SysAllocString(T2W(lpstrText));
+
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+
+#ifdef __ATLSTR_H__
+ LONG GetSelText(ATL::CString& strText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+
+ CHARRANGE cr = {};
+ ::SendMessage(this->m_hWnd, EM_EXGETSEL, 0, (LPARAM)&cr);
+
+ LONG lLen = 0;
+ LPTSTR lpstrText = strText.GetBufferSetLength(cr.cpMax - cr.cpMin);
+ if(lpstrText != NULL)
+ {
+ lLen = (LONG)::SendMessage(this->m_hWnd, EM_GETSELTEXT, 0, (LPARAM)lpstrText);
+ strText.ReleaseBuffer();
+ }
+
+ return lLen;
+ }
+#endif // __ATLSTR_H__
+
+ WORD GetSelectionType() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (WORD)::SendMessage(this->m_hWnd, EM_SELECTIONTYPE, 0, 0L);
+ }
+
+ COLORREF SetBackgroundColor(COLORREF cr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, EM_SETBKGNDCOLOR, 0, cr);
+ }
+
+ COLORREF SetBackgroundColor() // sets to system background
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, EM_SETBKGNDCOLOR, 1, 0);
+ }
+
+ BOOL SetCharFormat(CHARFORMAT& cf, WORD wFlags)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, (WPARAM)wFlags, (LPARAM)&cf);
+ }
+
+ BOOL SetDefaultCharFormat(CHARFORMAT& cf)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, 0, (LPARAM)&cf);
+ }
+
+ BOOL SetSelectionCharFormat(CHARFORMAT& cf)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, SCF_SELECTION, (LPARAM)&cf);
+ }
+
+ BOOL SetWordCharFormat(CHARFORMAT& cf)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, SCF_SELECTION | SCF_WORD, (LPARAM)&cf);
+ }
+
+ DWORD SetEventMask(DWORD dwEventMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_SETEVENTMASK, 0, dwEventMask);
+ }
+
+ BOOL SetParaFormat(PARAFORMAT& pf)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ pf.cbSize = sizeof(PARAFORMAT);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETPARAFORMAT, 0, (LPARAM)&pf);
+ }
+
+ BOOL SetTargetDevice(HDC hDC, int cxLineWidth)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETTARGETDEVICE, (WPARAM)hDC, cxLineWidth);
+ }
+
+ int GetTextLength() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, WM_GETTEXTLENGTH, 0, 0L);
+ }
+
+ BOOL SetReadOnly(BOOL bReadOnly = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETREADONLY, bReadOnly, 0L);
+ }
+
+ int GetFirstVisibleLine() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETFIRSTVISIBLELINE, 0, 0L);
+ }
+
+ int GetTextRange(TEXTRANGE* pTextRange) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETTEXTRANGE, 0, (LPARAM)pTextRange);
+ }
+
+ int GetTextRange(LONG nStartChar, LONG nEndChar, LPTSTR lpstrText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ TEXTRANGE tr = {};
+ tr.chrg.cpMin = nStartChar;
+ tr.chrg.cpMax = nEndChar;
+ tr.lpstrText = lpstrText;
+ return (int)::SendMessage(this->m_hWnd, EM_GETTEXTRANGE, 0, (LPARAM)&tr);
+ }
+
+ DWORD GetDefaultCharFormat(CHARFORMAT2& cf) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT2);
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETCHARFORMAT, 0, (LPARAM)&cf);
+ }
+
+ BOOL SetCharFormat(CHARFORMAT2& cf, WORD wFlags)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT2);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, (WPARAM)wFlags, (LPARAM)&cf);
+ }
+
+ BOOL SetDefaultCharFormat(CHARFORMAT2& cf)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT2);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, 0, (LPARAM)&cf);
+ }
+
+ DWORD GetSelectionCharFormat(CHARFORMAT2& cf) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT2);
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETCHARFORMAT, 1, (LPARAM)&cf);
+ }
+
+ BOOL SetSelectionCharFormat(CHARFORMAT2& cf)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT2);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, SCF_SELECTION, (LPARAM)&cf);
+ }
+
+ BOOL SetWordCharFormat(CHARFORMAT2& cf)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ cf.cbSize = sizeof(CHARFORMAT2);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETCHARFORMAT, SCF_SELECTION | SCF_WORD, (LPARAM)&cf);
+ }
+
+ DWORD GetParaFormat(PARAFORMAT2& pf) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ pf.cbSize = sizeof(PARAFORMAT2);
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETPARAFORMAT, 0, (LPARAM)&pf);
+ }
+
+ BOOL SetParaFormat(PARAFORMAT2& pf)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ pf.cbSize = sizeof(PARAFORMAT2);
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETPARAFORMAT, 0, (LPARAM)&pf);
+ }
+
+ TEXTMODE GetTextMode() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (TEXTMODE)::SendMessage(this->m_hWnd, EM_GETTEXTMODE, 0, 0L);
+ }
+
+ BOOL SetTextMode(TEXTMODE enumTextMode)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return !(BOOL)::SendMessage(this->m_hWnd, EM_SETTEXTMODE, enumTextMode, 0L);
+ }
+
+ UNDONAMEID GetUndoName() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UNDONAMEID)::SendMessage(this->m_hWnd, EM_GETUNDONAME, 0, 0L);
+ }
+
+ UNDONAMEID GetRedoName() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UNDONAMEID)::SendMessage(this->m_hWnd, EM_GETREDONAME, 0, 0L);
+ }
+
+ BOOL CanRedo() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_CANREDO, 0, 0L);
+ }
+
+ BOOL GetAutoURLDetect() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_GETAUTOURLDETECT, 0, 0L);
+ }
+
+ BOOL SetAutoURLDetect(BOOL bAutoDetect = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return !(BOOL)::SendMessage(this->m_hWnd, EM_AUTOURLDETECT, bAutoDetect, 0L);
+ }
+
+ // this method is deprecated, please use SetAutoURLDetect
+ BOOL EnableAutoURLDetect(BOOL bEnable = TRUE) { return SetAutoURLDetect(bEnable); }
+
+ UINT SetUndoLimit(UINT uUndoLimit)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, EM_SETUNDOLIMIT, uUndoLimit, 0L);
+ }
+
+ void SetPalette(HPALETTE hPalette)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETPALETTE, (WPARAM)hPalette, 0L);
+ }
+
+ int GetTextEx(GETTEXTEX* pGetTextEx, LPTSTR lpstrText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETTEXTEX, (WPARAM)pGetTextEx, (LPARAM)lpstrText);
+ }
+
+ int GetTextEx(LPTSTR lpstrText, int nTextLen, DWORD dwFlags = GT_DEFAULT, UINT uCodePage = CP_ACP, LPCSTR lpDefaultChar = NULL, LPBOOL lpUsedDefChar = NULL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ GETTEXTEX gte = {};
+ gte.cb = nTextLen * sizeof(TCHAR);
+ gte.codepage = uCodePage;
+ gte.flags = dwFlags;
+ gte.lpDefaultChar = lpDefaultChar;
+ gte.lpUsedDefChar = lpUsedDefChar;
+ return (int)::SendMessage(this->m_hWnd, EM_GETTEXTEX, (WPARAM)&gte, (LPARAM)lpstrText);
+ }
+
+ int GetTextLengthEx(GETTEXTLENGTHEX* pGetTextLengthEx) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETTEXTLENGTHEX, (WPARAM)pGetTextLengthEx, 0L);
+ }
+
+ int GetTextLengthEx(DWORD dwFlags = GTL_DEFAULT, UINT uCodePage = CP_ACP) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ GETTEXTLENGTHEX gtle = {};
+ gtle.codepage = uCodePage;
+ gtle.flags = dwFlags;
+ return (int)::SendMessage(this->m_hWnd, EM_GETTEXTLENGTHEX, (WPARAM)&gtle, 0L);
+ }
+
+ EDITWORDBREAKPROC GetWordBreakProc() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (EDITWORDBREAKPROC)::SendMessage(this->m_hWnd, EM_GETWORDBREAKPROC, 0, 0L);
+ }
+
+ void SetWordBreakProc(EDITWORDBREAKPROC ewbprc)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETWORDBREAKPROC, 0, (LPARAM)ewbprc);
+ }
+
+ int SetTextEx(SETTEXTEX* pSetTextEx, LPCTSTR lpstrText)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_SETTEXTEX, (WPARAM)pSetTextEx, (LPARAM)lpstrText);
+ }
+
+ int SetTextEx(LPCTSTR lpstrText, DWORD dwFlags = ST_DEFAULT, UINT uCodePage = CP_ACP)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ SETTEXTEX ste = {};
+ ste.flags = dwFlags;
+ ste.codepage = uCodePage;
+ return (int)::SendMessage(this->m_hWnd, EM_SETTEXTEX, (WPARAM)&ste, (LPARAM)lpstrText);
+ }
+
+ int GetEditStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_GETEDITSTYLE, 0, 0L);
+ }
+
+ int SetEditStyle(int nStyle, int nMask = -1)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ if(nMask == -1)
+ nMask = nStyle; // set everything specified
+ return (int)::SendMessage(this->m_hWnd, EM_SETEDITSTYLE, nStyle, nMask);
+ }
+
+ BOOL SetFontSize(int nFontSizeDelta)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((nFontSizeDelta >= -1637) && (nFontSizeDelta <= 1638));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETFONTSIZE, nFontSizeDelta, 0L);
+ }
+
+ void GetScrollPos(LPPOINT lpPoint) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpPoint != NULL);
+ ::SendMessage(this->m_hWnd, EM_GETSCROLLPOS, 0, (LPARAM)lpPoint);
+ }
+
+ void SetScrollPos(LPPOINT lpPoint)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpPoint != NULL);
+ ::SendMessage(this->m_hWnd, EM_SETSCROLLPOS, 0, (LPARAM)lpPoint);
+ }
+
+ BOOL GetZoom(int& nNum, int& nDen) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_GETZOOM, (WPARAM)&nNum, (LPARAM)&nDen);
+ }
+
+ BOOL SetZoom(int nNum, int nDen)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((nNum >= 0) && (nNum <= 64));
+ ATLASSERT((nDen >= 0) && (nDen <= 64));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETZOOM, nNum, nDen);
+ }
+
+ BOOL SetZoomOff()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETZOOM, 0, 0L);
+ }
+
+ void SetMargins(UINT nLeft, UINT nRight, WORD wFlags = EC_LEFTMARGIN | EC_RIGHTMARGIN)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETMARGINS, wFlags, MAKELONG(nLeft, nRight));
+ }
+
+ WORD GetTypographyOptions() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (WORD)::SendMessage(this->m_hWnd, EM_GETTYPOGRAPHYOPTIONS, 0, 0L);
+ }
+
+ BOOL SetTypographyOptions(WORD wOptions, WORD wMask) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETTYPOGRAPHYOPTIONS, wOptions, wMask);
+ }
+
+// Operations
+ void LimitText(LONG nChars = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_EXLIMITTEXT, 0, nChars);
+ }
+
+ int LineFromChar(LONG nIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_EXLINEFROMCHAR, 0, nIndex);
+ }
+
+ POINT PosFromChar(LONG nChar) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ POINT point = {};
+ ::SendMessage(this->m_hWnd, EM_POSFROMCHAR, (WPARAM)&point, nChar);
+ return point;
+ }
+
+ int CharFromPos(POINT pt) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ POINTL ptl = { pt.x, pt.y };
+ return (int)::SendMessage(this->m_hWnd, EM_CHARFROMPOS, 0, (LPARAM)&ptl);
+ }
+
+ void EmptyUndoBuffer()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_EMPTYUNDOBUFFER, 0, 0L);
+ }
+
+ int LineIndex(int nLine = -1) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_LINEINDEX, nLine, 0L);
+ }
+
+ int LineLength(int nLine = -1) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, EM_LINELENGTH, nLine, 0L);
+ }
+
+ BOOL LineScroll(int nLines)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_LINESCROLL, 0, nLines);
+ }
+
+ void ReplaceSel(LPCTSTR lpszNewText, BOOL bCanUndo = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_REPLACESEL, (WPARAM) bCanUndo, (LPARAM)lpszNewText);
+ }
+
+ void SetRect(LPCRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETRECT, 0, (LPARAM)lpRect);
+ }
+
+ BOOL DisplayBand(LPRECT pDisplayRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_DISPLAYBAND, 0, (LPARAM)pDisplayRect);
+ }
+
+ LONG FindText(DWORD dwFlags, FINDTEXT& ft) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+#ifdef _UNICODE
+ return (LONG)::SendMessage(this->m_hWnd, EM_FINDTEXTW, dwFlags, (LPARAM)&ft);
+#else
+ return (LONG)::SendMessage(this->m_hWnd, EM_FINDTEXT, dwFlags, (LPARAM)&ft);
+#endif
+ }
+
+ LONG FindText(DWORD dwFlags, FINDTEXTEX& ft) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+#ifdef _UNICODE
+ return (LONG)::SendMessage(this->m_hWnd, EM_FINDTEXTEXW, dwFlags, (LPARAM)&ft);
+#else
+ return (LONG)::SendMessage(this->m_hWnd, EM_FINDTEXTEX, dwFlags, (LPARAM)&ft);
+#endif
+ }
+
+ LONG FormatRange(FORMATRANGE& fr, BOOL bDisplay = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LONG)::SendMessage(this->m_hWnd, EM_FORMATRANGE, bDisplay, (LPARAM)&fr);
+ }
+
+ LONG FormatRange(FORMATRANGE* pFormatRange, BOOL bDisplay = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LONG)::SendMessage(this->m_hWnd, EM_FORMATRANGE, bDisplay, (LPARAM)pFormatRange);
+ }
+
+ void HideSelection(BOOL bHide = TRUE, BOOL bChangeStyle = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_HIDESELECTION, bHide, bChangeStyle);
+ }
+
+ void PasteSpecial(UINT uClipFormat, DWORD dwAspect = 0, HMETAFILE hMF = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ REPASTESPECIAL reps = { dwAspect, (DWORD_PTR)hMF };
+ ::SendMessage(this->m_hWnd, EM_PASTESPECIAL, uClipFormat, (LPARAM)&reps);
+ }
+
+ void RequestResize()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_REQUESTRESIZE, 0, 0L);
+ }
+
+ LONG StreamIn(UINT uFormat, EDITSTREAM& es)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LONG)::SendMessage(this->m_hWnd, EM_STREAMIN, uFormat, (LPARAM)&es);
+ }
+
+ LONG StreamOut(UINT uFormat, EDITSTREAM& es)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (LONG)::SendMessage(this->m_hWnd, EM_STREAMOUT, uFormat, (LPARAM)&es);
+ }
+
+ DWORD FindWordBreak(int nCode, LONG nStartChar)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_FINDWORDBREAK, nCode, nStartChar);
+ }
+
+ // Additional operations
+ void ScrollCaret()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SCROLLCARET, 0, 0L);
+ }
+
+ int InsertText(long nInsertAfterChar, LPCTSTR lpstrText, BOOL bCanUndo = FALSE)
+ {
+ int nRet = SetSel(nInsertAfterChar, nInsertAfterChar);
+ ReplaceSel(lpstrText, bCanUndo);
+ return nRet;
+ }
+
+ int AppendText(LPCTSTR lpstrText, BOOL bCanUndo = FALSE)
+ {
+ return InsertText(this->GetWindowTextLength(), lpstrText, bCanUndo);
+ }
+
+ // Clipboard operations
+ BOOL Undo()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_UNDO, 0, 0L);
+ }
+
+ void Clear()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_CLEAR, 0, 0L);
+ }
+
+ void Copy()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_COPY, 0, 0L);
+ }
+
+ void Cut()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_CUT, 0, 0L);
+ }
+
+ void Paste()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, WM_PASTE, 0, 0L);
+ }
+
+ // OLE support
+ IRichEditOle* GetOleInterface() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ IRichEditOle *pRichEditOle = NULL;
+ ::SendMessage(this->m_hWnd, EM_GETOLEINTERFACE, 0, (LPARAM)&pRichEditOle);
+ return pRichEditOle;
+ }
+
+ BOOL SetOleCallback(IRichEditOleCallback* pCallback)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETOLECALLBACK, 0, (LPARAM)pCallback);
+ }
+
+ BOOL Redo()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_REDO, 0, 0L);
+ }
+
+ void StopGroupTyping()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_STOPGROUPTYPING, 0, 0L);
+ }
+
+ void ShowScrollBar(int nBarType, BOOL bVisible = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SHOWSCROLLBAR, nBarType, bVisible);
+ }
+
+ BOOL SetTabStops(int nTabStops, LPINT rgTabStops)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, nTabStops, (LPARAM)rgTabStops);
+ }
+
+ BOOL SetTabStops()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, 0, 0L);
+ }
+
+ BOOL SetTabStops(const int& cxEachStop) // takes an 'int'
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETTABSTOPS, 1, (LPARAM)(LPINT)&cxEachStop);
+ }
+
+#if (_RICHEDIT_VER >= 0x0800)
+ AutoCorrectProc GetAutoCorrectProc() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (AutoCorrectProc)::SendMessage(this->m_hWnd, EM_GETAUTOCORRECTPROC, 0, 0L);
+ }
+
+ BOOL SetAutoCorrectProc(AutoCorrectProc pfn)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETAUTOCORRECTPROC, (WPARAM)pfn, 0L);
+ }
+
+ BOOL CallAutoCorrectProc(WCHAR ch)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_CALLAUTOCORRECTPROC, (WPARAM)ch, 0L);
+ }
+
+ DWORD GetEditStyleEx() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETEDITSTYLEEX, 0, 0L);
+ }
+
+ DWORD SetEditStyleEx(DWORD dwStyleEx, DWORD dwMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_SETEDITSTYLEEX, dwStyleEx, dwMask);
+ }
+
+ DWORD GetStoryType(int nStoryIndex) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_GETSTORYTYPE, nStoryIndex, 0L);
+ }
+
+ DWORD SetStoryType(int nStoryIndex, DWORD dwStoryType)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, EM_SETSTORYTYPE, nStoryIndex, dwStoryType);
+ }
+
+ DWORD GetEllipsisMode() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+
+ DWORD dwMode = 0;
+ BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, EM_GETELLIPSISMODE, 0, (LPARAM)&dwMode);
+ (void)bRet; // avoid level 4 warning
+ ATLASSERT(bRet != FALSE);
+
+ return dwMode;
+ }
+
+ BOOL SetEllipsisMode(DWORD dwEllipsisMode)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETELLIPSISMODE, 0, dwEllipsisMode);
+ }
+
+ BOOL GetEllipsisState() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_GETELLIPSISSTATE, 0, 0L);
+ }
+
+ BOOL GetTouchOptions(int nTouchOptions) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_GETTOUCHOPTIONS, nTouchOptions, 0L);
+ }
+
+ void SetTouchOptions(int nTouchOptions, BOOL bEnable)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, EM_SETTOUCHOPTIONS, nTouchOptions, bEnable);
+ }
+
+ HRESULT InsertTable(TABLEROWPARMS* pRowParams, TABLECELLPARMS* pCellParams)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HRESULT)::SendMessage(this->m_hWnd, EM_INSERTTABLE, (WPARAM)pRowParams, (LPARAM)pCellParams);
+ }
+
+ HRESULT GetTableParams(TABLEROWPARMS* pRowParams, TABLECELLPARMS* pCellParams) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HRESULT)::SendMessage(this->m_hWnd, EM_GETTABLEPARMS, (WPARAM)pRowParams, (LPARAM)pCellParams);
+ }
+
+ HRESULT SetTableParams(TABLEROWPARMS* pRowParams, TABLECELLPARMS* pCellParams)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HRESULT)::SendMessage(this->m_hWnd, EM_SETTABLEPARMS, (WPARAM)pRowParams, (LPARAM)pCellParams);
+ }
+
+ HRESULT InsertImage(RICHEDIT_IMAGE_PARAMETERS* pParams)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HRESULT)::SendMessage(this->m_hWnd, EM_INSERTIMAGE, 0, (LPARAM)pParams);
+ }
+
+ BOOL SetUiaName(LPCTSTR lpstrName)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, EM_SETUIANAME, 0, (LPARAM)lpstrName);
+ }
+#endif // (_RICHEDIT_VER >= 0x0800)
+};
+
+typedef CRichEditCtrlT<ATL::CWindow> CRichEditCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CRichEditCommands - message handlers for standard EDIT commands
+
+// Chain to CRichEditCommands message map. Your class must also derive from CRichEditCtrl.
+// Example:
+// class CMyRichEdit : public CWindowImpl<CMyRichEdit, CRichEditCtrl>,
+// public CRichEditCommands<CMyRichEdit>
+// {
+// public:
+// BEGIN_MSG_MAP(CMyRichEdit)
+// // your handlers...
+// CHAIN_MSG_MAP_ALT(CRichEditCommands<CMyRichEdit>, 1)
+// END_MSG_MAP()
+// // other stuff...
+// };
+
+template <class T>
+class CRichEditCommands : public CEditCommands< T >
+{
+public:
+ BEGIN_MSG_MAP(CRichEditCommands< T >)
+ ALT_MSG_MAP(1)
+ COMMAND_ID_HANDLER(ID_EDIT_CLEAR, CEditCommands< T >::OnEditClear)
+ COMMAND_ID_HANDLER(ID_EDIT_CLEAR_ALL, CEditCommands< T >::OnEditClearAll)
+ COMMAND_ID_HANDLER(ID_EDIT_COPY, CEditCommands< T >::OnEditCopy)
+ COMMAND_ID_HANDLER(ID_EDIT_CUT, CEditCommands< T >::OnEditCut)
+ COMMAND_ID_HANDLER(ID_EDIT_PASTE, CEditCommands< T >::OnEditPaste)
+ COMMAND_ID_HANDLER(ID_EDIT_SELECT_ALL, CEditCommands< T >::OnEditSelectAll)
+ COMMAND_ID_HANDLER(ID_EDIT_UNDO, CEditCommands< T >::OnEditUndo)
+ COMMAND_ID_HANDLER(ID_EDIT_REDO, OnEditRedo)
+ END_MSG_MAP()
+
+ LRESULT OnEditRedo(WORD /*wNotifyCode*/, WORD /*wID*/, HWND /*hWndCtl*/, BOOL& /*bHandled*/)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->Redo();
+ return 0;
+ }
+
+// State (update UI) helpers
+ BOOL CanCut() const
+ { return HasSelection(); }
+
+ BOOL CanCopy() const
+ { return HasSelection(); }
+
+ BOOL CanClear() const
+ { return HasSelection(); }
+
+// Implementation
+ BOOL HasSelection() const
+ {
+ const T* pT = static_cast<const T*>(this);
+ return (pT->GetSelectionType() != SEL_EMPTY);
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CDragListBox
+
+template <class TBase>
+class CDragListBoxT : public CListBoxT< TBase >
+{
+public:
+// Constructors
+ CDragListBoxT(HWND hWnd = NULL) : CListBoxT< TBase >(hWnd)
+ { }
+
+ CDragListBoxT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ HWND hWnd = TBase::Create(TBase::GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ if(hWnd != NULL)
+ MakeDragList();
+ return hWnd;
+ }
+
+// Operations
+ BOOL MakeDragList()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((this->GetStyle() & (LBS_MULTIPLESEL | LBS_EXTENDEDSEL)) == 0);
+ return ::MakeDragList(this->m_hWnd);
+ }
+
+ int LBItemFromPt(POINT pt, BOOL bAutoScroll = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ::LBItemFromPt(this->m_hWnd, pt, bAutoScroll);
+ }
+
+ void DrawInsert(int nItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::DrawInsert(this->GetParent(), this->m_hWnd, nItem);
+ }
+
+ static UINT GetDragListMessage()
+ {
+ static UINT uDragListMessage = 0;
+ if(uDragListMessage == 0)
+ {
+ CStaticDataInitCriticalSectionLock lock;
+ if(FAILED(lock.Lock()))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ERROR : Unable to lock critical section in CDragListBox::GetDragListMessage.\n"));
+ ATLASSERT(FALSE);
+ return 0;
+ }
+
+ if(uDragListMessage == 0)
+ uDragListMessage = ::RegisterWindowMessage(DRAGLISTMSGSTRING);
+
+ lock.Unlock();
+ }
+ ATLASSERT(uDragListMessage != 0);
+ return uDragListMessage;
+ }
+};
+
+typedef CDragListBoxT<ATL::CWindow> CDragListBox;
+
+template <class T>
+class CDragListNotifyImpl
+{
+public:
+ BEGIN_MSG_MAP(CDragListNotifyImpl< T >)
+ MESSAGE_HANDLER(CDragListBox::GetDragListMessage(), OnDragListNotify)
+ END_MSG_MAP()
+
+ LRESULT OnDragListNotify(UINT uMsg, WPARAM wParam, LPARAM lParam, BOOL& bHandled)
+ {
+ (void)uMsg; // avoid level 4 warning
+ ATLASSERT(uMsg == CDragListBox::GetDragListMessage());
+ T* pT = static_cast<T*>(this);
+ LPDRAGLISTINFO lpDragListInfo = (LPDRAGLISTINFO)lParam;
+ LRESULT lRet = 0;
+ switch(lpDragListInfo->uNotification)
+ {
+ case DL_BEGINDRAG:
+ lRet = (LPARAM)pT->OnBeginDrag((int)wParam, lpDragListInfo->hWnd, lpDragListInfo->ptCursor);
+ break;
+ case DL_CANCELDRAG:
+ pT->OnCancelDrag((int)wParam, lpDragListInfo->hWnd, lpDragListInfo->ptCursor);
+ break;
+ case DL_DRAGGING:
+ lRet = (LPARAM)pT->OnDragging((int)wParam, lpDragListInfo->hWnd, lpDragListInfo->ptCursor);
+ break;
+ case DL_DROPPED:
+ pT->OnDropped((int)wParam, lpDragListInfo->hWnd, lpDragListInfo->ptCursor);
+ break;
+ default:
+ ATLTRACE2(atlTraceUI, 0, _T("Unknown DragListBox notification\n"));
+ bHandled = FALSE; // don't handle it
+ break;
+ }
+ return lRet;
+ }
+
+// Overrideables
+ BOOL OnBeginDrag(int /*nCtlID*/, HWND /*hWndDragList*/, POINT /*ptCursor*/)
+ {
+ return TRUE; // allow dragging
+ }
+
+ void OnCancelDrag(int /*nCtlID*/, HWND /*hWndDragList*/, POINT /*ptCursor*/)
+ {
+ // nothing to do
+ }
+
+ int OnDragging(int /*nCtlID*/, HWND /*hWndDragList*/, POINT /*ptCursor*/)
+ {
+ return 0; // don't change cursor
+ }
+
+ void OnDropped(int /*nCtlID*/, HWND /*hWndDragList*/, POINT /*ptCursor*/)
+ {
+ // nothing to do
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CReBarCtrl
+
+template <class TBase>
+class CReBarCtrlT : public TBase
+{
+public:
+// Constructors
+ CReBarCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CReBarCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return REBARCLASSNAME;
+ }
+
+ UINT GetBandCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, RB_GETBANDCOUNT, 0, 0L);
+ }
+
+ BOOL GetBandInfo(int nBand, LPREBARBANDINFO lprbbi) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_GETBANDINFO, nBand, (LPARAM)lprbbi);
+ }
+
+ BOOL SetBandInfo(int nBand, LPREBARBANDINFO lprbbi)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_SETBANDINFO, nBand, (LPARAM)lprbbi);
+ }
+
+ BOOL GetBarInfo(LPREBARINFO lprbi) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_GETBARINFO, 0, (LPARAM)lprbi);
+ }
+
+ BOOL SetBarInfo(LPREBARINFO lprbi)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_SETBARINFO, 0, (LPARAM)lprbi);
+ }
+
+ CImageList GetImageList() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ REBARINFO rbi = {};
+ rbi.cbSize = sizeof(REBARINFO);
+ rbi.fMask = RBIM_IMAGELIST;
+ BOOL bRet = (BOOL)::SendMessage(this->m_hWnd, RB_GETBARINFO, 0, (LPARAM)&rbi);
+ return CImageList((bRet != FALSE) ? rbi.himl : NULL);
+ }
+
+ BOOL SetImageList(HIMAGELIST hImageList)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ REBARINFO rbi = {};
+ rbi.cbSize = sizeof(REBARINFO);
+ rbi.fMask = RBIM_IMAGELIST;
+ rbi.himl = hImageList;
+ return (BOOL)::SendMessage(this->m_hWnd, RB_SETBARINFO, 0, (LPARAM)&rbi);
+ }
+
+ UINT GetRowCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, RB_GETROWCOUNT, 0, 0L);
+ }
+
+ UINT GetRowHeight(int nBand) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, RB_GETROWHEIGHT, nBand, 0L);
+ }
+
+ COLORREF GetTextColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, RB_GETTEXTCOLOR, 0, 0L);
+ }
+
+ COLORREF SetTextColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, RB_SETTEXTCOLOR, 0, (LPARAM)clr);
+ }
+
+ COLORREF GetBkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, RB_GETBKCOLOR, 0, 0L);
+ }
+
+ COLORREF SetBkColor(COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, RB_SETBKCOLOR, 0, (LPARAM)clr);
+ }
+
+ UINT GetBarHeight() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (UINT)::SendMessage(this->m_hWnd, RB_GETBARHEIGHT, 0, 0L);
+ }
+
+ BOOL GetRect(int nBand, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_GETRECT, nBand, (LPARAM)lpRect);
+ }
+
+ CToolTipCtrl GetToolTips() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CToolTipCtrl((HWND)::SendMessage(this->m_hWnd, RB_GETTOOLTIPS, 0, 0L));
+ }
+
+ void SetToolTips(HWND hwndToolTip)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_SETTOOLTIPS, (WPARAM)hwndToolTip, 0L);
+ }
+
+ void GetBandBorders(int nBand, LPRECT lpRect) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpRect != NULL);
+ ::SendMessage(this->m_hWnd, RB_GETBANDBORDERS, nBand, (LPARAM)lpRect);
+ }
+
+ BOOL GetColorScheme(LPCOLORSCHEME lpColorScheme) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpColorScheme != NULL);
+ return (BOOL)::SendMessage(this->m_hWnd, RB_GETCOLORSCHEME, 0, (LPARAM)lpColorScheme);
+ }
+
+ void SetColorScheme(LPCOLORSCHEME lpColorScheme)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpColorScheme != NULL);
+ ::SendMessage(this->m_hWnd, RB_SETCOLORSCHEME, 0, (LPARAM)lpColorScheme);
+ }
+
+ HPALETTE GetPalette() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HPALETTE)::SendMessage(this->m_hWnd, RB_GETPALETTE, 0, 0L);
+ }
+
+ HPALETTE SetPalette(HPALETTE hPalette)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (HPALETTE)::SendMessage(this->m_hWnd, RB_SETPALETTE, 0, (LPARAM)hPalette);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+ // requires uxtheme.h to be included to use MARGINS struct
+#ifndef _UXTHEME_H_
+ typedef struct _MARGINS* PMARGINS;
+#endif // !_UXTHEME_H_
+ void GetBandMargins(PMARGINS pMargins) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_GETBANDMARGINS, 0, (LPARAM)pMargins);
+ }
+
+ void SetWindowTheme(LPCWSTR lpstrTheme)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_SETWINDOWTHEME, 0, (LPARAM)lpstrTheme);
+ }
+
+ DWORD GetExtendedStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, RB_GETEXTENDEDSTYLE, 0, 0L);
+ }
+
+ DWORD SetExtendedStyle(DWORD dwStyle, DWORD dwMask)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, RB_SETEXTENDEDSTYLE, dwMask, dwStyle);
+ }
+
+// Operations
+ BOOL InsertBand(int nBand, LPREBARBANDINFO lprbbi)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_INSERTBAND, nBand, (LPARAM)lprbbi);
+ }
+
+ BOOL AddBand(LPREBARBANDINFO lprbbi)
+ {
+ return InsertBand(-1, lprbbi);
+ }
+
+ BOOL DeleteBand(int nBand)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_DELETEBAND, nBand, 0L);
+ }
+
+ ATL::CWindow SetNotifyWnd(HWND hWnd)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return ATL::CWindow((HWND)::SendMessage(this->m_hWnd, RB_SETPARENT, (WPARAM)hWnd, 0L));
+ }
+
+ void BeginDrag(int nBand, DWORD dwPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_BEGINDRAG, nBand, dwPos);
+ }
+
+ void BeginDrag(int nBand, int xPos, int yPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_BEGINDRAG, nBand, MAKELPARAM(xPos, yPos));
+ }
+
+ void EndDrag()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_ENDDRAG, 0, 0L);
+ }
+
+ void DragMove(DWORD dwPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_DRAGMOVE, 0, dwPos);
+ }
+
+ void DragMove(int xPos, int yPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_DRAGMOVE, 0, MAKELPARAM(xPos, yPos));
+ }
+
+ void GetDropTarget(IDropTarget** ppDropTarget) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_GETDROPTARGET, 0, (LPARAM)ppDropTarget);
+ }
+
+ void MaximizeBand(int nBand, BOOL bIdeal = FALSE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_MAXIMIZEBAND, nBand, bIdeal);
+ }
+
+ void MinimizeBand(int nBand)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_MINIMIZEBAND, nBand, 0L);
+ }
+
+ BOOL SizeToRect(LPRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_SIZETORECT, 0, (LPARAM)lpRect);
+ }
+
+ int IdToIndex(UINT uBandID) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, RB_IDTOINDEX, uBandID, 0L);
+ }
+
+ int HitTest(LPRBHITTESTINFO lprbht) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, RB_HITTEST, 0, (LPARAM)lprbht);
+ }
+
+ BOOL ShowBand(int nBand, BOOL bShow)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_SHOWBAND, nBand, bShow);
+ }
+
+ BOOL MoveBand(int nBand, int nNewPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((nNewPos >= 0) && (nNewPos <= ((int)GetBandCount() - 1)));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_MOVEBAND, nBand, nNewPos);
+ }
+
+ void PushChevron(int nBand, LPARAM lAppValue)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, RB_PUSHCHEVRON, nBand, lAppValue);
+ }
+
+// Extra operations
+ void LockBands(bool bLock)
+ {
+ int nBandCount = GetBandCount();
+ for(int i =0; i < nBandCount; i++)
+ {
+ REBARBANDINFO rbbi = { RunTimeHelper::SizeOf_REBARBANDINFO() };
+ rbbi.fMask = RBBIM_STYLE;
+ BOOL bRet = GetBandInfo(i, &rbbi);
+ ATLASSERT(bRet);
+
+ if((rbbi.fStyle & RBBS_GRIPPERALWAYS) == 0)
+ {
+ rbbi.fStyle |= RBBS_GRIPPERALWAYS;
+ bRet = SetBandInfo(i, &rbbi);
+ ATLASSERT(bRet);
+ rbbi.fStyle &= ~RBBS_GRIPPERALWAYS;
+ }
+
+ if(bLock)
+ rbbi.fStyle |= RBBS_NOGRIPPER;
+ else
+ rbbi.fStyle &= ~RBBS_NOGRIPPER;
+
+ bRet = SetBandInfo(i, &rbbi);
+ ATLASSERT(bRet);
+ }
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ BOOL SetBandWidth(int nBand, int cxWidth)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, RB_SETBANDWIDTH, nBand, cxWidth);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+};
+
+typedef CReBarCtrlT<ATL::CWindow> CReBarCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CComboBoxEx
+
+template <class TBase>
+class CComboBoxExT : public CComboBoxT< TBase >
+{
+public:
+// Constructors
+ CComboBoxExT(HWND hWnd = NULL) : CComboBoxT< TBase >(hWnd)
+ { }
+
+ CComboBoxExT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return WC_COMBOBOXEX;
+ }
+
+ CImageList GetImageList() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, CBEM_GETIMAGELIST, 0, 0L));
+ }
+
+ CImageList SetImageList(HIMAGELIST hImageList)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CImageList((HIMAGELIST)::SendMessage(this->m_hWnd, CBEM_SETIMAGELIST, 0, (LPARAM)hImageList));
+ }
+
+ DWORD GetExtendedStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, CBEM_GETEXTENDEDSTYLE, 0, 0L);
+ }
+
+ DWORD SetExtendedStyle(DWORD dwExMask, DWORD dwExStyle)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, CBEM_SETEXTENDEDSTYLE, dwExMask, dwExStyle);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CBEM_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CBEM_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+ void SetWindowTheme(LPCWSTR lpstrTheme)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, CBEM_SETWINDOWTHEME, 0, (LPARAM)lpstrTheme);
+ }
+
+// Operations
+ int InsertItem(const COMBOBOXEXITEM* lpcCBItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CBEM_INSERTITEM, 0, (LPARAM)lpcCBItem);
+ }
+
+ int InsertItem(UINT nMask, int nIndex, LPCTSTR lpszItem, int nImage, int nSelImage,
+ int iIndent, int iOverlay, LPARAM lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ COMBOBOXEXITEM cbex = {};
+ cbex.mask = nMask;
+ cbex.iItem = nIndex;
+ cbex.pszText = (LPTSTR) lpszItem;
+ cbex.iImage = nImage;
+ cbex.iSelectedImage = nSelImage;
+ cbex.iIndent = iIndent;
+ cbex.iOverlay = iOverlay;
+ cbex.lParam = lParam;
+ return (int)::SendMessage(this->m_hWnd, CBEM_INSERTITEM, 0, (LPARAM)&cbex);
+ }
+
+ int InsertItem(int nIndex, LPCTSTR lpszItem, int nImage, int nSelImage, int iIndent, LPARAM lParam = 0)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ COMBOBOXEXITEM cbex = {};
+ cbex.mask = CBEIF_TEXT | CBEIF_IMAGE | CBEIF_SELECTEDIMAGE | CBEIF_INDENT | CBEIF_LPARAM;
+ cbex.iItem = nIndex;
+ cbex.pszText = (LPTSTR) lpszItem;
+ cbex.iImage = nImage;
+ cbex.iSelectedImage = nSelImage;
+ cbex.iIndent = iIndent;
+ cbex.lParam = lParam;
+ return (int)::SendMessage(this->m_hWnd, CBEM_INSERTITEM, 0, (LPARAM)&cbex);
+ }
+
+ int AddItem(UINT nMask, LPCTSTR lpszItem, int nImage, int nSelImage, int iIndent, int iOverlay, LPARAM lParam)
+ {
+ return InsertItem(nMask, -1, lpszItem, nImage, nSelImage, iIndent, iOverlay, lParam);
+ }
+
+ int AddItem(LPCTSTR lpszItem, int nImage, int nSelImage, int iIndent, LPARAM lParam = 0)
+ {
+ return InsertItem(-1, lpszItem, nImage, nSelImage, iIndent, lParam);
+ }
+
+ int DeleteItem(int nIndex)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, CBEM_DELETEITEM, nIndex, 0L);
+ }
+
+ BOOL GetItem(PCOMBOBOXEXITEM pCBItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CBEM_GETITEM, 0, (LPARAM)pCBItem);
+ }
+
+ BOOL SetItem(const COMBOBOXEXITEM* lpcCBItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CBEM_SETITEM, 0, (LPARAM)lpcCBItem);
+ }
+
+ int SetItem(int nIndex, UINT nMask, LPCTSTR lpszItem, int nImage, int nSelImage,
+ int iIndent, int iOverlay, LPARAM lParam)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ COMBOBOXEXITEM cbex = {};
+ cbex.mask = nMask;
+ cbex.iItem = nIndex;
+ cbex.pszText = (LPTSTR) lpszItem;
+ cbex.iImage = nImage;
+ cbex.iSelectedImage = nSelImage;
+ cbex.iIndent = iIndent;
+ cbex.iOverlay = iOverlay;
+ cbex.lParam = lParam;
+ return (int)::SendMessage(this->m_hWnd, CBEM_SETITEM, 0, (LPARAM)&cbex);
+ }
+
+ BOOL GetItemText(int nIndex, LPTSTR lpszItem, int nLen) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(lpszItem != NULL);
+
+ COMBOBOXEXITEM cbex = {};
+ cbex.mask = CBEIF_TEXT;
+ cbex.iItem = nIndex;
+ cbex.pszText = lpszItem;
+ cbex.cchTextMax = nLen;
+
+ return (BOOL)::SendMessage(this->m_hWnd, CBEM_GETITEM, 0, (LPARAM)&cbex);
+ }
+
+ BOOL GetItemText(int nIndex, BSTR& bstrText) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(bstrText == NULL);
+
+ COMBOBOXEXITEM cbex = {};
+ cbex.mask = CBEIF_TEXT;
+ cbex.iItem = nIndex;
+
+ LPTSTR lpstrText = NULL;
+ BOOL bRet = FALSE;
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ ATLTRY(lpstrText = new TCHAR[nLen]);
+ if(lpstrText == NULL)
+ break;
+ lpstrText[0] = NULL;
+ cbex.pszText = lpstrText;
+ cbex.cchTextMax = nLen;
+ bRet = (BOOL)::SendMessage(this->m_hWnd, CBEM_GETITEM, 0, (LPARAM)&cbex);
+ if(!bRet || (lstrlen(cbex.pszText) < (nLen - 1)))
+ break;
+ delete [] lpstrText;
+ lpstrText = NULL;
+ }
+
+ if(lpstrText != NULL)
+ {
+ if(bRet)
+ bstrText = ::SysAllocString(T2OLE(lpstrText));
+ delete [] lpstrText;
+ }
+
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+
+#ifdef __ATLSTR_H__
+ BOOL GetItemText(int nIndex, ATL::CString& strText) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+
+ COMBOBOXEXITEM cbex = {};
+ cbex.mask = CBEIF_TEXT;
+ cbex.iItem = nIndex;
+
+ strText.Empty();
+ BOOL bRet = FALSE;
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ cbex.pszText = strText.GetBufferSetLength(nLen);
+ if(cbex.pszText == NULL)
+ {
+ bRet = FALSE;
+ break;
+ }
+ cbex.cchTextMax = nLen;
+ bRet = (BOOL)::SendMessage(this->m_hWnd, CBEM_GETITEM, 0, (LPARAM)&cbex);
+ if(!bRet || (lstrlen(cbex.pszText) < (nLen - 1)))
+ break;
+ }
+ strText.ReleaseBuffer();
+ return bRet;
+ }
+#endif // __ATLSTR_H__
+
+ BOOL SetItemText(int nIndex, LPCTSTR lpszItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return SetItem(nIndex, CBEIF_TEXT, lpszItem, 0, 0, 0, 0, 0);
+ }
+
+ CComboBox GetComboCtrl() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CComboBox((HWND)::SendMessage(this->m_hWnd, CBEM_GETCOMBOCONTROL, 0, 0L));
+ }
+
+ CEdit GetEditCtrl() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CEdit((HWND)::SendMessage(this->m_hWnd, CBEM_GETEDITCONTROL, 0, 0L));
+ }
+
+ BOOL HasEditChanged() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, CBEM_HASEDITCHANGED, 0, 0L);
+ }
+
+// Non-functional
+ int AddString(LPCTSTR /*lpszItem*/)
+ {
+ ATLASSERT(FALSE); // Not available in CComboBoxEx; use InsertItem
+ return 0;
+ }
+
+ int InsertString(int /*nIndex*/, LPCTSTR /*lpszString*/)
+ {
+ ATLASSERT(FALSE); // Not available in CComboBoxEx; use InsertItem
+ return 0;
+ }
+
+ int Dir(UINT /*attr*/, LPCTSTR /*lpszWildCard*/)
+ {
+ ATLASSERT(FALSE); // Not available in CComboBoxEx
+ return 0;
+ }
+
+ int FindString(int /*nStartAfter*/, LPCTSTR /*lpszString*/) const
+ {
+ ATLASSERT(FALSE); // Not available in CComboBoxEx; try FindStringExact
+ return 0;
+ }
+};
+
+typedef CComboBoxExT<ATL::CWindow> CComboBoxEx;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CMonthCalendarCtrl
+
+template <class TBase>
+class CMonthCalendarCtrlT : public TBase
+{
+public:
+// Constructors
+ CMonthCalendarCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CMonthCalendarCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return MONTHCAL_CLASS;
+ }
+
+ COLORREF GetColor(int nColorType) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, MCM_GETCOLOR, nColorType, 0L);
+ }
+
+ COLORREF SetColor(int nColorType, COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, MCM_SETCOLOR, nColorType, clr);
+ }
+
+ BOOL GetCurSel(LPSYSTEMTIME lpSysTime) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_GETCURSEL, 0, (LPARAM)lpSysTime);
+ }
+
+ BOOL SetCurSel(LPSYSTEMTIME lpSysTime)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_SETCURSEL, 0, (LPARAM)lpSysTime);
+ }
+
+ int GetFirstDayOfWeek(BOOL* pbLocaleVal = NULL) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, MCM_GETFIRSTDAYOFWEEK, 0, 0L);
+ if(pbLocaleVal != NULL)
+ *pbLocaleVal = (BOOL)HIWORD(dwRet);
+ return (int)(short)LOWORD(dwRet);
+ }
+
+ int SetFirstDayOfWeek(int nDay, BOOL* pbLocaleVal = NULL)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ DWORD dwRet = (DWORD)::SendMessage(this->m_hWnd, MCM_SETFIRSTDAYOFWEEK, 0, nDay);
+ if(pbLocaleVal != NULL)
+ *pbLocaleVal = (BOOL)HIWORD(dwRet);
+ return (int)(short)LOWORD(dwRet);
+ }
+
+ int GetMaxSelCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, MCM_GETMAXSELCOUNT, 0, 0L);
+ }
+
+ BOOL SetMaxSelCount(int nMax)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_SETMAXSELCOUNT, nMax, 0L);
+ }
+
+ int GetMonthDelta() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, MCM_GETMONTHDELTA, 0, 0L);
+ }
+
+ int SetMonthDelta(int nDelta)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, MCM_SETMONTHDELTA, nDelta, 0L);
+ }
+
+ DWORD GetRange(LPSYSTEMTIME lprgSysTimeArray) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, MCM_GETRANGE, 0, (LPARAM)lprgSysTimeArray);
+ }
+
+ BOOL SetRange(DWORD dwFlags, LPSYSTEMTIME lprgSysTimeArray)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_SETRANGE, dwFlags, (LPARAM)lprgSysTimeArray);
+ }
+
+ BOOL GetSelRange(LPSYSTEMTIME lprgSysTimeArray) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_GETSELRANGE, 0, (LPARAM)lprgSysTimeArray);
+ }
+
+ BOOL SetSelRange(LPSYSTEMTIME lprgSysTimeArray)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_SETSELRANGE, 0, (LPARAM)lprgSysTimeArray);
+ }
+
+ BOOL GetToday(LPSYSTEMTIME lpSysTime) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_GETTODAY, 0, (LPARAM)lpSysTime);
+ }
+
+ void SetToday(LPSYSTEMTIME lpSysTime)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, MCM_SETTODAY, 0, (LPARAM)lpSysTime);
+ }
+
+ BOOL GetMinReqRect(LPRECT lpRectInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_GETMINREQRECT, 0, (LPARAM)lpRectInfo);
+ }
+
+ int GetMaxTodayWidth() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, MCM_GETMAXTODAYWIDTH, 0, 0L);
+ }
+
+ BOOL GetUnicodeFormat() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_GETUNICODEFORMAT, 0, 0L);
+ }
+
+ BOOL SetUnicodeFormat(BOOL bUnicode = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_SETUNICODEFORMAT, bUnicode, 0L);
+ }
+
+#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+ DWORD GetCurrentView() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, MCM_GETCURRENTVIEW, 0, 0L);
+ }
+
+ BOOL SetCurrentView(DWORD dwView)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_SETCURRENTVIEW, 0, dwView);
+ }
+
+ DWORD GetCalendarCount() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, MCM_GETCALENDARCOUNT, 0, 0L);
+ }
+
+ BOOL GetCalendarGridInfo(PMCGRIDINFO pGridInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_GETCALENDARGRIDINFO, 0, (LPARAM)pGridInfo);
+ }
+
+ CALID GetCALID() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (CALID)::SendMessage(this->m_hWnd, MCM_GETCALID, 0, 0L);
+ }
+
+ void SetCALID(CALID calid)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, MCM_SETCALID, (LPARAM)calid, 0L);
+ }
+
+ int GetCalendarBorder() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, MCM_GETCALENDARBORDER, 0, 0L);
+ }
+
+ void SetCalendarBorder(int cxyBorder, BOOL bSet = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, MCM_SETCALENDARBORDER, (WPARAM)bSet, (LPARAM)cxyBorder);
+ }
+#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+
+// Operations
+ int GetMonthRange(DWORD dwFlags, LPSYSTEMTIME lprgSysTimeArray) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, MCM_GETMONTHRANGE, dwFlags, (LPARAM)lprgSysTimeArray);
+ }
+
+ BOOL SetDayState(int nMonths, LPMONTHDAYSTATE lpDayStateArray)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, MCM_SETDAYSTATE, nMonths, (LPARAM)lpDayStateArray);
+ }
+
+ DWORD HitTest(PMCHITTESTINFO pMCHitTest) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, MCM_HITTEST, 0, (LPARAM)pMCHitTest);
+ }
+
+#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+ void SizeRectToMin(LPRECT lpRect)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, MCM_SIZERECTTOMIN, 0, (LPARAM)lpRect);
+ }
+#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+};
+
+typedef CMonthCalendarCtrlT<ATL::CWindow> CMonthCalendarCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CDateTimePickerCtrl
+
+template <class TBase>
+class CDateTimePickerCtrlT : public TBase
+{
+public:
+// Constructors
+ CDateTimePickerCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CDateTimePickerCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Operations
+ static LPCTSTR GetWndClassName()
+ {
+ return DATETIMEPICK_CLASS;
+ }
+
+ BOOL SetFormat(LPCTSTR lpszFormat)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, DTM_SETFORMAT, 0, (LPARAM)lpszFormat);
+ }
+
+ COLORREF GetMonthCalColor(int nColorType) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, DTM_GETMCCOLOR, nColorType, 0L);
+ }
+
+ COLORREF SetMonthCalColor(int nColorType, COLORREF clr)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, DTM_SETMCCOLOR, nColorType, clr);
+ }
+
+ DWORD GetRange(LPSYSTEMTIME lpSysTimeArray) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, DTM_GETRANGE, 0, (LPARAM)lpSysTimeArray);
+ }
+
+ BOOL SetRange(DWORD dwFlags, LPSYSTEMTIME lpSysTimeArray)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, DTM_SETRANGE, dwFlags, (LPARAM)lpSysTimeArray);
+ }
+
+ DWORD GetSystemTime(LPSYSTEMTIME lpSysTime) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, DTM_GETSYSTEMTIME, 0, (LPARAM)lpSysTime);
+ }
+
+ BOOL SetSystemTime(DWORD dwFlags, LPSYSTEMTIME lpSysTime)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, DTM_SETSYSTEMTIME, dwFlags, (LPARAM)lpSysTime);
+ }
+
+ CMonthCalendarCtrl GetMonthCal() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CMonthCalendarCtrl((HWND)::SendMessage(this->m_hWnd, DTM_GETMONTHCAL, 0, 0L));
+ }
+
+ CFontHandle GetMonthCalFont() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return CFontHandle((HFONT)::SendMessage(this->m_hWnd, DTM_GETMCFONT, 0, 0L));
+ }
+
+ void SetMonthCalFont(HFONT hFont, BOOL bRedraw = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, DTM_SETMCFONT, (WPARAM)hFont, MAKELPARAM(bRedraw, 0));
+ }
+
+#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+ DWORD GetMonthCalStyle() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, DTM_GETMCSTYLE, 0, 0L);
+ }
+
+ DWORD SetMonthCalStyle(DWORD dwStyle)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (DWORD)::SendMessage(this->m_hWnd, DTM_SETMCSTYLE, 0, (LPARAM)dwStyle);
+ }
+
+ void GetDateTimePickerInfo(LPDATETIMEPICKERINFO lpPickerInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, DTM_GETDATETIMEPICKERINFO, 0, (LPARAM)lpPickerInfo);
+ }
+
+ BOOL GetIdealSize(LPSIZE lpSize) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, DTM_GETIDEALSIZE, 0, (LPARAM)lpSize);
+ }
+
+ void CloseMonthCal()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, DTM_CLOSEMONTHCAL, 0, 0L);
+ }
+#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+};
+
+typedef CDateTimePickerCtrlT<ATL::CWindow> CDateTimePickerCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CFlatScrollBarImpl - support for flat scroll bars
+
+template <class T>
+class CFlatScrollBarImpl
+{
+public:
+// Initialization
+ BOOL FlatSB_Initialize()
+ {
+ T* pT = static_cast<T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::InitializeFlatSB(pT->m_hWnd);
+ }
+
+ HRESULT FlatSB_Uninitialize()
+ {
+ T* pT = static_cast<T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::UninitializeFlatSB(pT->m_hWnd);
+ }
+
+// Flat scroll bar properties
+ BOOL FlatSB_GetScrollProp(UINT uIndex, LPINT lpnValue) const
+ {
+ const T* pT = static_cast<const T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_GetScrollProp(pT->m_hWnd, uIndex, lpnValue);
+ }
+
+ BOOL FlatSB_SetScrollProp(UINT uIndex, int nValue, BOOL bRedraw = TRUE)
+ {
+ T* pT = static_cast<T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_SetScrollProp(pT->m_hWnd, uIndex, nValue, bRedraw);
+ }
+
+// Attributes
+ int FlatSB_GetScrollPos(int nBar) const
+ {
+ const T* pT = static_cast<const T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_GetScrollPos(pT->m_hWnd, nBar);
+ }
+
+ int FlatSB_SetScrollPos(int nBar, int nPos, BOOL bRedraw = TRUE)
+ {
+ T* pT = static_cast<T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_SetScrollPos(pT->m_hWnd, nBar, nPos, bRedraw);
+ }
+
+ BOOL FlatSB_GetScrollRange(int nBar, LPINT lpMinPos, LPINT lpMaxPos) const
+ {
+ const T* pT = static_cast<const T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_GetScrollRange(pT->m_hWnd, nBar, lpMinPos, lpMaxPos);
+ }
+
+ BOOL FlatSB_SetScrollRange(int nBar, int nMinPos, int nMaxPos, BOOL bRedraw = TRUE)
+ {
+ T* pT = static_cast<T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_SetScrollRange(pT->m_hWnd, nBar, nMinPos, nMaxPos, bRedraw);
+ }
+
+ BOOL FlatSB_GetScrollInfo(int nBar, LPSCROLLINFO lpScrollInfo) const
+ {
+ const T* pT = static_cast<const T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_GetScrollInfo(pT->m_hWnd, nBar, lpScrollInfo);
+ }
+
+ int FlatSB_SetScrollInfo(int nBar, LPSCROLLINFO lpScrollInfo, BOOL bRedraw = TRUE)
+ {
+ T* pT = static_cast<T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_SetScrollInfo(pT->m_hWnd, nBar, lpScrollInfo, bRedraw);
+ }
+
+// Operations
+ BOOL FlatSB_ShowScrollBar(UINT nBar, BOOL bShow = TRUE)
+ {
+ T* pT = static_cast<T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_ShowScrollBar(pT->m_hWnd, nBar, bShow);
+ }
+
+ BOOL FlatSB_EnableScrollBar(UINT uSBFlags, UINT uArrowFlags = ESB_ENABLE_BOTH)
+ {
+ T* pT = static_cast<T*>(this);
+ ATLASSERT(::IsWindow(pT->m_hWnd));
+ return ::FlatSB_EnableScrollBar(pT->m_hWnd, uSBFlags, uArrowFlags);
+ }
+};
+
+template <class TBase>
+class CFlatScrollBarT : public TBase, public CFlatScrollBarImpl<CFlatScrollBarT< TBase > >
+{
+public:
+ CFlatScrollBarT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CFlatScrollBarT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+};
+
+typedef CFlatScrollBarT<ATL::CWindow> CFlatScrollBar;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CIPAddressCtrl
+
+template <class TBase>
+class CIPAddressCtrlT : public TBase
+{
+public:
+// Constructors
+ CIPAddressCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CIPAddressCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Atteributes
+ static LPCTSTR GetWndClassName()
+ {
+ return WC_IPADDRESS;
+ }
+
+ BOOL IsBlank() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, IPM_ISBLANK, 0, 0L);
+ }
+
+ int GetAddress(LPDWORD lpdwAddress) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, IPM_GETADDRESS, 0, (LPARAM)lpdwAddress);
+ }
+
+ void SetAddress(DWORD dwAddress)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, IPM_SETADDRESS, 0, dwAddress);
+ }
+
+ void ClearAddress()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, IPM_CLEARADDRESS, 0, 0L);
+ }
+
+ void SetRange(int nField, WORD wRange)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, IPM_SETRANGE, nField, wRange);
+ }
+
+ void SetRange(int nField, BYTE nMin, BYTE nMax)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, IPM_SETRANGE, nField, MAKEIPRANGE(nMin, nMax));
+ }
+
+ void SetFocus(int nField)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, IPM_SETFOCUS, nField, 0L);
+ }
+};
+
+typedef CIPAddressCtrlT<ATL::CWindow> CIPAddressCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CPagerCtrl
+
+template <class TBase>
+class CPagerCtrlT : public TBase
+{
+public:
+// Constructors
+ CPagerCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CPagerCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+ return WC_PAGESCROLLER;
+ }
+
+ int GetButtonSize() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PGM_GETBUTTONSIZE, 0, 0L);
+ }
+
+ int SetButtonSize(int nButtonSize)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PGM_SETBUTTONSIZE, 0, nButtonSize);
+ }
+
+ DWORD GetButtonState(int nButton) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT((nButton == PGB_TOPORLEFT) || (nButton == PGB_BOTTOMORRIGHT));
+ return (DWORD)::SendMessage(this->m_hWnd, PGM_GETBUTTONSTATE, 0, nButton);
+ }
+
+ COLORREF GetBkColor() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, PGM_GETBKCOLOR, 0, 0L);
+ }
+
+ COLORREF SetBkColor(COLORREF clrBk)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (COLORREF)::SendMessage(this->m_hWnd, PGM_SETBKCOLOR, 0, (LPARAM)clrBk);
+ }
+
+ int GetBorder() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PGM_GETBORDER, 0, 0L);
+ }
+
+ int SetBorder(int nBorderSize)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PGM_SETBORDER, 0, nBorderSize);
+ }
+
+ int GetPos() const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PGM_GETPOS, 0, 0L);
+ }
+
+ int SetPos(int nPos)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, PGM_SETPOS, 0, nPos);
+ }
+
+// Operations
+ void SetChild(HWND hWndChild)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, PGM_SETCHILD, 0, (LPARAM)hWndChild);
+ }
+
+ void ForwardMouse(BOOL bForward = TRUE)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, PGM_FORWARDMOUSE, bForward, 0L);
+ }
+
+ void RecalcSize()
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ::SendMessage(this->m_hWnd, PGM_RECALCSIZE, 0, 0L);
+ }
+
+ void GetDropTarget(IDropTarget** ppDropTarget)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ ATLASSERT(ppDropTarget != NULL);
+ ::SendMessage(this->m_hWnd, PGM_GETDROPTARGET, 0, (LPARAM)ppDropTarget);
+ }
+};
+
+typedef CPagerCtrlT<ATL::CWindow> CPagerCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CLinkCtrl - Windows SYSLINK control
+
+template <class TBase>
+class CLinkCtrlT : public TBase
+{
+public:
+// Constructors
+ CLinkCtrlT(HWND hWnd = NULL) : TBase(hWnd)
+ { }
+
+ CLinkCtrlT< TBase >& operator =(HWND hWnd)
+ {
+ this->m_hWnd = hWnd;
+ return *this;
+ }
+
+ HWND Create(HWND hWndParent, ATL::_U_RECT rect = NULL, LPCTSTR szWindowName = NULL,
+ DWORD dwStyle = 0, DWORD dwExStyle = 0,
+ ATL::_U_MENUorID MenuOrID = 0U, LPVOID lpCreateParam = NULL)
+ {
+ return TBase::Create(GetWndClassName(), hWndParent, rect.m_lpRect, szWindowName, dwStyle, dwExStyle, MenuOrID.m_hMenu, lpCreateParam);
+ }
+
+// Attributes
+ static LPCTSTR GetWndClassName()
+ {
+#ifdef _UNICODE
+ return WC_LINK;
+#else // !_UNICODE
+ return "SysLink";
+#endif // !_UNICODE
+ }
+
+ int GetIdealHeight(int cxMaxWidth = 0) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LM_GETIDEALHEIGHT, cxMaxWidth, 0L);
+ }
+
+ BOOL GetItem(PLITEM pLItem) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LM_GETITEM, 0, (LPARAM)pLItem);
+ }
+
+ BOOL SetItem(PLITEM pLItem)
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LM_SETITEM, 0, (LPARAM)pLItem);
+ }
+
+ // Vista only
+ int GetIdealSize(SIZE& size, int cxMaxWidth = 0) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (int)::SendMessage(this->m_hWnd, LM_GETIDEALSIZE, cxMaxWidth, (LPARAM)&size);
+ }
+
+// Operations
+ BOOL HitTest(PLHITTESTINFO pLHitTestInfo) const
+ {
+ ATLASSERT(::IsWindow(this->m_hWnd));
+ return (BOOL)::SendMessage(this->m_hWnd, LM_HITTEST, 0, (LPARAM)pLHitTestInfo);
+ }
+};
+
+typedef CLinkCtrlT<ATL::CWindow> CLinkCtrl;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CCustomDraw - MI class for custom-draw support
+
+template <class T>
+class CCustomDraw
+{
+public:
+// Message map and handlers
+ BEGIN_MSG_MAP(CCustomDraw< T >)
+ NOTIFY_CODE_HANDLER(NM_CUSTOMDRAW, OnCustomDraw)
+ ALT_MSG_MAP(1)
+ REFLECTED_NOTIFY_CODE_HANDLER(NM_CUSTOMDRAW, OnCustomDraw)
+ END_MSG_MAP()
+
+// message handler
+ LRESULT OnCustomDraw(int idCtrl, LPNMHDR pnmh, BOOL& bHandled)
+ {
+ T* pT = static_cast<T*>(this);
+ pT->SetMsgHandled(TRUE);
+ LPNMCUSTOMDRAW lpNMCustomDraw = (LPNMCUSTOMDRAW)pnmh;
+ DWORD dwRet = 0;
+ switch(lpNMCustomDraw->dwDrawStage)
+ {
+ case CDDS_PREPAINT:
+ dwRet = pT->OnPrePaint(idCtrl, lpNMCustomDraw);
+ break;
+ case CDDS_POSTPAINT:
+ dwRet = pT->OnPostPaint(idCtrl, lpNMCustomDraw);
+ break;
+ case CDDS_PREERASE:
+ dwRet = pT->OnPreErase(idCtrl, lpNMCustomDraw);
+ break;
+ case CDDS_POSTERASE:
+ dwRet = pT->OnPostErase(idCtrl, lpNMCustomDraw);
+ break;
+ case CDDS_ITEMPREPAINT:
+ dwRet = pT->OnItemPrePaint(idCtrl, lpNMCustomDraw);
+ break;
+ case CDDS_ITEMPOSTPAINT:
+ dwRet = pT->OnItemPostPaint(idCtrl, lpNMCustomDraw);
+ break;
+ case CDDS_ITEMPREERASE:
+ dwRet = pT->OnItemPreErase(idCtrl, lpNMCustomDraw);
+ break;
+ case CDDS_ITEMPOSTERASE:
+ dwRet = pT->OnItemPostErase(idCtrl, lpNMCustomDraw);
+ break;
+ case (CDDS_ITEMPREPAINT | CDDS_SUBITEM):
+ dwRet = pT->OnSubItemPrePaint(idCtrl, lpNMCustomDraw);
+ break;
+ default:
+ pT->SetMsgHandled(FALSE);
+ break;
+ }
+ bHandled = pT->IsMsgHandled();
+ return dwRet;
+ }
+
+// Overrideables
+ DWORD OnPrePaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+
+ DWORD OnPostPaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+
+ DWORD OnPreErase(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+
+ DWORD OnPostErase(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+
+ DWORD OnItemPrePaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+
+ DWORD OnItemPostPaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+
+ DWORD OnItemPreErase(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+
+ DWORD OnItemPostErase(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+
+ DWORD OnSubItemPrePaint(int /*idCtrl*/, LPNMCUSTOMDRAW /*lpNMCustomDraw*/)
+ {
+ return CDRF_DODEFAULT;
+ }
+};
+
+} // namespace WTL
+
+#endif // __ATLCTRLS_H__
diff --git a/Examples/WhisperDesktop/Utils/WTL/atlddx.h b/Examples/WhisperDesktop/Utils/WTL/atlddx.h
new file mode 100644
index 0000000..7723b99
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/WTL/atlddx.h
@@ -0,0 +1,667 @@
+// Windows Template Library - WTL version 10.0
+// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved.
+//
+// This file is a part of the Windows Template Library.
+// The use and distribution terms for this software are covered by the
+// Microsoft Public License (http://opensource.org/licenses/MS-PL)
+// which can be found in the file MS-PL.txt at the root folder.
+
+#ifndef __ATLDDX_H__
+#define __ATLDDX_H__
+
+#pragma once
+
+#ifndef __ATLAPP_H__
+ #error atlddx.h requires atlapp.h to be included first
+#endif
+
+#include <float.h>
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Classes in this file:
+//
+// CWinDataExchange<T>
+
+
+namespace WTL
+{
+
+// Constants
+#define DDX_LOAD FALSE
+#define DDX_SAVE TRUE
+
+// DDX map macros
+#define BEGIN_DDX_MAP(thisClass) \
+ BOOL DoDataExchange(BOOL bSaveAndValidate = FALSE, UINT nCtlID = (UINT)-1) \
+ { \
+ (bSaveAndValidate); \
+ (nCtlID);
+
+#define DDX_TEXT(nID, var) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Text(nID, var, sizeof(var), bSaveAndValidate)) \
+ return FALSE; \
+ }
+
+#define DDX_TEXT_LEN(nID, var, len) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Text(nID, var, sizeof(var), bSaveAndValidate, TRUE, len)) \
+ return FALSE; \
+ }
+
+#define DDX_INT(nID, var) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Int(nID, var, TRUE, bSaveAndValidate)) \
+ return FALSE; \
+ }
+
+#define DDX_INT_RANGE(nID, var, min, max) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Int(nID, var, TRUE, bSaveAndValidate, TRUE, min, max)) \
+ return FALSE; \
+ }
+
+#define DDX_UINT(nID, var) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Int(nID, var, FALSE, bSaveAndValidate)) \
+ return FALSE; \
+ }
+
+#define DDX_UINT_RANGE(nID, var, min, max) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Int(nID, var, FALSE, bSaveAndValidate, TRUE, min, max)) \
+ return FALSE; \
+ }
+
+#define DDX_FLOAT(nID, var) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Float(nID, var, bSaveAndValidate)) \
+ return FALSE; \
+ }
+
+#define DDX_FLOAT_RANGE(nID, var, min, max) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Float(nID, var, bSaveAndValidate, TRUE, min, max)) \
+ return FALSE; \
+ }
+#define DDX_FLOAT_P(nID, var, precision) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Float(nID, var, bSaveAndValidate, FALSE, 0, 0, precision)) \
+ return FALSE; \
+ }
+
+#define DDX_FLOAT_P_RANGE(nID, var, min, max, precision) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ { \
+ if(!DDX_Float(nID, var, bSaveAndValidate, TRUE, min, max, precision)) \
+ return FALSE; \
+ }
+
+#define DDX_CONTROL(nID, obj) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ DDX_Control(nID, obj, bSaveAndValidate);
+
+#define DDX_CONTROL_HANDLE(nID, obj) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ DDX_Control_Handle(nID, obj, bSaveAndValidate);
+
+#define DDX_CHECK(nID, var) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ DDX_Check(nID, var, bSaveAndValidate);
+
+#define DDX_RADIO(nID, var) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ DDX_Radio(nID, var, bSaveAndValidate);
+
+#define END_DDX_MAP() \
+ return TRUE; \
+ }
+
+// DDX support for Tab, Combo, ListBox and ListView selection index
+// Note: Specialized versions require atlctrls.h to be included first
+
+#define DDX_INDEX(CtrlClass, nID, var) \
+ if((nCtlID == (UINT)-1) || (nCtlID == nID)) \
+ DDX_Index<CtrlClass>(nID, var, bSaveAndValidate);
+
+#ifdef __ATLCTRLS_H__
+ #define DDX_TAB_INDEX(nID, var) DDX_INDEX(WTL::CTabCtrl, nID, var)
+ #define DDX_COMBO_INDEX(nID, var) DDX_INDEX(WTL::CComboBox, nID, var)
+ #define DDX_LISTBOX_INDEX(nID, var) DDX_INDEX(WTL::CListBox, nID, var)
+ #define DDX_LISTVIEW_INDEX(nID, var) DDX_INDEX(WTL::CListViewCtrl, nID, var)
+#endif // __ATLCTRLS_H__
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CWinDataExchange - provides support for DDX
+
+template <class T>
+class CWinDataExchange
+{
+public:
+// Data exchange method - override in your derived class
+ BOOL DoDataExchange(BOOL /*bSaveAndValidate*/ = FALSE, UINT /*nCtlID*/ = (UINT)-1)
+ {
+ // this one should never be called, override it in
+ // your derived class by implementing DDX map
+ ATLASSERT(FALSE);
+ return FALSE;
+ }
+
+// Helpers for validation error reporting
+ enum _XDataType
+ {
+ ddxDataNull = 0,
+ ddxDataText = 1,
+ ddxDataInt = 2,
+ ddxDataFloat = 3,
+ ddxDataDouble = 4
+ };
+
+ struct _XTextData
+ {
+ int nLength;
+ int nMaxLength;
+ };
+
+ struct _XIntData
+ {
+ long nVal;
+ long nMin;
+ long nMax;
+ };
+
+ struct _XFloatData
+ {
+ double nVal;
+ double nMin;
+ double nMax;
+ };
+
+ struct _XData
+ {
+ _XDataType nDataType;
+ union
+ {
+ _XTextData textData;
+ _XIntData intData;
+ _XFloatData floatData;
+ };
+ };
+
+// Text exchange
+ BOOL DDX_Text(UINT nID, LPTSTR lpstrText, int cbSize, BOOL bSave, BOOL bValidate = FALSE, int nLength = 0)
+ {
+ T* pT = static_cast<T*>(this);
+ BOOL bSuccess = TRUE;
+
+ if(bSave)
+ {
+ HWND hWndCtrl = pT->GetDlgItem(nID);
+ int nRetLen = ::GetWindowText(hWndCtrl, lpstrText, cbSize / sizeof(TCHAR));
+ if(nRetLen < ::GetWindowTextLength(hWndCtrl))
+ bSuccess = FALSE;
+ }
+ else
+ {
+ ATLASSERT(!bValidate || (lstrlen(lpstrText) <= nLength));
+ bSuccess = pT->SetDlgItemText(nID, lpstrText);
+ }
+
+ if(!bSuccess)
+ {
+ pT->OnDataExchangeError(nID, bSave);
+ }
+ else if(bSave && bValidate) // validation
+ {
+ ATLASSERT(nLength > 0);
+ if(lstrlen(lpstrText) > nLength)
+ {
+ _XData data = { ddxDataText };
+ data.textData.nLength = lstrlen(lpstrText);
+ data.textData.nMaxLength = nLength;
+ pT->OnDataValidateError(nID, bSave, data);
+ bSuccess = FALSE;
+ }
+ }
+ return bSuccess;
+ }
+
+ BOOL DDX_Text(UINT nID, BSTR& bstrText, int /*cbSize*/, BOOL bSave, BOOL bValidate = FALSE, int nLength = 0)
+ {
+ T* pT = static_cast<T*>(this);
+ BOOL bSuccess = TRUE;
+
+ if(bSave)
+ {
+ bSuccess = pT->GetDlgItemText(nID, bstrText);
+ }
+ else
+ {
+ USES_CONVERSION;
+ LPTSTR lpstrText = OLE2T(bstrText);
+ ATLASSERT(!bValidate || (lstrlen(lpstrText) <= nLength));
+ bSuccess = pT->SetDlgItemText(nID, lpstrText);
+ }
+
+ if(!bSuccess)
+ {
+ pT->OnDataExchangeError(nID, bSave);
+ }
+ else if(bSave && bValidate) // validation
+ {
+ ATLASSERT(nLength > 0);
+ if((int)::SysStringLen(bstrText) > nLength)
+ {
+ _XData data = { ddxDataText };
+ data.textData.nLength = (int)::SysStringLen(bstrText);
+ data.textData.nMaxLength = nLength;
+ pT->OnDataValidateError(nID, bSave, data);
+ bSuccess = FALSE;
+ }
+ }
+ return bSuccess;
+ }
+
+ BOOL DDX_Text(UINT nID, ATL::CComBSTR& bstrText, int /*cbSize*/, BOOL bSave, BOOL bValidate = FALSE, int nLength = 0)
+ {
+ T* pT = static_cast<T*>(this);
+ BOOL bSuccess = TRUE;
+
+ if(bSave)
+ {
+ bSuccess = pT->GetDlgItemText(nID, (BSTR&)bstrText);
+ }
+ else
+ {
+ USES_CONVERSION;
+ LPTSTR lpstrText = OLE2T(bstrText);
+ ATLASSERT(!bValidate || (lstrlen(lpstrText) <= nLength));
+ bSuccess = pT->SetDlgItemText(nID, lpstrText);
+ }
+
+ if(!bSuccess)
+ {
+ pT->OnDataExchangeError(nID, bSave);
+ }
+ else if(bSave && bValidate) // validation
+ {
+ ATLASSERT(nLength > 0);
+ if((int)bstrText.Length() > nLength)
+ {
+ _XData data = { ddxDataText };
+ data.textData.nLength = (int)bstrText.Length();
+ data.textData.nMaxLength = nLength;
+ pT->OnDataValidateError(nID, bSave, data);
+ bSuccess = FALSE;
+ }
+ }
+ return bSuccess;
+ }
+
+#ifdef __ATLSTR_H__
+ BOOL DDX_Text(UINT nID, ATL::CString& strText, int /*cbSize*/, BOOL bSave, BOOL bValidate = FALSE, int nLength = 0)
+ {
+ T* pT = static_cast<T*>(this);
+ BOOL bSuccess = TRUE;
+
+ if(bSave)
+ {
+ HWND hWndCtrl = pT->GetDlgItem(nID);
+ int nLen = ::GetWindowTextLength(hWndCtrl);
+ int nRetLen = -1;
+ LPTSTR lpstr = strText.GetBufferSetLength(nLen);
+ if(lpstr != NULL)
+ {
+ nRetLen = ::GetWindowText(hWndCtrl, lpstr, nLen + 1);
+ strText.ReleaseBuffer();
+ }
+ if(nRetLen < nLen)
+ bSuccess = FALSE;
+ }
+ else
+ {
+ bSuccess = pT->SetDlgItemText(nID, strText);
+ }
+
+ if(!bSuccess)
+ {
+ pT->OnDataExchangeError(nID, bSave);
+ }
+ else if(bSave && bValidate) // validation
+ {
+ ATLASSERT(nLength > 0);
+ if(strText.GetLength() > nLength)
+ {
+ _XData data = { ddxDataText };
+ data.textData.nLength = strText.GetLength();
+ data.textData.nMaxLength = nLength;
+ pT->OnDataValidateError(nID, bSave, data);
+ bSuccess = FALSE;
+ }
+ }
+ return bSuccess;
+ }
+#endif // __ATLSTR_H__
+
+// Numeric exchange
+ template <class Type>
+ BOOL DDX_Int(UINT nID, Type& nVal, BOOL bSigned, BOOL bSave, BOOL bValidate = FALSE, Type nMin = 0, Type nMax = 0)
+ {
+ T* pT = static_cast<T*>(this);
+ BOOL bSuccess = TRUE;
+
+ if(bSave)
+ {
+ nVal = (Type)pT->GetDlgItemInt(nID, &bSuccess, bSigned);
+ }
+ else
+ {
+ ATLASSERT(!bValidate || ((nVal >= nMin) && (nVal <= nMax)));
+ bSuccess = pT->SetDlgItemInt(nID, nVal, bSigned);
+ }
+
+ if(!bSuccess)
+ {
+ pT->OnDataExchangeError(nID, bSave);
+ }
+ else if(bSave && bValidate) // validation
+ {
+ ATLASSERT(nMin != nMax);
+ if((nVal < nMin) || (nVal > nMax))
+ {
+ _XData data = { ddxDataInt };
+ data.intData.nVal = (long)nVal;
+ data.intData.nMin = (long)nMin;
+ data.intData.nMax = (long)nMax;
+ pT->OnDataValidateError(nID, bSave, data);
+ bSuccess = FALSE;
+ }
+ }
+ return bSuccess;
+ }
+
+// Float exchange
+ static BOOL _AtlSimpleFloatParse(LPCTSTR lpszText, double& d)
+ {
+ ATLASSERT(lpszText != NULL);
+ while ((*lpszText == _T(' ')) || (*lpszText == _T('\t')))
+ lpszText++;
+
+ TCHAR chFirst = lpszText[0];
+ d = _tcstod(lpszText, (LPTSTR*)&lpszText);
+ if ((d == 0.0) && (chFirst != _T('0')))
+ return FALSE; // could not convert
+ while ((*lpszText == _T(' ')) || (*lpszText == _T('\t')))
+ lpszText++;
+
+ if (*lpszText != _T('\0'))
+ return FALSE; // not terminated properly
+
+ return TRUE;
+ }
+
+ BOOL DDX_Float(UINT nID, float& nVal, BOOL bSave, BOOL bValidate = FALSE, float nMin = 0.F, float nMax = 0.F, int nPrecision = FLT_DIG)
+ {
+ T* pT = static_cast<T*>(this);
+ BOOL bSuccess = TRUE;
+ const int cchBuff = 32;
+ TCHAR szBuff[cchBuff] = {};
+
+ if(bSave)
+ {
+ pT->GetDlgItemText(nID, szBuff, cchBuff);
+ double d = 0;
+ if(_AtlSimpleFloatParse(szBuff, d))
+ nVal = (float)d;
+ else
+ bSuccess = FALSE;
+ }
+ else
+ {
+ ATLASSERT(!bValidate || ((nVal >= nMin) && (nVal <= nMax)));
+ _stprintf_s(szBuff, cchBuff, _T("%.*g"), nPrecision, nVal);
+ bSuccess = pT->SetDlgItemText(nID, szBuff);
+ }
+
+ if(!bSuccess)
+ {
+ pT->OnDataExchangeError(nID, bSave);
+ }
+ else if(bSave && bValidate) // validation
+ {
+ ATLASSERT(nMin != nMax);
+ if((nVal < nMin) || (nVal > nMax))
+ {
+ _XData data = { ddxDataFloat };
+ data.floatData.nVal = (double)nVal;
+ data.floatData.nMin = (double)nMin;
+ data.floatData.nMax = (double)nMax;
+ pT->OnDataValidateError(nID, bSave, data);
+ bSuccess = FALSE;
+ }
+ }
+ return bSuccess;
+ }
+
+ BOOL DDX_Float(UINT nID, double& nVal, BOOL bSave, BOOL bValidate = FALSE, double nMin = 0., double nMax = 0., int nPrecision = DBL_DIG)
+ {
+ T* pT = static_cast<T*>(this);
+ BOOL bSuccess = TRUE;
+ const int cchBuff = 32;
+ TCHAR szBuff[cchBuff] = {};
+
+ if(bSave)
+ {
+ pT->GetDlgItemText(nID, szBuff, cchBuff);
+ double d = 0;
+ if(_AtlSimpleFloatParse(szBuff, d))
+ nVal = d;
+ else
+ bSuccess = FALSE;
+ }
+ else
+ {
+ ATLASSERT(!bValidate || ((nVal >= nMin) && (nVal <= nMax)));
+ _stprintf_s(szBuff, cchBuff, _T("%.*g"), nPrecision, nVal);
+ bSuccess = pT->SetDlgItemText(nID, szBuff);
+ }
+
+ if(!bSuccess)
+ {
+ pT->OnDataExchangeError(nID, bSave);
+ }
+ else if(bSave && bValidate) // validation
+ {
+ ATLASSERT(nMin != nMax);
+ if((nVal < nMin) || (nVal > nMax))
+ {
+ _XData data = { ddxDataFloat };
+ data.floatData.nVal = nVal;
+ data.floatData.nMin = nMin;
+ data.floatData.nMax = nMax;
+ pT->OnDataValidateError(nID, bSave, data);
+ bSuccess = FALSE;
+ }
+ }
+ return bSuccess;
+ }
+
+// Full control subclassing (for CWindowImpl derived controls)
+ template <class TControl>
+ void DDX_Control(UINT nID, TControl& ctrl, BOOL bSave)
+ {
+ if(!bSave && (ctrl.m_hWnd == NULL))
+ {
+ T* pT = static_cast<T*>(this);
+ ctrl.SubclassWindow(pT->GetDlgItem(nID));
+ }
+ }
+
+// Simple control attaching (for HWND wrapper controls)
+ template <class TControl>
+ void DDX_Control_Handle(UINT nID, TControl& ctrl, BOOL bSave)
+ {
+ if(!bSave && (ctrl.m_hWnd == NULL))
+ {
+ T* pT = static_cast<T*>(this);
+ ctrl = pT->GetDlgItem(nID);
+ }
+ }
+
+// Control state
+ void DDX_Check(UINT nID, int& nValue, BOOL bSave)
+ {
+ T* pT = static_cast<T*>(this);
+ HWND hWndCtrl = pT->GetDlgItem(nID);
+ if(bSave)
+ {
+ nValue = (int)::SendMessage(hWndCtrl, BM_GETCHECK, 0, 0L);
+ ATLASSERT((nValue >= 0) && (nValue <= 2));
+ }
+ else
+ {
+ if((nValue < 0) || (nValue > 2))
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ATL: Warning - dialog data checkbox value (%d) out of range.\n"), nValue);
+ nValue = 0; // default to off
+ }
+ ::SendMessage(hWndCtrl, BM_SETCHECK, nValue, 0L);
+ }
+ }
+
+ // variant that supports bool (checked/not-checked, no intermediate state)
+ void DDX_Check(UINT nID, bool& bCheck, BOOL bSave)
+ {
+ int nValue = bCheck ? 1 : 0;
+ DDX_Check(nID, nValue, bSave);
+
+ if(bSave)
+ {
+ if(nValue == 2)
+ ATLTRACE2(atlTraceUI, 0, _T("ATL: Warning - checkbox state (%d) out of supported range.\n"), nValue);
+ bCheck = (nValue == 1);
+ }
+ }
+
+ void DDX_Radio(UINT nID, int& nValue, BOOL bSave)
+ {
+ T* pT = static_cast<T*>(this);
+ HWND hWndCtrl = pT->GetDlgItem(nID);
+ ATLASSERT(hWndCtrl != NULL);
+
+ // must be first in a group of auto radio buttons
+ ATLASSERT(::GetWindowLong(hWndCtrl, GWL_STYLE) & WS_GROUP);
+ ATLASSERT(::SendMessage(hWndCtrl, WM_GETDLGCODE, 0, 0L) & DLGC_RADIOBUTTON);
+
+ if(bSave)
+ nValue = -1; // value if none found
+
+ // walk all children in group
+ int nButton = 0;
+ do
+ {
+ if(::SendMessage(hWndCtrl, WM_GETDLGCODE, 0, 0L) & DLGC_RADIOBUTTON)
+ {
+ // control in group is a radio button
+ if(bSave)
+ {
+ if(::SendMessage(hWndCtrl, BM_GETCHECK, 0, 0L) != 0)
+ {
+ ATLASSERT(nValue == -1); // only set once
+ nValue = nButton;
+ }
+ }
+ else
+ {
+ // select button
+ ::SendMessage(hWndCtrl, BM_SETCHECK, (nButton == nValue), 0L);
+ }
+ nButton++;
+ }
+ else
+ {
+ ATLTRACE2(atlTraceUI, 0, _T("ATL: Warning - skipping non-radio button in group.\n"));
+ }
+ hWndCtrl = ::GetWindow(hWndCtrl, GW_HWNDNEXT);
+ }
+ while ((hWndCtrl != NULL) && !(GetWindowLong(hWndCtrl, GWL_STYLE) & WS_GROUP));
+ }
+
+// DDX support for Tab, Combo, ListBox and ListView selection index
+ template <class TCtrl>
+ INT _getSel(TCtrl& tCtrl)
+ {
+ return tCtrl.GetCurSel();
+ }
+
+ template <class TCtrl>
+ void _setSel(TCtrl& tCtrl, INT iSel)
+ {
+ if(iSel < 0)
+ tCtrl.SetCurSel(-1);
+ else
+ tCtrl.SetCurSel(iSel);
+ }
+
+#ifdef __ATLCTRLS_H__
+ // ListViewCtrl specialization
+ template <>
+ INT _getSel(WTL::CListViewCtrl& tCtrl)
+ {
+ return tCtrl.GetSelectedIndex();
+ }
+
+ template <>
+ void _setSel(WTL::CListViewCtrl& tCtrl, INT iSel)
+ {
+ if(iSel < 0)
+ tCtrl.SelectItem(-1);
+ else
+ tCtrl.SelectItem(iSel);
+ }
+#endif // __ATLCTRLS_H__
+
+ template <class TCtrl>
+ void DDX_Index(UINT nID, INT& nVal, BOOL bSave)
+ {
+ T* pT = static_cast<T*>(this);
+ TCtrl ctrl(pT->GetDlgItem(nID));
+
+ if(bSave)
+ nVal = _getSel(ctrl);
+ else
+ _setSel(ctrl, nVal);
+ }
+
+// Overrideables
+ void OnDataExchangeError(UINT nCtrlID, BOOL /*bSave*/)
+ {
+ // Override to display an error message
+ ::MessageBeep((UINT)-1);
+ T* pT = static_cast<T*>(this);
+ ::SetFocus(pT->GetDlgItem(nCtrlID));
+ }
+
+ void OnDataValidateError(UINT nCtrlID, BOOL /*bSave*/, _XData& /*data*/)
+ {
+ // Override to display an error message
+ ::MessageBeep((UINT)-1);
+ T* pT = static_cast<T*>(this);
+ ::SetFocus(pT->GetDlgItem(nCtrlID));
+ }
+};
+
+} // namespace WTL
+
+#endif // __ATLDDX_H__
diff --git a/Examples/WhisperDesktop/Utils/WTL/atlgdi.h b/Examples/WhisperDesktop/Utils/WTL/atlgdi.h
new file mode 100644
index 0000000..14e1518
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/WTL/atlgdi.h
@@ -0,0 +1,3445 @@
+// Windows Template Library - WTL version 10.0
+// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved.
+//
+// This file is a part of the Windows Template Library.
+// The use and distribution terms for this software are covered by the
+// Microsoft Public License (http://opensource.org/licenses/MS-PL)
+// which can be found in the file MS-PL.txt at the root folder.
+
+#ifndef __ATLGDI_H__
+#define __ATLGDI_H__
+
+#pragma once
+
+#ifndef __ATLAPP_H__
+ #error atlgdi.h requires atlapp.h to be included first
+#endif
+
+
+// protect template members from windowsx.h macros
+#ifdef _INC_WINDOWSX
+ #undef CopyRgn
+ #undef CreateBrush
+ #undef CreatePen
+ #undef SelectBrush
+ #undef SelectPen
+ #undef SelectFont
+ #undef SelectBitmap
+#endif // _INC_WINDOWSX
+
+// required libraries
+#pragma comment(lib, "msimg32.lib")
+#if !defined(_ATL_NO_OPENGL)
+ #pragma comment(lib, "opengl32.lib")
+#endif
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Classes in this file:
+//
+// CPenT<t_bManaged>
+// CBrushT<t_bManaged>
+// CLogFont
+// CFontT<t_bManaged>
+// CBitmapT<t_bManaged>
+// CPaletteT<t_bManaged>
+// CRgnT<t_bManaged>
+// CDCT<t_bManaged>
+// CPaintDC
+// CClientDC
+// CWindowDC
+// CMemoryDC
+// CEnhMetaFileInfo
+// CEnhMetaFileT<t_bManaged>
+// CEnhMetaFileDC
+
+
+namespace WTL
+{
+
+///////////////////////////////////////////////////////////////////////////////
+// Bitmap resource helpers to extract bitmap information for a bitmap resource
+
+inline LPBITMAPINFOHEADER AtlGetBitmapResourceInfo(HMODULE hModule, ATL::_U_STRINGorID image)
+{
+ HRSRC hResource = ::FindResource(hModule, image.m_lpstr, RT_BITMAP);
+ ATLASSERT(hResource != NULL);
+ HGLOBAL hGlobal = ::LoadResource(hModule, hResource);
+ ATLASSERT(hGlobal != NULL);
+ LPBITMAPINFOHEADER pBitmapInfoHeader = (LPBITMAPINFOHEADER)::LockResource(hGlobal);
+ ATLASSERT(pBitmapInfoHeader != NULL);
+ return pBitmapInfoHeader;
+}
+
+inline WORD AtlGetBitmapResourceBitsPerPixel(HMODULE hModule, ATL::_U_STRINGorID image)
+{
+ LPBITMAPINFOHEADER pBitmapInfoHeader = AtlGetBitmapResourceInfo(hModule, image);
+ ATLASSERT(pBitmapInfoHeader != NULL);
+ return pBitmapInfoHeader->biBitCount;
+}
+
+inline WORD AtlGetBitmapResourceBitsPerPixel(ATL::_U_STRINGorID image)
+{
+ return AtlGetBitmapResourceBitsPerPixel(ModuleHelper::GetResourceInstance(), image);
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// 32-bit (alpha channel) bitmap resource helper
+
+// Note: 32-bit (alpha channel) images work only on Windows XP with Common Controls version 6.
+// If you want your app to work on older version of Windows, load non-alpha images if Common
+// Controls version is less than 6.
+
+inline bool AtlIsAlphaBitmapResource(ATL::_U_STRINGorID image)
+{
+ return (AtlGetBitmapResourceBitsPerPixel(image) == 32);
+}
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CPen
+
+template <bool t_bManaged>
+class CPenT
+{
+public:
+// Data members
+ HPEN m_hPen;
+
+// Constructor/destructor/operators
+ CPenT(HPEN hPen = NULL) : m_hPen(hPen)
+ { }
+
+ ~CPenT()
+ {
+ if(t_bManaged && (m_hPen != NULL))
+ DeleteObject();
+ }
+
+ CPenT<t_bManaged>& operator =(HPEN hPen)
+ {
+ Attach(hPen);
+ return *this;
+ }
+
+ void Attach(HPEN hPen)
+ {
+ if(t_bManaged && (m_hPen != NULL) && (m_hPen != hPen))
+ ::DeleteObject(m_hPen);
+ m_hPen = hPen;
+ }
+
+ HPEN Detach()
+ {
+ HPEN hPen = m_hPen;
+ m_hPen = NULL;
+ return hPen;
+ }
+
+ operator HPEN() const { return m_hPen; }
+
+ bool IsNull() const { return (m_hPen == NULL); }
+
+// Create methods
+ HPEN CreatePen(int nPenStyle, int nWidth, COLORREF crColor)
+ {
+ ATLASSERT(m_hPen == NULL);
+ m_hPen = ::CreatePen(nPenStyle, nWidth, crColor);
+ return m_hPen;
+ }
+
+ HPEN CreatePen(int nPenStyle, int nWidth, const LOGBRUSH* pLogBrush, int nStyleCount = 0, const DWORD* lpStyle = NULL)
+ {
+ ATLASSERT(m_hPen == NULL);
+ m_hPen = ::ExtCreatePen(nPenStyle, nWidth, pLogBrush, nStyleCount, lpStyle);
+ return m_hPen;
+ }
+
+ HPEN CreatePenIndirect(LPLOGPEN lpLogPen)
+ {
+ ATLASSERT(m_hPen == NULL);
+ m_hPen = ::CreatePenIndirect(lpLogPen);
+ return m_hPen;
+ }
+
+ BOOL DeleteObject()
+ {
+ ATLASSERT(m_hPen != NULL);
+ BOOL bRet = ::DeleteObject(m_hPen);
+ if(bRet)
+ m_hPen = NULL;
+ return bRet;
+ }
+
+// Attributes
+ int GetLogPen(LOGPEN* pLogPen) const
+ {
+ ATLASSERT(m_hPen != NULL);
+ return ::GetObject(m_hPen, sizeof(LOGPEN), pLogPen);
+ }
+
+ bool GetLogPen(LOGPEN& LogPen) const
+ {
+ ATLASSERT(m_hPen != NULL);
+ return (::GetObject(m_hPen, sizeof(LOGPEN), &LogPen) == sizeof(LOGPEN));
+ }
+
+ int GetExtLogPen(EXTLOGPEN* pLogPen, int nSize = sizeof(EXTLOGPEN)) const
+ {
+ ATLASSERT(m_hPen != NULL);
+ return ::GetObject(m_hPen, nSize, pLogPen);
+ }
+
+ bool GetExtLogPen(EXTLOGPEN& ExtLogPen, int nSize = sizeof(EXTLOGPEN)) const
+ {
+ ATLASSERT(m_hPen != NULL);
+ int nRet = ::GetObject(m_hPen, nSize, &ExtLogPen);
+ return ((nRet > 0) && (nRet <= nSize));
+ }
+};
+
+typedef CPenT<false> CPenHandle;
+typedef CPenT<true> CPen;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CBrush
+
+template <bool t_bManaged>
+class CBrushT
+{
+public:
+// Data members
+ HBRUSH m_hBrush;
+
+// Constructor/destructor/operators
+ CBrushT(HBRUSH hBrush = NULL) : m_hBrush(hBrush)
+ { }
+
+ ~CBrushT()
+ {
+ if(t_bManaged && (m_hBrush != NULL))
+ DeleteObject();
+ }
+
+ CBrushT<t_bManaged>& operator =(HBRUSH hBrush)
+ {
+ Attach(hBrush);
+ return *this;
+ }
+
+ void Attach(HBRUSH hBrush)
+ {
+ if(t_bManaged && (m_hBrush != NULL) && (m_hBrush != hBrush))
+ ::DeleteObject(m_hBrush);
+ m_hBrush = hBrush;
+ }
+
+ HBRUSH Detach()
+ {
+ HBRUSH hBrush = m_hBrush;
+ m_hBrush = NULL;
+ return hBrush;
+ }
+
+ operator HBRUSH() const { return m_hBrush; }
+
+ bool IsNull() const { return (m_hBrush == NULL); }
+
+// Create methods
+ HBRUSH CreateSolidBrush(COLORREF crColor)
+ {
+ ATLASSERT(m_hBrush == NULL);
+ m_hBrush = ::CreateSolidBrush(crColor);
+ return m_hBrush;
+ }
+
+ HBRUSH CreateHatchBrush(int nIndex, COLORREF crColor)
+ {
+ ATLASSERT(m_hBrush == NULL);
+ m_hBrush = ::CreateHatchBrush(nIndex, crColor);
+ return m_hBrush;
+ }
+
+ HBRUSH CreateBrushIndirect(const LOGBRUSH* lpLogBrush)
+ {
+ ATLASSERT(m_hBrush == NULL);
+ m_hBrush = ::CreateBrushIndirect(lpLogBrush);
+ return m_hBrush;
+ }
+
+ HBRUSH CreatePatternBrush(HBITMAP hBitmap)
+ {
+ ATLASSERT(m_hBrush == NULL);
+ m_hBrush = ::CreatePatternBrush(hBitmap);
+ return m_hBrush;
+ }
+
+ HBRUSH CreateDIBPatternBrush(HGLOBAL hPackedDIB, UINT nUsage)
+ {
+ ATLASSERT(hPackedDIB != NULL);
+ const void* lpPackedDIB = GlobalLock(hPackedDIB);
+ ATLASSERT(lpPackedDIB != NULL);
+ m_hBrush = ::CreateDIBPatternBrushPt(lpPackedDIB, nUsage);
+ GlobalUnlock(hPackedDIB);
+ return m_hBrush;
+ }
+
+ HBRUSH CreateDIBPatternBrush(const void* lpPackedDIB, UINT nUsage)
+ {
+ ATLASSERT(m_hBrush == NULL);
+ m_hBrush = ::CreateDIBPatternBrushPt(lpPackedDIB, nUsage);
+ return m_hBrush;
+ }
+
+ HBRUSH CreateSysColorBrush(int nIndex)
+ {
+ ATLASSERT(m_hBrush == NULL);
+ m_hBrush = ::GetSysColorBrush(nIndex);
+ return m_hBrush;
+ }
+
+ BOOL DeleteObject()
+ {
+ ATLASSERT(m_hBrush != NULL);
+ BOOL bRet = ::DeleteObject(m_hBrush);
+ if(bRet)
+ m_hBrush = NULL;
+ return bRet;
+ }
+
+// Attributes
+ int GetLogBrush(LOGBRUSH* pLogBrush) const
+ {
+ ATLASSERT(m_hBrush != NULL);
+ return ::GetObject(m_hBrush, sizeof(LOGBRUSH), pLogBrush);
+ }
+
+ bool GetLogBrush(LOGBRUSH& LogBrush) const
+ {
+ ATLASSERT(m_hBrush != NULL);
+ return (::GetObject(m_hBrush, sizeof(LOGBRUSH), &LogBrush) == sizeof(LOGBRUSH));
+ }
+};
+
+typedef CBrushT<false> CBrushHandle;
+typedef CBrushT<true> CBrush;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CFont
+
+class CLogFont : public LOGFONT
+{
+public:
+ CLogFont()
+ {
+ memset(this, 0, sizeof(LOGFONT));
+ }
+
+ CLogFont(const LOGFONT& lf)
+ {
+ Copy(&lf);
+ }
+
+ CLogFont(HFONT hFont)
+ {
+ ATLASSERT(::GetObjectType(hFont) == OBJ_FONT);
+ ::GetObject(hFont, sizeof(LOGFONT), (LOGFONT*)this);
+ }
+
+ HFONT CreateFontIndirect()
+ {
+ return ::CreateFontIndirect(this);
+ }
+
+ void SetBold()
+ {
+ lfWeight = FW_BOLD;
+ }
+
+ bool IsBold() const
+ {
+ return (lfWeight >= FW_BOLD);
+ }
+
+ void MakeBolder(int iScale = 1)
+ {
+ lfWeight += FW_BOLD * iScale;
+ }
+
+ void MakeLarger(int iScale)
+ {
+ if(lfHeight > 0)
+ lfHeight += iScale;
+ else
+ lfHeight -= iScale;
+ }
+
+ void SetHeight(LONG nPointSize, HDC hDC = NULL)
+ {
+ HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL);
+ // For MM_TEXT mapping mode
+ lfHeight = -::MulDiv(nPointSize, ::GetDeviceCaps(hDC1, LOGPIXELSY), 72);
+ if(hDC == NULL)
+ ::ReleaseDC(NULL, hDC1);
+ }
+
+ LONG GetHeight(HDC hDC = NULL) const
+ {
+ HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL);
+ // For MM_TEXT mapping mode
+ LONG nPointSize = ::MulDiv(-lfHeight, 72, ::GetDeviceCaps(hDC1, LOGPIXELSY));
+ if(hDC == NULL)
+ ::ReleaseDC(NULL, hDC1);
+
+ return nPointSize;
+ }
+
+ LONG GetDeciPointHeight(HDC hDC = NULL) const
+ {
+ HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL);
+ POINT ptOrg = { 0, 0 };
+ ::DPtoLP(hDC1, &ptOrg, 1);
+ POINT pt = { 0, 0 };
+ pt.y = abs(lfHeight) + ptOrg.y;
+ ::LPtoDP(hDC1, &pt, 1);
+ LONG nDeciPoint = ::MulDiv(pt.y, 720, ::GetDeviceCaps(hDC1, LOGPIXELSY)); // 72 points/inch, 10 decipoints/point
+ if(hDC == NULL)
+ ::ReleaseDC(NULL, hDC1);
+
+ return nDeciPoint;
+ }
+
+ void SetHeightFromDeciPoint(LONG nDeciPtHeight, HDC hDC = NULL)
+ {
+ HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL);
+ POINT pt = { 0, 0 };
+ pt.y = ::MulDiv(::GetDeviceCaps(hDC1, LOGPIXELSY), nDeciPtHeight, 720); // 72 points/inch, 10 decipoints/point
+ ::DPtoLP(hDC1, &pt, 1);
+ POINT ptOrg = { 0, 0 };
+ ::DPtoLP(hDC1, &ptOrg, 1);
+ lfHeight = -abs(pt.y - ptOrg.y);
+ if(hDC == NULL)
+ ::ReleaseDC(NULL, hDC1);
+ }
+
+ void SetCaptionFont()
+ {
+ NONCLIENTMETRICS ncm = { RunTimeHelper::SizeOf_NONCLIENTMETRICS() };
+ ATLVERIFY(::SystemParametersInfo(SPI_GETNONCLIENTMETRICS, sizeof(ncm), &ncm, 0));
+ Copy(&ncm.lfCaptionFont);
+ }
+
+ void SetMenuFont()
+ {
+ NONCLIENTMETRICS ncm = { RunTimeHelper::SizeOf_NONCLIENTMETRICS() };
+ ATLVERIFY(::SystemParametersInfo(SPI_GETNONCLIENTMETRICS, sizeof(ncm), &ncm, 0));
+ Copy(&ncm.lfMenuFont);
+ }
+
+ void SetStatusFont()
+ {
+ NONCLIENTMETRICS ncm = { RunTimeHelper::SizeOf_NONCLIENTMETRICS() };
+ ATLVERIFY(::SystemParametersInfo(SPI_GETNONCLIENTMETRICS, sizeof(ncm), &ncm, 0));
+ Copy(&ncm.lfStatusFont);
+ }
+
+ void SetMessageBoxFont()
+ {
+ NONCLIENTMETRICS ncm = { RunTimeHelper::SizeOf_NONCLIENTMETRICS() };
+ ATLVERIFY(::SystemParametersInfo(SPI_GETNONCLIENTMETRICS, sizeof(ncm), &ncm, 0));
+ Copy(&ncm.lfMessageFont);
+ }
+
+ void Copy(const LOGFONT* pLogFont)
+ {
+ ATLASSERT(pLogFont != NULL);
+ *(LOGFONT*)this = *pLogFont;
+ }
+
+ CLogFont& operator =(const CLogFont& src)
+ {
+ Copy(&src);
+ return *this;
+ }
+
+ CLogFont& operator =(const LOGFONT& src)
+ {
+ Copy(&src);
+ return *this;
+ }
+
+ CLogFont& operator =(HFONT hFont)
+ {
+ ATLASSERT(::GetObjectType(hFont) == OBJ_FONT);
+ ::GetObject(hFont, sizeof(LOGFONT), (LOGFONT*)this);
+ return *this;
+ }
+
+ bool operator ==(const LOGFONT& logfont) const
+ {
+ return((logfont.lfHeight == lfHeight) &&
+ (logfont.lfWidth == lfWidth) &&
+ (logfont.lfEscapement == lfEscapement) &&
+ (logfont.lfOrientation == lfOrientation) &&
+ (logfont.lfWeight == lfWeight) &&
+ (logfont.lfItalic == lfItalic) &&
+ (logfont.lfUnderline == lfUnderline) &&
+ (logfont.lfStrikeOut == lfStrikeOut) &&
+ (logfont.lfCharSet == lfCharSet) &&
+ (logfont.lfOutPrecision == lfOutPrecision) &&
+ (logfont.lfClipPrecision == lfClipPrecision) &&
+ (logfont.lfQuality == lfQuality) &&
+ (logfont.lfPitchAndFamily == lfPitchAndFamily) &&
+ (lstrcmp(logfont.lfFaceName, lfFaceName) == 0));
+ }
+};
+
+
+template <bool t_bManaged>
+class CFontT
+{
+public:
+// Data members
+ HFONT m_hFont;
+
+// Constructor/destructor/operators
+ CFontT(HFONT hFont = NULL) : m_hFont(hFont)
+ { }
+
+ ~CFontT()
+ {
+ if(t_bManaged && (m_hFont != NULL))
+ DeleteObject();
+ }
+
+ CFontT<t_bManaged>& operator =(HFONT hFont)
+ {
+ Attach(hFont);
+ return *this;
+ }
+
+ void Attach(HFONT hFont)
+ {
+ if(t_bManaged && (m_hFont != NULL) && (m_hFont != hFont))
+ ::DeleteObject(m_hFont);
+ m_hFont = hFont;
+ }
+
+ HFONT Detach()
+ {
+ HFONT hFont = m_hFont;
+ m_hFont = NULL;
+ return hFont;
+ }
+
+ operator HFONT() const { return m_hFont; }
+
+ bool IsNull() const { return (m_hFont == NULL); }
+
+// Create methods
+ HFONT CreateFontIndirect(const LOGFONT* lpLogFont)
+ {
+ ATLASSERT(m_hFont == NULL);
+ m_hFont = ::CreateFontIndirect(lpLogFont);
+ return m_hFont;
+ }
+
+ HFONT CreateFontIndirectEx(CONST ENUMLOGFONTEXDV* penumlfex)
+ {
+ ATLASSERT(m_hFont == NULL);
+ m_hFont = ::CreateFontIndirectEx(penumlfex);
+ return m_hFont;
+ }
+
+ HFONT CreateFont(int nHeight, int nWidth, int nEscapement,
+ int nOrientation, int nWeight, BYTE bItalic, BYTE bUnderline,
+ BYTE cStrikeOut, BYTE nCharSet, BYTE nOutPrecision,
+ BYTE nClipPrecision, BYTE nQuality, BYTE nPitchAndFamily,
+ LPCTSTR lpszFacename)
+ {
+ ATLASSERT(m_hFont == NULL);
+ m_hFont = ::CreateFont(nHeight, nWidth, nEscapement,
+ nOrientation, nWeight, bItalic, bUnderline, cStrikeOut,
+ nCharSet, nOutPrecision, nClipPrecision, nQuality,
+ nPitchAndFamily, lpszFacename);
+ return m_hFont;
+ }
+
+ HFONT CreatePointFont(int nPointSize, LPCTSTR lpszFaceName, HDC hDC = NULL, bool bBold = false, bool bItalic = false)
+ {
+ LOGFONT logFont = {};
+ logFont.lfCharSet = DEFAULT_CHARSET;
+ logFont.lfHeight = nPointSize;
+ ATL::Checked::tcsncpy_s(logFont.lfFaceName, _countof(logFont.lfFaceName), lpszFaceName, _TRUNCATE);
+
+ if(bBold)
+ logFont.lfWeight = FW_BOLD;
+ if(bItalic)
+ logFont.lfItalic = (BYTE)TRUE;
+
+ return CreatePointFontIndirect(&logFont, hDC);
+ }
+
+ HFONT CreatePointFontIndirect(const LOGFONT* lpLogFont, HDC hDC = NULL)
+ {
+ HDC hDC1 = (hDC != NULL) ? hDC : ::GetDC(NULL);
+
+ // convert nPointSize to logical units based on hDC
+ LOGFONT logFont = *lpLogFont;
+ POINT pt = { 0, 0 };
+ pt.y = ::MulDiv(::GetDeviceCaps(hDC1, LOGPIXELSY), logFont.lfHeight, 720); // 72 points/inch, 10 decipoints/point
+ ::DPtoLP(hDC1, &pt, 1);
+ POINT ptOrg = { 0, 0 };
+ ::DPtoLP(hDC1, &ptOrg, 1);
+ logFont.lfHeight = -abs(pt.y - ptOrg.y);
+
+ if(hDC == NULL)
+ ::ReleaseDC(NULL, hDC1);
+
+ return CreateFontIndirect(&logFont);
+ }
+
+ BOOL DeleteObject()
+ {
+ ATLASSERT(m_hFont != NULL);
+ BOOL bRet = ::DeleteObject(m_hFont);
+ if(bRet)
+ m_hFont = NULL;
+ return bRet;
+ }
+
+// Attributes
+ int GetLogFont(LOGFONT* pLogFont) const
+ {
+ ATLASSERT(m_hFont != NULL);
+ return ::GetObject(m_hFont, sizeof(LOGFONT), pLogFont);
+ }
+
+ bool GetLogFont(LOGFONT& LogFont) const
+ {
+ ATLASSERT(m_hFont != NULL);
+ return (::GetObject(m_hFont, sizeof(LOGFONT), &LogFont) == sizeof(LOGFONT));
+ }
+};
+
+typedef CFontT<false> CFontHandle;
+typedef CFontT<true> CFont;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CBitmap
+
+template <bool t_bManaged>
+class CBitmapT
+{
+public:
+// Data members
+ HBITMAP m_hBitmap;
+
+// Constructor/destructor/operators
+ CBitmapT(HBITMAP hBitmap = NULL) : m_hBitmap(hBitmap)
+ { }
+
+ ~CBitmapT()
+ {
+ if(t_bManaged && (m_hBitmap != NULL))
+ DeleteObject();
+ }
+
+ CBitmapT<t_bManaged>& operator =(HBITMAP hBitmap)
+ {
+ Attach(hBitmap);
+ return *this;
+ }
+
+ void Attach(HBITMAP hBitmap)
+ {
+ if(t_bManaged && (m_hBitmap != NULL) && (m_hBitmap != hBitmap))
+ ::DeleteObject(m_hBitmap);
+ m_hBitmap = hBitmap;
+ }
+
+ HBITMAP Detach()
+ {
+ HBITMAP hBitmap = m_hBitmap;
+ m_hBitmap = NULL;
+ return hBitmap;
+ }
+
+ operator HBITMAP() const { return m_hBitmap; }
+
+ bool IsNull() const { return (m_hBitmap == NULL); }
+
+// Create and load methods
+ HBITMAP LoadBitmap(ATL::_U_STRINGorID bitmap)
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::LoadBitmap(ModuleHelper::GetResourceInstance(), bitmap.m_lpstr);
+ return m_hBitmap;
+ }
+
+ HBITMAP LoadOEMBitmap(UINT nIDBitmap) // for OBM_/OCR_/OIC_
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::LoadBitmap(NULL, MAKEINTRESOURCE(nIDBitmap));
+ return m_hBitmap;
+ }
+
+ HBITMAP LoadMappedBitmap(UINT nIDBitmap, UINT nFlags = 0, LPCOLORMAP lpColorMap = NULL, int nMapSize = 0)
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::CreateMappedBitmap(ModuleHelper::GetResourceInstance(), nIDBitmap, (WORD)nFlags, lpColorMap, nMapSize);
+ return m_hBitmap;
+ }
+
+ HBITMAP CreateBitmap(int nWidth, int nHeight, UINT nPlanes, UINT nBitsPerPixel, const void* lpBits)
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::CreateBitmap(nWidth, nHeight, nPlanes, nBitsPerPixel, lpBits);
+ return m_hBitmap;
+ }
+
+ HBITMAP CreateBitmapIndirect(LPBITMAP lpBitmap)
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::CreateBitmapIndirect(lpBitmap);
+ return m_hBitmap;
+ }
+
+ HBITMAP CreateCompatibleBitmap(HDC hDC, int nWidth, int nHeight)
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::CreateCompatibleBitmap(hDC, nWidth, nHeight);
+ return m_hBitmap;
+ }
+
+ HBITMAP CreateDiscardableBitmap(HDC hDC, int nWidth, int nHeight)
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::CreateDiscardableBitmap(hDC, nWidth, nHeight);
+ return m_hBitmap;
+ }
+
+ BOOL DeleteObject()
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ BOOL bRet = ::DeleteObject(m_hBitmap);
+ if(bRet)
+ m_hBitmap = NULL;
+ return bRet;
+ }
+
+// Attributes
+ int GetBitmap(BITMAP* pBitMap) const
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ return ::GetObject(m_hBitmap, sizeof(BITMAP), pBitMap);
+ }
+
+ bool GetBitmap(BITMAP& bm) const
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ return (::GetObject(m_hBitmap, sizeof(BITMAP), &bm) == sizeof(BITMAP));
+ }
+
+ bool GetSize(SIZE& size) const
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ BITMAP bm = {};
+ if(!GetBitmap(&bm))
+ return false;
+ size.cx = bm.bmWidth;
+ size.cy = bm.bmHeight;
+ return true;
+ }
+
+ DWORD GetBitmapBits(DWORD dwCount, LPVOID lpBits) const
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ return ::GetBitmapBits(m_hBitmap, dwCount, lpBits);
+ }
+
+ DWORD SetBitmapBits(DWORD dwCount, const void* lpBits)
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ return ::SetBitmapBits(m_hBitmap, dwCount, lpBits);
+ }
+
+ BOOL GetBitmapDimension(LPSIZE lpSize) const
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ return ::GetBitmapDimensionEx(m_hBitmap, lpSize);
+ }
+
+ BOOL SetBitmapDimension(int nWidth, int nHeight, LPSIZE lpSize = NULL)
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ return ::SetBitmapDimensionEx(m_hBitmap, nWidth, nHeight, lpSize);
+ }
+
+// DIB support
+ HBITMAP CreateDIBitmap(HDC hDC, CONST BITMAPINFOHEADER* lpbmih, DWORD dwInit, CONST VOID* lpbInit, CONST BITMAPINFO* lpbmi, UINT uColorUse)
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::CreateDIBitmap(hDC, lpbmih, dwInit, lpbInit, lpbmi, uColorUse);
+ return m_hBitmap;
+ }
+
+ HBITMAP CreateDIBSection(HDC hDC, CONST BITMAPINFO* lpbmi, UINT uColorUse, VOID** ppvBits, HANDLE hSection, DWORD dwOffset)
+ {
+ ATLASSERT(m_hBitmap == NULL);
+ m_hBitmap = ::CreateDIBSection(hDC, lpbmi, uColorUse, ppvBits, hSection, dwOffset);
+ return m_hBitmap;
+ }
+
+ int GetDIBits(HDC hDC, UINT uStartScan, UINT cScanLines, LPVOID lpvBits, LPBITMAPINFO lpbmi, UINT uColorUse) const
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ return ::GetDIBits(hDC, m_hBitmap, uStartScan, cScanLines, lpvBits, lpbmi, uColorUse);
+ }
+
+ int SetDIBits(HDC hDC, UINT uStartScan, UINT cScanLines, CONST VOID* lpvBits, CONST BITMAPINFO* lpbmi, UINT uColorUse)
+ {
+ ATLASSERT(m_hBitmap != NULL);
+ return ::SetDIBits(hDC, m_hBitmap, uStartScan, cScanLines, lpvBits, lpbmi, uColorUse);
+ }
+};
+
+typedef CBitmapT<false> CBitmapHandle;
+typedef CBitmapT<true> CBitmap;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CPalette
+
+template <bool t_bManaged>
+class CPaletteT
+{
+public:
+// Data members
+ HPALETTE m_hPalette;
+
+// Constructor/destructor/operators
+ CPaletteT(HPALETTE hPalette = NULL) : m_hPalette(hPalette)
+ { }
+
+ ~CPaletteT()
+ {
+ if(t_bManaged && (m_hPalette != NULL))
+ DeleteObject();
+ }
+
+ CPaletteT<t_bManaged>& operator =(HPALETTE hPalette)
+ {
+ Attach(hPalette);
+ return *this;
+ }
+
+ void Attach(HPALETTE hPalette)
+ {
+ if(t_bManaged && (m_hPalette != NULL) && (m_hPalette != hPalette))
+ ::DeleteObject(m_hPalette);
+ m_hPalette = hPalette;
+ }
+
+ HPALETTE Detach()
+ {
+ HPALETTE hPalette = m_hPalette;
+ m_hPalette = NULL;
+ return hPalette;
+ }
+
+ operator HPALETTE() const { return m_hPalette; }
+
+ bool IsNull() const { return (m_hPalette == NULL); }
+
+// Create methods
+ HPALETTE CreatePalette(LPLOGPALETTE lpLogPalette)
+ {
+ ATLASSERT(m_hPalette == NULL);
+ m_hPalette = ::CreatePalette(lpLogPalette);
+ return m_hPalette;
+ }
+
+ HPALETTE CreateHalftonePalette(HDC hDC)
+ {
+ ATLASSERT(m_hPalette == NULL);
+ ATLASSERT(hDC != NULL);
+ m_hPalette = ::CreateHalftonePalette(hDC);
+ return m_hPalette;
+ }
+
+ BOOL DeleteObject()
+ {
+ ATLASSERT(m_hPalette != NULL);
+ BOOL bRet = ::DeleteObject(m_hPalette);
+ if(bRet)
+ m_hPalette = NULL;
+ return bRet;
+ }
+
+// Attributes
+ int GetEntryCount() const
+ {
+ ATLASSERT(m_hPalette != NULL);
+ WORD nEntries = 0;
+ ::GetObject(m_hPalette, sizeof(WORD), &nEntries);
+ return (int)nEntries;
+ }
+
+ UINT GetPaletteEntries(UINT nStartIndex, UINT nNumEntries, LPPALETTEENTRY lpPaletteColors) const
+ {
+ ATLASSERT(m_hPalette != NULL);
+ return ::GetPaletteEntries(m_hPalette, nStartIndex, nNumEntries, lpPaletteColors);
+ }
+
+ UINT SetPaletteEntries(UINT nStartIndex, UINT nNumEntries, LPPALETTEENTRY lpPaletteColors)
+ {
+ ATLASSERT(m_hPalette != NULL);
+ return ::SetPaletteEntries(m_hPalette, nStartIndex, nNumEntries, lpPaletteColors);
+ }
+
+// Operations
+ void AnimatePalette(UINT nStartIndex, UINT nNumEntries, LPPALETTEENTRY lpPaletteColors)
+ {
+ ATLASSERT(m_hPalette != NULL);
+ ::AnimatePalette(m_hPalette, nStartIndex, nNumEntries, lpPaletteColors);
+ }
+
+ BOOL ResizePalette(UINT nNumEntries)
+ {
+ ATLASSERT(m_hPalette != NULL);
+ return ::ResizePalette(m_hPalette, nNumEntries);
+ }
+
+ UINT GetNearestPaletteIndex(COLORREF crColor) const
+ {
+ ATLASSERT(m_hPalette != NULL);
+ return ::GetNearestPaletteIndex(m_hPalette, crColor);
+ }
+};
+
+typedef CPaletteT<false> CPaletteHandle;
+typedef CPaletteT<true> CPalette;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CRgn
+
+template <bool t_bManaged>
+class CRgnT
+{
+public:
+// Data members
+ HRGN m_hRgn;
+
+// Constructor/destructor/operators
+ CRgnT(HRGN hRgn = NULL) : m_hRgn(hRgn)
+ { }
+
+ ~CRgnT()
+ {
+ if(t_bManaged && (m_hRgn != NULL))
+ DeleteObject();
+ }
+
+ CRgnT<t_bManaged>& operator =(HRGN hRgn)
+ {
+ Attach(hRgn);
+ return *this;
+ }
+
+ void Attach(HRGN hRgn)
+ {
+ if(t_bManaged && (m_hRgn != NULL) && (m_hRgn != hRgn))
+ ::DeleteObject(m_hRgn);
+ m_hRgn = hRgn;
+ }
+
+ HRGN Detach()
+ {
+ HRGN hRgn = m_hRgn;
+ m_hRgn = NULL;
+ return hRgn;
+ }
+
+ operator HRGN() const { return m_hRgn; }
+
+ bool IsNull() const { return (m_hRgn == NULL); }
+
+// Create methods
+ HRGN CreateRectRgn(int x1, int y1, int x2, int y2)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ m_hRgn = ::CreateRectRgn(x1, y1, x2, y2);
+ return m_hRgn;
+ }
+
+ HRGN CreateRectRgnIndirect(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ m_hRgn = ::CreateRectRgnIndirect(lpRect);
+ return m_hRgn;
+ }
+
+ HRGN CreateEllipticRgn(int x1, int y1, int x2, int y2)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ m_hRgn = ::CreateEllipticRgn(x1, y1, x2, y2);
+ return m_hRgn;
+ }
+
+ HRGN CreateEllipticRgnIndirect(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ m_hRgn = ::CreateEllipticRgnIndirect(lpRect);
+ return m_hRgn;
+ }
+
+ HRGN CreatePolygonRgn(const POINT* lpPoints, int nCount, int nMode)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ m_hRgn = ::CreatePolygonRgn(lpPoints, nCount, nMode);
+ return m_hRgn;
+ }
+
+ HRGN CreatePolyPolygonRgn(const POINT* lpPoints, const INT* lpPolyCounts, int nCount, int nPolyFillMode)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ m_hRgn = ::CreatePolyPolygonRgn(lpPoints, lpPolyCounts, nCount, nPolyFillMode);
+ return m_hRgn;
+ }
+
+ HRGN CreateRoundRectRgn(int x1, int y1, int x2, int y2, int x3, int y3)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ m_hRgn = ::CreateRoundRectRgn(x1, y1, x2, y2, x3, y3);
+ return m_hRgn;
+ }
+
+ HRGN CreateFromPath(HDC hDC)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ ATLASSERT(hDC != NULL);
+ m_hRgn = ::PathToRegion(hDC);
+ return m_hRgn;
+ }
+
+ HRGN CreateFromData(const XFORM* lpXForm, int nCount, const RGNDATA* pRgnData)
+ {
+ ATLASSERT(m_hRgn == NULL);
+ m_hRgn = ::ExtCreateRegion(lpXForm, nCount, pRgnData);
+ return m_hRgn;
+ }
+
+ BOOL DeleteObject()
+ {
+ ATLASSERT(m_hRgn != NULL);
+ BOOL bRet = ::DeleteObject(m_hRgn);
+ if(bRet)
+ m_hRgn = NULL;
+ return bRet;
+ }
+
+// Operations
+ void SetRectRgn(int x1, int y1, int x2, int y2)
+ {
+ ATLASSERT(m_hRgn != NULL);
+ ::SetRectRgn(m_hRgn, x1, y1, x2, y2);
+ }
+
+ void SetRectRgn(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hRgn != NULL);
+ ::SetRectRgn(m_hRgn, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom);
+ }
+
+ int CombineRgn(HRGN hRgnSrc1, HRGN hRgnSrc2, int nCombineMode)
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::CombineRgn(m_hRgn, hRgnSrc1, hRgnSrc2, nCombineMode);
+ }
+
+ int CombineRgn(HRGN hRgnSrc, int nCombineMode)
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::CombineRgn(m_hRgn, m_hRgn, hRgnSrc, nCombineMode);
+ }
+
+ int CopyRgn(HRGN hRgnSrc)
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::CombineRgn(m_hRgn, hRgnSrc, NULL, RGN_COPY);
+ }
+
+ BOOL EqualRgn(HRGN hRgn) const
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::EqualRgn(m_hRgn, hRgn);
+ }
+
+ int OffsetRgn(int x, int y)
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::OffsetRgn(m_hRgn, x, y);
+ }
+
+ int OffsetRgn(POINT point)
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::OffsetRgn(m_hRgn, point.x, point.y);
+ }
+
+ int GetRgnBox(LPRECT lpRect) const
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::GetRgnBox(m_hRgn, lpRect);
+ }
+
+ BOOL PtInRegion(int x, int y) const
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::PtInRegion(m_hRgn, x, y);
+ }
+
+ BOOL PtInRegion(POINT point) const
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::PtInRegion(m_hRgn, point.x, point.y);
+ }
+
+ BOOL RectInRegion(LPCRECT lpRect) const
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return ::RectInRegion(m_hRgn, lpRect);
+ }
+
+ int GetRegionData(LPRGNDATA lpRgnData, int nDataSize) const
+ {
+ ATLASSERT(m_hRgn != NULL);
+ return (int)::GetRegionData(m_hRgn, nDataSize, lpRgnData);
+ }
+};
+
+typedef CRgnT<false> CRgnHandle;
+typedef CRgnT<true> CRgn;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CDC - The device context class
+
+template <bool t_bManaged>
+class CDCT;
+typedef CDCT<false> CDCHandle;
+typedef CDCT<true> CDC;
+
+template <bool t_bManaged>
+class CDCT
+{
+public:
+// Data members
+ HDC m_hDC;
+
+// Constructor/destructor/operators
+ CDCT(HDC hDC = NULL) : m_hDC(hDC)
+ {
+ }
+
+ ~CDCT()
+ {
+ if(t_bManaged && (m_hDC != NULL))
+ ::DeleteDC(Detach());
+ }
+
+ CDCT<t_bManaged>& operator =(HDC hDC)
+ {
+ Attach(hDC);
+ return *this;
+ }
+
+ void Attach(HDC hDC)
+ {
+ if(t_bManaged && (m_hDC != NULL) && (m_hDC != hDC))
+ ::DeleteDC(m_hDC);
+ m_hDC = hDC;
+ }
+
+ HDC Detach()
+ {
+ HDC hDC = m_hDC;
+ m_hDC = NULL;
+ return hDC;
+ }
+
+ operator HDC() const { return m_hDC; }
+
+ bool IsNull() const { return (m_hDC == NULL); }
+
+// Operations
+ HWND WindowFromDC() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::WindowFromDC(m_hDC);
+ }
+
+ CPenHandle GetCurrentPen() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return CPenHandle((HPEN)::GetCurrentObject(m_hDC, OBJ_PEN));
+ }
+
+ CBrushHandle GetCurrentBrush() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return CBrushHandle((HBRUSH)::GetCurrentObject(m_hDC, OBJ_BRUSH));
+ }
+
+ CPaletteHandle GetCurrentPalette() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return CPaletteHandle((HPALETTE)::GetCurrentObject(m_hDC, OBJ_PAL));
+ }
+
+ CFontHandle GetCurrentFont() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return CFontHandle((HFONT)::GetCurrentObject(m_hDC, OBJ_FONT));
+ }
+
+ CBitmapHandle GetCurrentBitmap() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return CBitmapHandle((HBITMAP)::GetCurrentObject(m_hDC, OBJ_BITMAP));
+ }
+
+ HDC CreateDC(LPCTSTR lpszDriverName, LPCTSTR lpszDeviceName, LPCTSTR lpszOutput, const DEVMODE* lpInitData)
+ {
+ ATLASSERT(m_hDC == NULL);
+ m_hDC = ::CreateDC(lpszDriverName, lpszDeviceName, lpszOutput, lpInitData);
+ return m_hDC;
+ }
+
+ HDC CreateCompatibleDC(HDC hDC = NULL)
+ {
+ ATLASSERT(m_hDC == NULL);
+ m_hDC = ::CreateCompatibleDC(hDC);
+ return m_hDC;
+ }
+
+ BOOL DeleteDC()
+ {
+ if(m_hDC == NULL)
+ return FALSE;
+ BOOL bRet = ::DeleteDC(m_hDC);
+ if(bRet)
+ m_hDC = NULL;
+ return bRet;
+ }
+
+// Device-Context Functions
+ int SaveDC()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SaveDC(m_hDC);
+ }
+
+ BOOL RestoreDC(int nSavedDC)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::RestoreDC(m_hDC, nSavedDC);
+ }
+
+ int GetDeviceCaps(int nIndex) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetDeviceCaps(m_hDC, nIndex);
+ }
+
+ UINT SetBoundsRect(LPCRECT lpRectBounds, UINT flags)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetBoundsRect(m_hDC, lpRectBounds, flags);
+ }
+
+ UINT GetBoundsRect(LPRECT lpRectBounds, UINT flags) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetBoundsRect(m_hDC, lpRectBounds, flags);
+ }
+
+ BOOL ResetDC(const DEVMODE* lpDevMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ResetDC(m_hDC, lpDevMode) != NULL;
+ }
+
+// Drawing-Tool Functions
+ BOOL GetBrushOrg(LPPOINT lpPoint) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetBrushOrgEx(m_hDC, lpPoint);
+ }
+
+ BOOL SetBrushOrg(int x, int y, LPPOINT lpPoint = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetBrushOrgEx(m_hDC, x, y, lpPoint);
+ }
+
+ BOOL SetBrushOrg(POINT point, LPPOINT lpPointRet = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetBrushOrgEx(m_hDC, point.x, point.y, lpPointRet);
+ }
+
+ int EnumObjects(int nObjectType, int (CALLBACK* lpfn)(LPVOID, LPARAM), LPARAM lpData)
+ {
+ ATLASSERT(m_hDC != NULL);
+#ifdef STRICT
+ return ::EnumObjects(m_hDC, nObjectType, (GOBJENUMPROC)lpfn, lpData);
+#else
+ return ::EnumObjects(m_hDC, nObjectType, (GOBJENUMPROC)lpfn, (LPVOID)lpData);
+#endif
+ }
+
+// Type-safe selection helpers
+ HPEN SelectPen(HPEN hPen)
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT((hPen == NULL) || (::GetObjectType(hPen) == OBJ_PEN) || (::GetObjectType(hPen) == OBJ_EXTPEN));
+ return (HPEN)::SelectObject(m_hDC, hPen);
+ }
+
+ HBRUSH SelectBrush(HBRUSH hBrush)
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT((hBrush == NULL) || (::GetObjectType(hBrush) == OBJ_BRUSH));
+ return (HBRUSH)::SelectObject(m_hDC, hBrush);
+ }
+
+ HFONT SelectFont(HFONT hFont)
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT((hFont == NULL) || (::GetObjectType(hFont) == OBJ_FONT));
+ return (HFONT)::SelectObject(m_hDC, hFont);
+ }
+
+ HBITMAP SelectBitmap(HBITMAP hBitmap)
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT((hBitmap == NULL) || (::GetObjectType(hBitmap) == OBJ_BITMAP));
+ return (HBITMAP)::SelectObject(m_hDC, hBitmap);
+ }
+
+ int SelectRgn(HRGN hRgn) // special return for regions
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT((hRgn == NULL) || (::GetObjectType(hRgn) == OBJ_REGION));
+ return PtrToInt(::SelectObject(m_hDC, hRgn));
+ }
+
+// Type-safe selection helpers for stock objects
+ HPEN SelectStockPen(int nPen)
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT((nPen == WHITE_PEN) || (nPen == BLACK_PEN) || (nPen == NULL_PEN) || (nPen == DC_PEN));
+ return SelectPen((HPEN)::GetStockObject(nPen));
+ }
+
+ HBRUSH SelectStockBrush(int nBrush)
+ {
+ ATLASSERT(((nBrush >= WHITE_BRUSH) && (nBrush <= HOLLOW_BRUSH)) || (nBrush == DC_BRUSH));
+ return SelectBrush((HBRUSH)::GetStockObject(nBrush));
+ }
+
+ HFONT SelectStockFont(int nFont)
+ {
+ ATLASSERT(((nFont >= OEM_FIXED_FONT) && (nFont <= SYSTEM_FIXED_FONT)) || (nFont == DEFAULT_GUI_FONT));
+ return SelectFont((HFONT)::GetStockObject(nFont));
+ }
+
+ HPALETTE SelectStockPalette(int nPalette, BOOL bForceBackground)
+ {
+ ATLASSERT(nPalette == DEFAULT_PALETTE); // the only one supported
+ return SelectPalette((HPALETTE)::GetStockObject(nPalette), bForceBackground);
+ }
+
+// Color and Color Palette Functions
+ COLORREF GetNearestColor(COLORREF crColor) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetNearestColor(m_hDC, crColor);
+ }
+
+ HPALETTE SelectPalette(HPALETTE hPalette, BOOL bForceBackground)
+ {
+ ATLASSERT(m_hDC != NULL);
+
+ return ::SelectPalette(m_hDC, hPalette, bForceBackground);
+ }
+
+ UINT RealizePalette()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::RealizePalette(m_hDC);
+ }
+
+ void UpdateColors()
+ {
+ ATLASSERT(m_hDC != NULL);
+ ::UpdateColors(m_hDC);
+ }
+
+// Drawing-Attribute Functions
+ COLORREF GetBkColor() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetBkColor(m_hDC);
+ }
+
+ int GetBkMode() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetBkMode(m_hDC);
+ }
+
+ int GetPolyFillMode() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetPolyFillMode(m_hDC);
+ }
+
+ int GetROP2() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetROP2(m_hDC);
+ }
+
+ int GetStretchBltMode() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetStretchBltMode(m_hDC);
+ }
+
+ COLORREF GetTextColor() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextColor(m_hDC);
+ }
+
+ COLORREF SetBkColor(COLORREF crColor)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetBkColor(m_hDC, crColor);
+ }
+
+ int SetBkMode(int nBkMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetBkMode(m_hDC, nBkMode);
+ }
+
+ int SetPolyFillMode(int nPolyFillMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetPolyFillMode(m_hDC, nPolyFillMode);
+ }
+
+ int SetROP2(int nDrawMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetROP2(m_hDC, nDrawMode);
+ }
+
+ int SetStretchBltMode(int nStretchMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetStretchBltMode(m_hDC, nStretchMode);
+ }
+
+ COLORREF SetTextColor(COLORREF crColor)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetTextColor(m_hDC, crColor);
+ }
+
+ BOOL GetColorAdjustment(LPCOLORADJUSTMENT lpColorAdjust) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetColorAdjustment(m_hDC, lpColorAdjust);
+ }
+
+ BOOL SetColorAdjustment(const COLORADJUSTMENT* lpColorAdjust)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetColorAdjustment(m_hDC, lpColorAdjust);
+ }
+
+// Mapping Functions
+ int GetMapMode() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetMapMode(m_hDC);
+ }
+
+ BOOL GetViewportOrg(LPPOINT lpPoint) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetViewportOrgEx(m_hDC, lpPoint);
+ }
+
+ int SetMapMode(int nMapMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetMapMode(m_hDC, nMapMode);
+ }
+
+ // Viewport Origin
+ BOOL SetViewportOrg(int x, int y, LPPOINT lpPoint = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetViewportOrgEx(m_hDC, x, y, lpPoint);
+ }
+
+ BOOL SetViewportOrg(POINT point, LPPOINT lpPointRet = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return SetViewportOrg(point.x, point.y, lpPointRet);
+ }
+
+ BOOL OffsetViewportOrg(int nWidth, int nHeight, LPPOINT lpPoint = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::OffsetViewportOrgEx(m_hDC, nWidth, nHeight, lpPoint);
+ }
+
+ // Viewport Extent
+ BOOL GetViewportExt(LPSIZE lpSize) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetViewportExtEx(m_hDC, lpSize);
+ }
+
+ BOOL SetViewportExt(int x, int y, LPSIZE lpSize = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetViewportExtEx(m_hDC, x, y, lpSize);
+ }
+
+ BOOL SetViewportExt(SIZE size, LPSIZE lpSizeRet = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return SetViewportExt(size.cx, size.cy, lpSizeRet);
+ }
+
+ BOOL ScaleViewportExt(int xNum, int xDenom, int yNum, int yDenom, LPSIZE lpSize = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ScaleViewportExtEx(m_hDC, xNum, xDenom, yNum, yDenom, lpSize);
+ }
+
+ // Window Origin
+ BOOL GetWindowOrg(LPPOINT lpPoint) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetWindowOrgEx(m_hDC, lpPoint);
+ }
+
+ BOOL SetWindowOrg(int x, int y, LPPOINT lpPoint = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetWindowOrgEx(m_hDC, x, y, lpPoint);
+ }
+
+ BOOL SetWindowOrg(POINT point, LPPOINT lpPointRet = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return SetWindowOrg(point.x, point.y, lpPointRet);
+ }
+
+ BOOL OffsetWindowOrg(int nWidth, int nHeight, LPPOINT lpPoint = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::OffsetWindowOrgEx(m_hDC, nWidth, nHeight, lpPoint);
+ }
+
+ // Window extent
+ BOOL GetWindowExt(LPSIZE lpSize) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetWindowExtEx(m_hDC, lpSize);
+ }
+
+ BOOL SetWindowExt(int x, int y, LPSIZE lpSize = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetWindowExtEx(m_hDC, x, y, lpSize);
+ }
+
+ BOOL SetWindowExt(SIZE size, LPSIZE lpSizeRet = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return SetWindowExt(size.cx, size.cy, lpSizeRet);
+ }
+
+ BOOL ScaleWindowExt(int xNum, int xDenom, int yNum, int yDenom, LPSIZE lpSize = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ScaleWindowExtEx(m_hDC, xNum, xDenom, yNum, yDenom, lpSize);
+ }
+
+// Coordinate Functions
+ BOOL DPtoLP(LPPOINT lpPoints, int nCount = 1) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DPtoLP(m_hDC, lpPoints, nCount);
+ }
+
+ BOOL DPtoLP(LPRECT lpRect) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DPtoLP(m_hDC, (LPPOINT)lpRect, 2);
+ }
+
+ BOOL DPtoLP(LPSIZE lpSize) const
+ {
+ SIZE sizeWinExt = {};
+ if(!GetWindowExt(&sizeWinExt))
+ return FALSE;
+ SIZE sizeVpExt = {};
+ if(!GetViewportExt(&sizeVpExt))
+ return FALSE;
+ lpSize->cx = ::MulDiv(lpSize->cx, abs(sizeWinExt.cx), abs(sizeVpExt.cx));
+ lpSize->cy = ::MulDiv(lpSize->cy, abs(sizeWinExt.cy), abs(sizeVpExt.cy));
+ return TRUE;
+ }
+
+ BOOL LPtoDP(LPPOINT lpPoints, int nCount = 1) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::LPtoDP(m_hDC, lpPoints, nCount);
+ }
+
+ BOOL LPtoDP(LPRECT lpRect) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::LPtoDP(m_hDC, (LPPOINT)lpRect, 2);
+ }
+
+ BOOL LPtoDP(LPSIZE lpSize) const
+ {
+ SIZE sizeWinExt = {};
+ if(!GetWindowExt(&sizeWinExt))
+ return FALSE;
+ SIZE sizeVpExt = {};
+ if(!GetViewportExt(&sizeVpExt))
+ return FALSE;
+ lpSize->cx = ::MulDiv(lpSize->cx, abs(sizeVpExt.cx), abs(sizeWinExt.cx));
+ lpSize->cy = ::MulDiv(lpSize->cy, abs(sizeVpExt.cy), abs(sizeWinExt.cy));
+ return TRUE;
+ }
+
+// Special Coordinate Functions (useful for dealing with metafiles and OLE)
+ #define HIMETRIC_INCH 2540 // HIMETRIC units per inch
+
+ void DPtoHIMETRIC(LPSIZE lpSize)
+ {
+ ATLASSERT(m_hDC != NULL);
+ int nMapMode = GetMapMode();
+ if((nMapMode < MM_ISOTROPIC) && (nMapMode != MM_TEXT))
+ {
+ // when using a constrained map mode, map against physical inch
+ SetMapMode(MM_HIMETRIC);
+ DPtoLP(lpSize);
+ SetMapMode(nMapMode);
+ }
+ else
+ {
+ // map against logical inch for non-constrained mapping modes
+ int cxPerInch = GetDeviceCaps(LOGPIXELSX);
+ int cyPerInch = GetDeviceCaps(LOGPIXELSY);
+ ATLASSERT((cxPerInch != 0) && (cyPerInch != 0));
+ lpSize->cx = ::MulDiv(lpSize->cx, HIMETRIC_INCH, cxPerInch);
+ lpSize->cy = ::MulDiv(lpSize->cy, HIMETRIC_INCH, cyPerInch);
+ }
+ }
+
+ void HIMETRICtoDP(LPSIZE lpSize)
+ {
+ ATLASSERT(m_hDC != NULL);
+ int nMapMode = GetMapMode();
+ if((nMapMode < MM_ISOTROPIC) && (nMapMode != MM_TEXT))
+ {
+ // when using a constrained map mode, map against physical inch
+ SetMapMode(MM_HIMETRIC);
+ LPtoDP(lpSize);
+ SetMapMode(nMapMode);
+ }
+ else
+ {
+ // map against logical inch for non-constrained mapping modes
+ int cxPerInch = GetDeviceCaps(LOGPIXELSX);
+ int cyPerInch = GetDeviceCaps(LOGPIXELSY);
+ ATLASSERT((cxPerInch != 0) && (cyPerInch != 0));
+ lpSize->cx = ::MulDiv(lpSize->cx, cxPerInch, HIMETRIC_INCH);
+ lpSize->cy = ::MulDiv(lpSize->cy, cyPerInch, HIMETRIC_INCH);
+ }
+ }
+
+ void LPtoHIMETRIC(LPSIZE lpSize)
+ {
+ LPtoDP(lpSize);
+ DPtoHIMETRIC(lpSize);
+ }
+
+ void HIMETRICtoLP(LPSIZE lpSize)
+ {
+ HIMETRICtoDP(lpSize);
+ DPtoLP(lpSize);
+ }
+
+// Region Functions
+ BOOL FillRgn(HRGN hRgn, HBRUSH hBrush)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::FillRgn(m_hDC, hRgn, hBrush);
+ }
+
+ BOOL FrameRgn(HRGN hRgn, HBRUSH hBrush, int nWidth, int nHeight)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::FrameRgn(m_hDC, hRgn, hBrush, nWidth, nHeight);
+ }
+
+ BOOL InvertRgn(HRGN hRgn)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::InvertRgn(m_hDC, hRgn);
+ }
+
+ BOOL PaintRgn(HRGN hRgn)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PaintRgn(m_hDC, hRgn);
+ }
+
+// Clipping Functions
+ int GetClipBox(LPRECT lpRect) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetClipBox(m_hDC, lpRect);
+ }
+
+ int GetClipRgn(CRgn& region) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ if(region.IsNull())
+ region.CreateRectRgn(0, 0, 0, 0);
+
+ int nRet = ::GetClipRgn(m_hDC, region);
+ if(nRet != 1)
+ region.DeleteObject();
+
+ return nRet;
+ }
+
+ BOOL PtVisible(int x, int y) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PtVisible(m_hDC, x, y);
+ }
+
+ BOOL PtVisible(POINT point) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PtVisible(m_hDC, point.x, point.y);
+ }
+
+ BOOL RectVisible(LPCRECT lpRect) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::RectVisible(m_hDC, lpRect);
+ }
+
+ int SelectClipRgn(HRGN hRgn)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SelectClipRgn(m_hDC, (HRGN)hRgn);
+ }
+
+ int ExcludeClipRect(int x1, int y1, int x2, int y2)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ExcludeClipRect(m_hDC, x1, y1, x2, y2);
+ }
+
+ int ExcludeClipRect(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ExcludeClipRect(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom);
+ }
+
+ int ExcludeUpdateRgn(HWND hWnd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ExcludeUpdateRgn(m_hDC, hWnd);
+ }
+
+ int IntersectClipRect(int x1, int y1, int x2, int y2)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::IntersectClipRect(m_hDC, x1, y1, x2, y2);
+ }
+
+ int IntersectClipRect(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::IntersectClipRect(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom);
+ }
+
+ int OffsetClipRgn(int x, int y)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::OffsetClipRgn(m_hDC, x, y);
+ }
+
+ int OffsetClipRgn(SIZE size)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::OffsetClipRgn(m_hDC, size.cx, size.cy);
+ }
+
+ int SelectClipRgn(HRGN hRgn, int nMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ExtSelectClipRgn(m_hDC, hRgn, nMode);
+ }
+
+// Line-Output Functions
+ BOOL GetCurrentPosition(LPPOINT lpPoint) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetCurrentPositionEx(m_hDC, lpPoint);
+ }
+
+ BOOL MoveTo(int x, int y, LPPOINT lpPoint = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::MoveToEx(m_hDC, x, y, lpPoint);
+ }
+
+ BOOL MoveTo(POINT point, LPPOINT lpPointRet = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return MoveTo(point.x, point.y, lpPointRet);
+ }
+
+ BOOL LineTo(int x, int y)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::LineTo(m_hDC, x, y);
+ }
+
+ BOOL LineTo(POINT point)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return LineTo(point.x, point.y);
+ }
+
+ BOOL Arc(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Arc(m_hDC, x1, y1, x2, y2, x3, y3, x4, y4);
+ }
+
+ BOOL Arc(LPCRECT lpRect, POINT ptStart, POINT ptEnd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Arc(m_hDC, lpRect->left, lpRect->top,
+ lpRect->right, lpRect->bottom, ptStart.x, ptStart.y,
+ ptEnd.x, ptEnd.y);
+ }
+
+ BOOL Polyline(const POINT* lpPoints, int nCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Polyline(m_hDC, lpPoints, nCount);
+ }
+
+ BOOL AngleArc(int x, int y, int nRadius, float fStartAngle, float fSweepAngle)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::AngleArc(m_hDC, x, y, nRadius, fStartAngle, fSweepAngle);
+ }
+
+ BOOL ArcTo(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ArcTo(m_hDC, x1, y1, x2, y2, x3, y3, x4, y4);
+ }
+
+ BOOL ArcTo(LPCRECT lpRect, POINT ptStart, POINT ptEnd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ArcTo(lpRect->left, lpRect->top, lpRect->right,
+ lpRect->bottom, ptStart.x, ptStart.y, ptEnd.x, ptEnd.y);
+ }
+
+ int GetArcDirection() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetArcDirection(m_hDC);
+ }
+
+ int SetArcDirection(int nArcDirection)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetArcDirection(m_hDC, nArcDirection);
+ }
+
+ BOOL PolyDraw(const POINT* lpPoints, const BYTE* lpTypes, int nCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PolyDraw(m_hDC, lpPoints, lpTypes, nCount);
+ }
+
+ BOOL PolylineTo(const POINT* lpPoints, int nCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PolylineTo(m_hDC, lpPoints, nCount);
+ }
+
+ BOOL PolyPolyline(const POINT* lpPoints,
+ const DWORD* lpPolyPoints, int nCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PolyPolyline(m_hDC, lpPoints, lpPolyPoints, nCount);
+ }
+
+ BOOL PolyBezier(const POINT* lpPoints, int nCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PolyBezier(m_hDC, lpPoints, nCount);
+ }
+
+ BOOL PolyBezierTo(const POINT* lpPoints, int nCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PolyBezierTo(m_hDC, lpPoints, nCount);
+ }
+
+// Simple Drawing Functions
+ BOOL FillRect(LPCRECT lpRect, HBRUSH hBrush)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::FillRect(m_hDC, lpRect, hBrush);
+ }
+
+ BOOL FillRect(LPCRECT lpRect, int nColorIndex)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::FillRect(m_hDC, lpRect, (HBRUSH)LongToPtr(nColorIndex + 1));
+ }
+
+ BOOL FrameRect(LPCRECT lpRect, HBRUSH hBrush)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::FrameRect(m_hDC, lpRect, hBrush);
+ }
+
+ BOOL InvertRect(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::InvertRect(m_hDC, lpRect);
+ }
+
+ BOOL DrawIcon(int x, int y, HICON hIcon)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawIcon(m_hDC, x, y, hIcon);
+ }
+
+ BOOL DrawIcon(POINT point, HICON hIcon)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawIcon(m_hDC, point.x, point.y, hIcon);
+ }
+
+ BOOL DrawIconEx(int x, int y, HICON hIcon, int cxWidth, int cyWidth, UINT uStepIfAniCur = 0, HBRUSH hbrFlickerFreeDraw = NULL, UINT uFlags = DI_NORMAL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawIconEx(m_hDC, x, y, hIcon, cxWidth, cyWidth, uStepIfAniCur, hbrFlickerFreeDraw, uFlags);
+ }
+
+ BOOL DrawIconEx(POINT point, HICON hIcon, SIZE size, UINT uStepIfAniCur = 0, HBRUSH hbrFlickerFreeDraw = NULL, UINT uFlags = DI_NORMAL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawIconEx(m_hDC, point.x, point.y, hIcon, size.cx, size.cy, uStepIfAniCur, hbrFlickerFreeDraw, uFlags);
+ }
+
+ BOOL DrawState(POINT pt, SIZE size, HBITMAP hBitmap, UINT nFlags, HBRUSH hBrush = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawState(m_hDC, hBrush, NULL, (LPARAM)hBitmap, 0, pt.x, pt.y, size.cx, size.cy, nFlags | DST_BITMAP);
+ }
+
+ BOOL DrawState(POINT pt, SIZE size, HICON hIcon, UINT nFlags, HBRUSH hBrush = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawState(m_hDC, hBrush, NULL, (LPARAM)hIcon, 0, pt.x, pt.y, size.cx, size.cy, nFlags | DST_ICON);
+ }
+
+ BOOL DrawState(POINT pt, SIZE size, LPCTSTR lpszText, UINT nFlags, BOOL bPrefixText = TRUE, int nTextLen = 0, HBRUSH hBrush = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawState(m_hDC, hBrush, NULL, (LPARAM)lpszText, (WPARAM)nTextLen, pt.x, pt.y, size.cx, size.cy, nFlags | (bPrefixText ? DST_PREFIXTEXT : DST_TEXT));
+ }
+
+ BOOL DrawState(POINT pt, SIZE size, DRAWSTATEPROC lpDrawProc, LPARAM lData, UINT nFlags, HBRUSH hBrush = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawState(m_hDC, hBrush, lpDrawProc, lData, 0, pt.x, pt.y, size.cx, size.cy, nFlags | DST_COMPLEX);
+ }
+
+// Ellipse and Polygon Functions
+ BOOL Chord(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Chord(m_hDC, x1, y1, x2, y2, x3, y3, x4, y4);
+ }
+
+ BOOL Chord(LPCRECT lpRect, POINT ptStart, POINT ptEnd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Chord(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom, ptStart.x, ptStart.y, ptEnd.x, ptEnd.y);
+ }
+
+ void DrawFocusRect(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hDC != NULL);
+ ::DrawFocusRect(m_hDC, lpRect);
+ }
+
+ BOOL Ellipse(int x1, int y1, int x2, int y2)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Ellipse(m_hDC, x1, y1, x2, y2);
+ }
+
+ BOOL Ellipse(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Ellipse(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom);
+ }
+
+ BOOL Pie(int x1, int y1, int x2, int y2, int x3, int y3, int x4, int y4)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Pie(m_hDC, x1, y1, x2, y2, x3, y3, x4, y4);
+ }
+
+ BOOL Pie(LPCRECT lpRect, POINT ptStart, POINT ptEnd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Pie(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom, ptStart.x, ptStart.y, ptEnd.x, ptEnd.y);
+ }
+
+ BOOL Polygon(const POINT* lpPoints, int nCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Polygon(m_hDC, lpPoints, nCount);
+ }
+
+ BOOL PolyPolygon(const POINT* lpPoints, const INT* lpPolyCounts, int nCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PolyPolygon(m_hDC, lpPoints, lpPolyCounts, nCount);
+ }
+
+ BOOL Rectangle(int x1, int y1, int x2, int y2)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Rectangle(m_hDC, x1, y1, x2, y2);
+ }
+
+ BOOL Rectangle(LPCRECT lpRect)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Rectangle(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom);
+ }
+
+ BOOL RoundRect(int x1, int y1, int x2, int y2, int x3, int y3)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::RoundRect(m_hDC, x1, y1, x2, y2, x3, y3);
+ }
+
+ BOOL RoundRect(LPCRECT lpRect, POINT point)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::RoundRect(m_hDC, lpRect->left, lpRect->top, lpRect->right, lpRect->bottom, point.x, point.y);
+ }
+
+// Bitmap Functions
+ BOOL PatBlt(int x, int y, int nWidth, int nHeight, DWORD dwRop)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PatBlt(m_hDC, x, y, nWidth, nHeight, dwRop);
+ }
+
+ BOOL BitBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC,
+ int xSrc, int ySrc, DWORD dwRop)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::BitBlt(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, dwRop);
+ }
+
+ BOOL StretchBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, int xSrc, int ySrc, int nSrcWidth, int nSrcHeight, DWORD dwRop)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::StretchBlt(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, nSrcWidth, nSrcHeight, dwRop);
+ }
+
+ COLORREF GetPixel(int x, int y) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetPixel(m_hDC, x, y);
+ }
+
+ COLORREF GetPixel(POINT point) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetPixel(m_hDC, point.x, point.y);
+ }
+
+ COLORREF SetPixel(int x, int y, COLORREF crColor)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetPixel(m_hDC, x, y, crColor);
+ }
+
+ COLORREF SetPixel(POINT point, COLORREF crColor)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetPixel(m_hDC, point.x, point.y, crColor);
+ }
+
+ BOOL FloodFill(int x, int y, COLORREF crColor)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::FloodFill(m_hDC, x, y, crColor);
+ }
+
+ BOOL ExtFloodFill(int x, int y, COLORREF crColor, UINT nFillType)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ExtFloodFill(m_hDC, x, y, crColor, nFillType);
+ }
+
+ BOOL MaskBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, int xSrc, int ySrc, HBITMAP hMaskBitmap, int xMask, int yMask, DWORD dwRop)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::MaskBlt(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, hMaskBitmap, xMask, yMask, dwRop);
+ }
+
+ BOOL PlgBlt(LPPOINT lpPoint, HDC hSrcDC, int xSrc, int ySrc, int nWidth, int nHeight, HBITMAP hMaskBitmap, int xMask, int yMask)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PlgBlt(m_hDC, lpPoint, hSrcDC, xSrc, ySrc, nWidth, nHeight, hMaskBitmap, xMask, yMask);
+ }
+
+ BOOL SetPixelV(int x, int y, COLORREF crColor)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetPixelV(m_hDC, x, y, crColor);
+ }
+
+ BOOL SetPixelV(POINT point, COLORREF crColor)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetPixelV(m_hDC, point.x, point.y, crColor);
+ }
+
+ BOOL TransparentBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, int xSrc, int ySrc, int nSrcWidth, int nSrcHeight, UINT crTransparent)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::TransparentBlt(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, nSrcWidth, nSrcHeight, crTransparent);
+ }
+
+ BOOL GradientFill(const PTRIVERTEX pVertices, DWORD nVertices, void* pMeshElements, DWORD nMeshElements, DWORD dwMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GradientFill(m_hDC, pVertices, nVertices, pMeshElements, nMeshElements, dwMode);
+ }
+
+ BOOL GradientFillRect(RECT& rect, COLORREF clr1, COLORREF clr2, bool bHorizontal)
+ {
+ ATLASSERT(m_hDC != NULL);
+
+ TRIVERTEX arrTvx[2] = { { 0 }, { 0 } };
+
+ arrTvx[0].x = rect.left;
+ arrTvx[0].y = rect.top;
+ arrTvx[0].Red = MAKEWORD(0, GetRValue(clr1));
+ arrTvx[0].Green = MAKEWORD(0, GetGValue(clr1));
+ arrTvx[0].Blue = MAKEWORD(0, GetBValue(clr1));
+ arrTvx[0].Alpha = 0;
+
+ arrTvx[1].x = rect.right;
+ arrTvx[1].y = rect.bottom;
+ arrTvx[1].Red = MAKEWORD(0, GetRValue(clr2));
+ arrTvx[1].Green = MAKEWORD(0, GetGValue(clr2));
+ arrTvx[1].Blue = MAKEWORD(0, GetBValue(clr2));
+ arrTvx[1].Alpha = 0;
+
+ GRADIENT_RECT gr = { 0, 1 };
+
+ return ::GradientFill(m_hDC, arrTvx, 2, &gr, 1, bHorizontal ? GRADIENT_FILL_RECT_H : GRADIENT_FILL_RECT_V);
+ }
+
+ BOOL AlphaBlend(int x, int y, int nWidth, int nHeight, HDC hSrcDC, int xSrc, int ySrc, int nSrcWidth, int nSrcHeight, BLENDFUNCTION bf)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::AlphaBlend(m_hDC, x, y, nWidth, nHeight, hSrcDC, xSrc, ySrc, nSrcWidth, nSrcHeight, bf);
+ }
+
+// Extra bitmap functions
+ // Helper function for painting a disabled toolbar or menu bitmap
+ // This function can take either an HBITMAP (for SS) or a DC with
+ // the bitmap already painted (for cmdbar)
+ BOOL DitherBlt(int x, int y, int nWidth, int nHeight, HDC hSrcDC, HBITMAP hBitmap, int xSrc, int ySrc,
+ HBRUSH hBrushBackground = ::GetSysColorBrush(COLOR_3DFACE),
+ HBRUSH hBrush3DEffect = ::GetSysColorBrush(COLOR_3DHILIGHT),
+ HBRUSH hBrushDisabledImage = ::GetSysColorBrush(COLOR_3DSHADOW))
+ {
+ ATLASSERT((m_hDC != NULL) || (hBitmap != NULL));
+ ATLASSERT((nWidth > 0) && (nHeight > 0));
+
+ // Create a generic DC for all BitBlts
+ CDCT<false> dc = (hSrcDC != NULL) ? hSrcDC : ::CreateCompatibleDC(m_hDC);
+ ATLASSERT(dc.m_hDC != NULL);
+ if(dc.m_hDC == NULL)
+ return FALSE;
+
+ // Create a DC for the monochrome DIB section
+ CDCT<true> dcBW = ::CreateCompatibleDC(m_hDC);
+ ATLASSERT(dcBW.m_hDC != NULL);
+ if(dcBW.m_hDC == NULL)
+ {
+ if(hSrcDC == NULL)
+ dc.DeleteDC();
+ return FALSE;
+ }
+
+ // Create the monochrome DIB section with a black and white palette
+ struct RGBBWBITMAPINFO
+ {
+ BITMAPINFOHEADER bmiHeader;
+ RGBQUAD bmiColors[2];
+ };
+
+ RGBBWBITMAPINFO rgbBWBitmapInfo =
+ {
+ { sizeof(BITMAPINFOHEADER), nWidth, nHeight, 1, 1, BI_RGB, 0, 0, 0, 0, 0 },
+ { { 0x00, 0x00, 0x00, 0x00 }, { 0xFF, 0xFF, 0xFF, 0x00 } }
+ };
+
+ VOID* pbitsBW;
+ CBitmap bmpBW = ::CreateDIBSection(dcBW, (LPBITMAPINFO)&rgbBWBitmapInfo, DIB_RGB_COLORS, &pbitsBW, NULL, 0);
+ ATLASSERT(bmpBW.m_hBitmap != NULL);
+ if(bmpBW.m_hBitmap == NULL)
+ {
+ if(hSrcDC == NULL)
+ dc.DeleteDC();
+ return FALSE;
+ }
+
+ // Attach the monochrome DIB section and the bitmap to the DCs
+ HBITMAP hbmOldBW = dcBW.SelectBitmap(bmpBW);
+ HBITMAP hbmOldDC = NULL;
+ if(hBitmap != NULL)
+ hbmOldDC = dc.SelectBitmap(hBitmap);
+
+ // Block: Dark gray removal: we want (128, 128, 128) pixels to become black and not white
+ {
+ CDCT<true> dcTemp1 = ::CreateCompatibleDC(m_hDC);
+ CDCT<true> dcTemp2 = ::CreateCompatibleDC(m_hDC);
+ CBitmap bmpTemp1;
+ bmpTemp1.CreateCompatibleBitmap(dc, nWidth, nHeight);
+ CBitmap bmpTemp2;
+ bmpTemp2.CreateBitmap(nWidth, nHeight, 1, 1, NULL);
+ HBITMAP hOldBmp1 = dcTemp1.SelectBitmap(bmpTemp1);
+ HBITMAP hOldBmp2 = dcTemp2.SelectBitmap(bmpTemp2);
+ // Let's copy our image, it will be altered
+ dcTemp1.BitBlt(0, 0, nWidth, nHeight, dc, xSrc, ySrc, SRCCOPY);
+
+ // All dark gray pixels will become white, the others black
+ dcTemp1.SetBkColor(RGB(128, 128, 128));
+ dcTemp2.BitBlt(0, 0, nWidth, nHeight, dcTemp1, 0, 0, SRCCOPY);
+ // Do an XOR to set to black these white pixels
+ dcTemp1.BitBlt(0, 0, nWidth, nHeight, dcTemp2, 0, 0, SRCINVERT);
+
+ // BitBlt the bitmap into the monochrome DIB section
+ // The DIB section will do a true monochrome conversion
+ // The magenta background being closer to white will become white
+ dcBW.BitBlt(0, 0, nWidth, nHeight, dcTemp1, 0, 0, SRCCOPY);
+
+ // Cleanup
+ dcTemp1.SelectBitmap(hOldBmp1);
+ dcTemp2.SelectBitmap(hOldBmp2);
+ }
+
+ // Paint the destination rectangle using hBrushBackground
+ if(hBrushBackground != NULL)
+ {
+ RECT rc = { x, y, x + nWidth, y + nHeight };
+ FillRect(&rc, hBrushBackground);
+ }
+
+ // BitBlt the black bits in the monochrome bitmap into hBrush3DEffect color in the destination DC
+ // The magic ROP comes from the Charles Petzold's book
+ HBRUSH hOldBrush = SelectBrush(hBrush3DEffect);
+ BitBlt(x + 1, y + 1, nWidth, nHeight, dcBW, 0, 0, 0xB8074A);
+
+ // BitBlt the black bits in the monochrome bitmap into hBrushDisabledImage color in the destination DC
+ SelectBrush(hBrushDisabledImage);
+ BitBlt(x, y, nWidth, nHeight, dcBW, 0, 0, 0xB8074A);
+
+ SelectBrush(hOldBrush);
+ dcBW.SelectBitmap(hbmOldBW);
+ dc.SelectBitmap(hbmOldDC);
+
+ if(hSrcDC == NULL)
+ dc.DeleteDC();
+
+ return TRUE;
+ }
+
+// Text Functions
+ BOOL TextOut(int x, int y, LPCTSTR lpszString, int nCount = -1)
+ {
+ ATLASSERT(m_hDC != NULL);
+ if(nCount == -1)
+ nCount = lstrlen(lpszString);
+ return ::TextOut(m_hDC, x, y, lpszString, nCount);
+ }
+
+ BOOL ExtTextOut(int x, int y, UINT nOptions, LPCRECT lpRect, LPCTSTR lpszString, int nCount = -1, LPINT lpDxWidths = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ if(nCount == -1)
+ nCount = lstrlen(lpszString);
+ ATLASSERT((nCount >= 0) && (nCount <= 8192));
+ return ::ExtTextOut(m_hDC, x, y, nOptions, lpRect, lpszString, (UINT)nCount, lpDxWidths);
+ }
+
+ SIZE TabbedTextOut(int x, int y, LPCTSTR lpszString, int nCount = -1, int nTabPositions = 0, LPINT lpnTabStopPositions = NULL, int nTabOrigin = 0)
+ {
+ ATLASSERT(m_hDC != NULL);
+ if(nCount == -1)
+ nCount = lstrlen(lpszString);
+ LONG lRes = ::TabbedTextOut(m_hDC, x, y, lpszString, nCount, nTabPositions, lpnTabStopPositions, nTabOrigin);
+ SIZE size = { GET_X_LPARAM(lRes), GET_Y_LPARAM(lRes) };
+ return size;
+ }
+
+ int DrawText(LPCTSTR lpstrText, int cchText, LPRECT lpRect, UINT uFormat)
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT((uFormat & DT_MODIFYSTRING) == 0);
+ return ::DrawText(m_hDC, lpstrText, cchText, lpRect, uFormat);
+ }
+
+ int DrawText(LPTSTR lpstrText, int cchText, LPRECT lpRect, UINT uFormat)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawText(m_hDC, lpstrText, cchText, lpRect, uFormat);
+ }
+
+ int DrawTextEx(LPTSTR lpstrText, int cchText, LPRECT lpRect, UINT uFormat, LPDRAWTEXTPARAMS lpDTParams = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawTextEx(m_hDC, lpstrText, cchText, lpRect, uFormat, lpDTParams);
+ }
+
+ // Note - ::DrawShadowText() is present only if comctl32.dll version 6 is loaded
+ int DrawShadowText(LPCWSTR lpstrText, int cchText, LPRECT lpRect, DWORD dwFlags, COLORREF clrText, COLORREF clrShadow, int xOffset, int yOffset)
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT(lpRect != NULL);
+ return ::DrawShadowText(m_hDC, lpstrText, cchText, lpRect, dwFlags, clrText, clrShadow, xOffset, yOffset);
+ }
+
+ BOOL GetTextExtent(LPCTSTR lpszString, int nCount, LPSIZE lpSize) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ if(nCount == -1)
+ nCount = lstrlen(lpszString);
+ return ::GetTextExtentPoint32(m_hDC, lpszString, nCount, lpSize);
+ }
+
+ BOOL GetTextExtentExPoint(LPCTSTR lpszString, int cchString, LPSIZE lpSize, int nMaxExtent, LPINT lpnFit = NULL, LPINT alpDx = NULL)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextExtentExPoint(m_hDC, lpszString, cchString, nMaxExtent, lpnFit, alpDx, lpSize);
+ }
+
+ DWORD GetTabbedTextExtent(LPCTSTR lpszString, int nCount = -1, int nTabPositions = 0, LPINT lpnTabStopPositions = NULL) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ if(nCount == -1)
+ nCount = lstrlen(lpszString);
+ return ::GetTabbedTextExtent(m_hDC, lpszString, nCount, nTabPositions, lpnTabStopPositions);
+ }
+
+ BOOL GrayString(HBRUSH hBrush, BOOL (CALLBACK* lpfnOutput)(HDC, LPARAM, int), LPARAM lpData, int nCount, int x, int y, int nWidth, int nHeight)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GrayString(m_hDC, hBrush, (GRAYSTRINGPROC)lpfnOutput, lpData, nCount, x, y, nWidth, nHeight);
+ }
+
+ UINT GetTextAlign() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextAlign(m_hDC);
+ }
+
+ UINT SetTextAlign(UINT nFlags)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetTextAlign(m_hDC, nFlags);
+ }
+
+ int GetTextFace(LPTSTR lpszFacename, int nCount) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextFace(m_hDC, nCount, lpszFacename);
+ }
+
+ int GetTextFaceLen() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextFace(m_hDC, 0, NULL);
+ }
+
+#ifdef _OLEAUTO_H_
+ BOOL GetTextFace(BSTR& bstrFace) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT(bstrFace == NULL);
+
+ int nLen = GetTextFaceLen();
+ if(nLen == 0)
+ return FALSE;
+
+ ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff;
+ LPTSTR lpszText = buff.Allocate(nLen);
+ if(lpszText == NULL)
+ return FALSE;
+
+ if(!GetTextFace(lpszText, nLen))
+ return FALSE;
+
+ bstrFace = ::SysAllocString(T2OLE(lpszText));
+ return (bstrFace != NULL) ? TRUE : FALSE;
+ }
+#endif
+
+#ifdef __ATLSTR_H__
+ int GetTextFace(ATL::CString& strFace) const
+ {
+ ATLASSERT(m_hDC != NULL);
+
+ int nLen = GetTextFaceLen();
+ if(nLen == 0)
+ return 0;
+
+ LPTSTR lpstr = strFace.GetBufferSetLength(nLen);
+ if(lpstr == NULL)
+ return 0;
+ int nRet = GetTextFace(lpstr, nLen);
+ strFace.ReleaseBuffer();
+ return nRet;
+ }
+#endif // __ATLSTR_H__
+
+ BOOL GetTextMetrics(LPTEXTMETRIC lpMetrics) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextMetrics(m_hDC, lpMetrics);
+ }
+
+ int SetTextJustification(int nBreakExtra, int nBreakCount)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetTextJustification(m_hDC, nBreakExtra, nBreakCount);
+ }
+
+ int GetTextCharacterExtra() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextCharacterExtra(m_hDC);
+ }
+
+ int SetTextCharacterExtra(int nCharExtra)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetTextCharacterExtra(m_hDC, nCharExtra);
+ }
+
+// Advanced Drawing
+ BOOL DrawEdge(LPRECT lpRect, UINT nEdge, UINT nFlags)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawEdge(m_hDC, lpRect, nEdge, nFlags);
+ }
+
+ BOOL DrawFrameControl(LPRECT lpRect, UINT nType, UINT nState)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawFrameControl(m_hDC, lpRect, nType, nState);
+ }
+
+// Scrolling Functions
+ BOOL ScrollDC(int dx, int dy, LPCRECT lpRectScroll, LPCRECT lpRectClip, HRGN hRgnUpdate, LPRECT lpRectUpdate)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ScrollDC(m_hDC, dx, dy, lpRectScroll, lpRectClip, hRgnUpdate, lpRectUpdate);
+ }
+
+// Font Functions
+ BOOL GetCharWidth(UINT nFirstChar, UINT nLastChar, LPINT lpBuffer) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetCharWidth(m_hDC, nFirstChar, nLastChar, lpBuffer);
+ }
+
+ // GetCharWidth32 is not supported under Win9x
+ BOOL GetCharWidth32(UINT nFirstChar, UINT nLastChar, LPINT lpBuffer) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetCharWidth32(m_hDC, nFirstChar, nLastChar, lpBuffer);
+ }
+
+ DWORD SetMapperFlags(DWORD dwFlag)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetMapperFlags(m_hDC, dwFlag);
+ }
+
+ BOOL GetAspectRatioFilter(LPSIZE lpSize) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetAspectRatioFilterEx(m_hDC, lpSize);
+ }
+
+ BOOL GetCharABCWidths(UINT nFirstChar, UINT nLastChar, LPABC lpabc) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetCharABCWidths(m_hDC, nFirstChar, nLastChar, lpabc);
+ }
+
+ DWORD GetFontData(DWORD dwTable, DWORD dwOffset, LPVOID lpData, DWORD cbData) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetFontData(m_hDC, dwTable, dwOffset, lpData, cbData);
+ }
+
+ int GetKerningPairs(int nPairs, LPKERNINGPAIR lpkrnpair) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetKerningPairs(m_hDC, nPairs, lpkrnpair);
+ }
+
+ UINT GetOutlineTextMetrics(UINT cbData, LPOUTLINETEXTMETRIC lpotm) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetOutlineTextMetrics(m_hDC, cbData, lpotm);
+ }
+
+ DWORD GetGlyphOutline(UINT nChar, UINT nFormat, LPGLYPHMETRICS lpgm, DWORD cbBuffer, LPVOID lpBuffer, const MAT2* lpmat2) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetGlyphOutline(m_hDC, nChar, nFormat, lpgm, cbBuffer, lpBuffer, lpmat2);
+ }
+
+ BOOL GetCharABCWidths(UINT nFirstChar, UINT nLastChar, LPABCFLOAT lpABCF) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetCharABCWidthsFloat(m_hDC, nFirstChar, nLastChar, lpABCF);
+ }
+
+ BOOL GetCharWidth(UINT nFirstChar, UINT nLastChar, float* lpFloatBuffer) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetCharWidthFloat(m_hDC, nFirstChar, nLastChar, lpFloatBuffer);
+ }
+
+// Printer/Device Escape Functions
+ int Escape(int nEscape, int nCount, LPCSTR lpszInData, LPVOID lpOutData)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::Escape(m_hDC, nEscape, nCount, lpszInData, lpOutData);
+ }
+
+ int Escape(int nEscape, int nInputSize, LPCSTR lpszInputData,
+ int nOutputSize, LPSTR lpszOutputData)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ExtEscape(m_hDC, nEscape, nInputSize, lpszInputData, nOutputSize, lpszOutputData);
+ }
+
+ int DrawEscape(int nEscape, int nInputSize, LPCSTR lpszInputData)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DrawEscape(m_hDC, nEscape, nInputSize, lpszInputData);
+ }
+
+ // Escape helpers
+ int StartDoc(LPCTSTR lpszDocName) // old Win3.0 version
+ {
+ DOCINFO di = {};
+ di.cbSize = sizeof(DOCINFO);
+ di.lpszDocName = lpszDocName;
+ return StartDoc(&di);
+ }
+
+ int StartDoc(LPDOCINFO lpDocInfo)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::StartDoc(m_hDC, lpDocInfo);
+ }
+
+ int StartPage()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::StartPage(m_hDC);
+ }
+
+ int EndPage()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::EndPage(m_hDC);
+ }
+
+ int SetAbortProc(BOOL (CALLBACK* lpfn)(HDC, int))
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetAbortProc(m_hDC, (ABORTPROC)lpfn);
+ }
+
+ int AbortDoc()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::AbortDoc(m_hDC);
+ }
+
+ int EndDoc()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::EndDoc(m_hDC);
+ }
+
+// MetaFile Functions
+ BOOL PlayMetaFile(HMETAFILE hMF)
+ {
+ ATLASSERT(m_hDC != NULL);
+ if(::GetDeviceCaps(m_hDC, TECHNOLOGY) == DT_METAFILE)
+ {
+ // playing metafile in metafile, just use core windows API
+ return ::PlayMetaFile(m_hDC, hMF);
+ }
+
+ // for special playback, lParam == pDC
+ return ::EnumMetaFile(m_hDC, hMF, EnumMetaFileProc, (LPARAM)this);
+ }
+
+ BOOL PlayMetaFile(HENHMETAFILE hEnhMetaFile, LPCRECT lpBounds)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::PlayEnhMetaFile(m_hDC, hEnhMetaFile, lpBounds);
+ }
+
+ BOOL AddMetaFileComment(UINT nDataSize, const BYTE* pCommentData) // can be used for enhanced metafiles only
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GdiComment(m_hDC, nDataSize, pCommentData);
+ }
+
+ // Special handling for metafile playback
+ static int CALLBACK EnumMetaFileProc(HDC hDC, HANDLETABLE* pHandleTable, METARECORD* pMetaRec, int nHandles, LPARAM lParam)
+ {
+ CDCT<false>* pDC = (CDCT<false>*)lParam;
+
+ switch (pMetaRec->rdFunction)
+ {
+ case META_SETMAPMODE:
+ pDC->SetMapMode((int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_SETWINDOWEXT:
+ pDC->SetWindowExt((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_SETWINDOWORG:
+ pDC->SetWindowOrg((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_SETVIEWPORTEXT:
+ pDC->SetViewportExt((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_SETVIEWPORTORG:
+ pDC->SetViewportOrg((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_SCALEWINDOWEXT:
+ pDC->ScaleWindowExt((int)(short)pMetaRec->rdParm[3], (int)(short)pMetaRec->rdParm[2],
+ (int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_SCALEVIEWPORTEXT:
+ pDC->ScaleViewportExt((int)(short)pMetaRec->rdParm[3], (int)(short)pMetaRec->rdParm[2],
+ (int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_OFFSETVIEWPORTORG:
+ pDC->OffsetViewportOrg((int)(short)pMetaRec->rdParm[1], (int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_SAVEDC:
+ pDC->SaveDC();
+ break;
+ case META_RESTOREDC:
+ pDC->RestoreDC((int)(short)pMetaRec->rdParm[0]);
+ break;
+ case META_SETBKCOLOR:
+ pDC->SetBkColor(*(UNALIGNED COLORREF*)&pMetaRec->rdParm[0]);
+ break;
+ case META_SETTEXTCOLOR:
+ pDC->SetTextColor(*(UNALIGNED COLORREF*)&pMetaRec->rdParm[0]);
+ break;
+
+ // need to watch out for SelectObject(HFONT), for custom font mapping
+ case META_SELECTOBJECT:
+ {
+ HGDIOBJ hObject = pHandleTable->objectHandle[pMetaRec->rdParm[0]];
+ UINT nObjType = ::GetObjectType(hObject);
+ if(nObjType == 0)
+ {
+ // object type is unknown, determine if it is a font
+ HFONT hStockFont = (HFONT)::GetStockObject(SYSTEM_FONT);
+ HFONT hFontOld = (HFONT)::SelectObject(pDC->m_hDC, hStockFont);
+ HGDIOBJ hObjOld = ::SelectObject(pDC->m_hDC, hObject);
+ if(hObjOld == hStockFont)
+ {
+ // got the stock object back, so must be selecting a font
+ pDC->SelectFont((HFONT)hObject);
+ break; // don't play the default record
+ }
+ else
+ {
+ // didn't get the stock object back, so restore everything
+ ::SelectObject(pDC->m_hDC, hFontOld);
+ ::SelectObject(pDC->m_hDC, hObjOld);
+ }
+ // and fall through to PlayMetaFileRecord...
+ }
+ else if(nObjType == OBJ_FONT)
+ {
+ // play back as CDCHandle::SelectFont(HFONT)
+ pDC->SelectFont((HFONT)hObject);
+ break; // don't play the default record
+ }
+ }
+ // fall through...
+
+ default:
+ ::PlayMetaFileRecord(hDC, pHandleTable, pMetaRec, nHandles);
+ break;
+ }
+
+ return 1;
+ }
+
+// Path Functions
+ BOOL AbortPath()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::AbortPath(m_hDC);
+ }
+
+ BOOL BeginPath()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::BeginPath(m_hDC);
+ }
+
+ BOOL CloseFigure()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::CloseFigure(m_hDC);
+ }
+
+ BOOL EndPath()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::EndPath(m_hDC);
+ }
+
+ BOOL FillPath()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::FillPath(m_hDC);
+ }
+
+ BOOL FlattenPath()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::FlattenPath(m_hDC);
+ }
+
+ BOOL StrokeAndFillPath()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::StrokeAndFillPath(m_hDC);
+ }
+
+ BOOL StrokePath()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::StrokePath(m_hDC);
+ }
+
+ BOOL WidenPath()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::WidenPath(m_hDC);
+ }
+
+ BOOL GetMiterLimit(PFLOAT pfMiterLimit) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetMiterLimit(m_hDC, pfMiterLimit);
+ }
+
+ BOOL SetMiterLimit(float fMiterLimit)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetMiterLimit(m_hDC, fMiterLimit, NULL);
+ }
+
+ int GetPath(LPPOINT lpPoints, LPBYTE lpTypes, int nCount) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetPath(m_hDC, lpPoints, lpTypes, nCount);
+ }
+
+ BOOL SelectClipPath(int nMode)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SelectClipPath(m_hDC, nMode);
+ }
+
+// Misc Helper Functions
+ static CBrushHandle PASCAL GetHalftoneBrush()
+ {
+ HBRUSH halftoneBrush = NULL;
+ WORD grayPattern[8] = {};
+ for(int i = 0; i < 8; i++)
+ grayPattern[i] = (WORD)(0x5555 << (i & 1));
+ HBITMAP grayBitmap = CreateBitmap(8, 8, 1, 1, &grayPattern);
+ if(grayBitmap != NULL)
+ {
+ halftoneBrush = ::CreatePatternBrush(grayBitmap);
+ DeleteObject(grayBitmap);
+ }
+ return CBrushHandle(halftoneBrush);
+ }
+
+ void DrawDragRect(LPCRECT lpRect, SIZE size, LPCRECT lpRectLast, SIZE sizeLast, HBRUSH hBrush = NULL, HBRUSH hBrushLast = NULL)
+ {
+ // first, determine the update region and select it
+ CRgn rgnOutside;
+ rgnOutside.CreateRectRgnIndirect(lpRect);
+ RECT rect = *lpRect;
+ ::InflateRect(&rect, -size.cx, -size.cy);
+ ::IntersectRect(&rect, &rect, lpRect);
+ CRgn rgnInside;
+ rgnInside.CreateRectRgnIndirect(&rect);
+ CRgn rgnNew;
+ rgnNew.CreateRectRgn(0, 0, 0, 0);
+ rgnNew.CombineRgn(rgnOutside, rgnInside, RGN_XOR);
+
+ HBRUSH hBrushOld = NULL;
+ CBrush brushHalftone;
+ if(hBrush == NULL)
+ brushHalftone = hBrush = CDCHandle::GetHalftoneBrush();
+ if(hBrushLast == NULL)
+ hBrushLast = hBrush;
+
+ CRgn rgnLast;
+ CRgn rgnUpdate;
+ if(lpRectLast != NULL)
+ {
+ // find difference between new region and old region
+ rgnLast.CreateRectRgn(0, 0, 0, 0);
+ rgnOutside.SetRectRgn(lpRectLast->left, lpRectLast->top, lpRectLast->right, lpRectLast->bottom);
+ rect = *lpRectLast;
+ ::InflateRect(&rect, -sizeLast.cx, -sizeLast.cy);
+ ::IntersectRect(&rect, &rect, lpRectLast);
+ rgnInside.SetRectRgn(rect.left, rect.top, rect.right, rect.bottom);
+ rgnLast.CombineRgn(rgnOutside, rgnInside, RGN_XOR);
+
+ // only diff them if brushes are the same
+ if(hBrush == hBrushLast)
+ {
+ rgnUpdate.CreateRectRgn(0, 0, 0, 0);
+ rgnUpdate.CombineRgn(rgnLast, rgnNew, RGN_XOR);
+ }
+ }
+ if((hBrush != hBrushLast) && (lpRectLast != NULL))
+ {
+ // brushes are different -- erase old region first
+ SelectClipRgn(rgnLast);
+ GetClipBox(&rect);
+ hBrushOld = SelectBrush(hBrushLast);
+ PatBlt(rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, PATINVERT);
+ SelectBrush(hBrushOld);
+ hBrushOld = NULL;
+ }
+
+ // draw into the update/new region
+ SelectClipRgn(rgnUpdate.IsNull() ? rgnNew : rgnUpdate);
+ GetClipBox(&rect);
+ hBrushOld = SelectBrush(hBrush);
+ PatBlt(rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, PATINVERT);
+
+ // cleanup DC
+ if(hBrushOld != NULL)
+ SelectBrush(hBrushOld);
+ SelectClipRgn(NULL);
+ }
+
+ void FillSolidRect(LPCRECT lpRect, COLORREF clr)
+ {
+ ATLASSERT(m_hDC != NULL);
+
+ COLORREF clrOld = ::SetBkColor(m_hDC, clr);
+ ATLASSERT(clrOld != CLR_INVALID);
+ if(clrOld != CLR_INVALID)
+ {
+ ::ExtTextOut(m_hDC, 0, 0, ETO_OPAQUE, lpRect, NULL, 0, NULL);
+ ::SetBkColor(m_hDC, clrOld);
+ }
+ }
+
+ void FillSolidRect(int x, int y, int cx, int cy, COLORREF clr)
+ {
+ ATLASSERT(m_hDC != NULL);
+
+ RECT rect = { x, y, x + cx, y + cy };
+ FillSolidRect(&rect, clr);
+ }
+
+ void Draw3dRect(LPCRECT lpRect, COLORREF clrTopLeft, COLORREF clrBottomRight)
+ {
+ Draw3dRect(lpRect->left, lpRect->top, lpRect->right - lpRect->left,
+ lpRect->bottom - lpRect->top, clrTopLeft, clrBottomRight);
+ }
+
+ void Draw3dRect(int x, int y, int cx, int cy, COLORREF clrTopLeft, COLORREF clrBottomRight)
+ {
+ FillSolidRect(x, y, cx - 1, 1, clrTopLeft);
+ FillSolidRect(x, y, 1, cy - 1, clrTopLeft);
+ FillSolidRect(x + cx, y, -1, cy, clrBottomRight);
+ FillSolidRect(x, y + cy, cx, -1, clrBottomRight);
+ }
+
+// DIB support
+ int SetDIBitsToDevice(int x, int y, DWORD dwWidth, DWORD dwHeight, int xSrc, int ySrc, UINT uStartScan, UINT cScanLines, CONST VOID* lpvBits, CONST BITMAPINFO* lpbmi, UINT uColorUse)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetDIBitsToDevice(m_hDC, x, y, dwWidth, dwHeight, xSrc, ySrc, uStartScan, cScanLines, lpvBits, lpbmi, uColorUse);
+ }
+
+ int StretchDIBits(int x, int y, int nWidth, int nHeight, int xSrc, int ySrc, int nSrcWidth, int nSrcHeight, CONST VOID* lpvBits, CONST BITMAPINFO* lpbmi, UINT uColorUse, DWORD dwRop)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::StretchDIBits(m_hDC, x, y, nWidth, nHeight, xSrc, ySrc, nSrcWidth, nSrcHeight, lpvBits, lpbmi, uColorUse, dwRop);
+ }
+
+ UINT GetDIBColorTable(UINT uStartIndex, UINT cEntries, RGBQUAD* pColors) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetDIBColorTable(m_hDC, uStartIndex, cEntries, pColors);
+ }
+
+ UINT SetDIBColorTable(UINT uStartIndex, UINT cEntries, CONST RGBQUAD* pColors)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetDIBColorTable(m_hDC, uStartIndex, cEntries, pColors);
+ }
+
+// OpenGL support
+#if !defined(_ATL_NO_OPENGL)
+ int ChoosePixelFormat(CONST PIXELFORMATDESCRIPTOR* ppfd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ChoosePixelFormat(m_hDC, ppfd);
+ }
+
+ int DescribePixelFormat(int iPixelFormat, UINT nBytes, LPPIXELFORMATDESCRIPTOR ppfd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::DescribePixelFormat(m_hDC, iPixelFormat, nBytes, ppfd);
+ }
+
+ int GetPixelFormat() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetPixelFormat(m_hDC);
+ }
+
+ BOOL SetPixelFormat(int iPixelFormat, CONST PIXELFORMATDESCRIPTOR* ppfd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetPixelFormat(m_hDC, iPixelFormat, ppfd);
+ }
+
+ BOOL SwapBuffers()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SwapBuffers(m_hDC);
+ }
+
+ HGLRC wglCreateContext()
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglCreateContext(m_hDC);
+ }
+
+ HGLRC wglCreateLayerContext(int iLayerPlane)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglCreateLayerContext(m_hDC, iLayerPlane);
+ }
+
+ BOOL wglMakeCurrent(HGLRC hglrc)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglMakeCurrent(m_hDC, hglrc);
+ }
+
+ BOOL wglUseFontBitmaps(DWORD dwFirst, DWORD dwCount, DWORD listBase)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglUseFontBitmaps(m_hDC, dwFirst, dwCount, listBase);
+ }
+
+ BOOL wglUseFontOutlines(DWORD dwFirst, DWORD dwCount, DWORD listBase, FLOAT deviation, FLOAT extrusion, int format, LPGLYPHMETRICSFLOAT lpgmf)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglUseFontOutlines(m_hDC, dwFirst, dwCount, listBase, deviation, extrusion, format, lpgmf);
+ }
+
+ BOOL wglDescribeLayerPlane(int iPixelFormat, int iLayerPlane, UINT nBytes, LPLAYERPLANEDESCRIPTOR plpd)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglDescribeLayerPlane(m_hDC, iPixelFormat, iLayerPlane, nBytes, plpd);
+ }
+
+ int wglSetLayerPaletteEntries(int iLayerPlane, int iStart, int cEntries, CONST COLORREF* pclr)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglSetLayerPaletteEntries(m_hDC, iLayerPlane, iStart, cEntries, pclr);
+ }
+
+ int wglGetLayerPaletteEntries(int iLayerPlane, int iStart, int cEntries, COLORREF* pclr)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglGetLayerPaletteEntries(m_hDC, iLayerPlane, iStart, cEntries, pclr);
+ }
+
+ BOOL wglRealizeLayerPalette(int iLayerPlane, BOOL bRealize)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglRealizeLayerPalette(m_hDC, iLayerPlane, bRealize);
+ }
+
+ BOOL wglSwapLayerBuffers(UINT uPlanes)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::wglSwapLayerBuffers(m_hDC, uPlanes);
+ }
+#endif // !defined(_ATL_NO_OPENGL)
+
+ COLORREF GetDCPenColor() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetDCPenColor(m_hDC);
+ }
+
+ COLORREF SetDCPenColor(COLORREF clr)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetDCPenColor(m_hDC, clr);
+ }
+
+ COLORREF GetDCBrushColor() const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetDCBrushColor(m_hDC);
+ }
+
+ COLORREF SetDCBrushColor(COLORREF clr)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::SetDCBrushColor(m_hDC, clr);
+ }
+
+ DWORD GetFontUnicodeRanges(LPGLYPHSET lpgs) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetFontUnicodeRanges(m_hDC, lpgs);
+ }
+
+ DWORD GetGlyphIndices(LPCTSTR lpstr, int cch, LPWORD pgi, DWORD dwFlags) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetGlyphIndices(m_hDC, lpstr, cch, pgi, dwFlags);
+ }
+
+ BOOL GetTextExtentPointI(LPWORD pgiIn, int cgi, LPSIZE lpSize) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextExtentPointI(m_hDC, pgiIn, cgi, lpSize);
+ }
+
+ BOOL GetTextExtentExPointI(LPWORD pgiIn, int cgi, int nMaxExtent, LPINT lpnFit, LPINT alpDx, LPSIZE lpSize) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetTextExtentExPointI(m_hDC, pgiIn, cgi, nMaxExtent, lpnFit, alpDx, lpSize);
+ }
+
+ BOOL GetCharWidthI(UINT giFirst, UINT cgi, LPWORD pgi, LPINT lpBuffer) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetCharWidthI(m_hDC, giFirst, cgi, pgi, lpBuffer);
+ }
+
+ BOOL GetCharABCWidthsI(UINT giFirst, UINT cgi, LPWORD pgi, LPABC lpabc) const
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::GetCharABCWidthsI(m_hDC, giFirst, cgi, pgi, lpabc);
+ }
+
+ BOOL ColorCorrectPalette(HPALETTE hPalette, DWORD dwFirstEntry, DWORD dwNumOfEntries)
+ {
+ ATLASSERT(m_hDC != NULL);
+ return ::ColorCorrectPalette(m_hDC, hPalette, dwFirstEntry, dwNumOfEntries);
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CDC Helpers
+
+class CPaintDC : public CDC
+{
+public:
+// Data members
+ HWND m_hWnd;
+ PAINTSTRUCT m_ps;
+
+// Constructor/destructor
+ CPaintDC(HWND hWnd)
+ {
+ ATLASSERT(::IsWindow(hWnd));
+ m_hWnd = hWnd;
+ m_hDC = ::BeginPaint(hWnd, &m_ps);
+ }
+
+ ~CPaintDC()
+ {
+ ATLASSERT(m_hDC != NULL);
+ ATLASSERT(::IsWindow(m_hWnd));
+ ::EndPaint(m_hWnd, &m_ps);
+ Detach();
+ }
+};
+
+class CClientDC : public CDC
+{
+public:
+// Data members
+ HWND m_hWnd;
+
+// Constructor/destructor
+ CClientDC(HWND hWnd)
+ {
+ ATLASSERT((hWnd == NULL) || ::IsWindow(hWnd));
+ m_hWnd = hWnd;
+ m_hDC = ::GetDC(hWnd);
+ }
+
+ ~CClientDC()
+ {
+ ATLASSERT(m_hDC != NULL);
+ ::ReleaseDC(m_hWnd, Detach());
+ }
+};
+
+class CWindowDC : public CDC
+{
+public:
+// Data members
+ HWND m_hWnd;
+
+// Constructor/destructor
+ CWindowDC(HWND hWnd)
+ {
+ ATLASSERT((hWnd == NULL) || ::IsWindow(hWnd));
+ m_hWnd = hWnd;
+ m_hDC = ::GetWindowDC(hWnd);
+ }
+
+ ~CWindowDC()
+ {
+ ATLASSERT(m_hDC != NULL);
+ ::ReleaseDC(m_hWnd, Detach());
+ }
+};
+
+class CMemoryDC : public CDC
+{
+public:
+// Data members
+ HDC m_hDCOriginal;
+ RECT m_rcPaint;
+ CBitmap m_bmp;
+ HBITMAP m_hBmpOld;
+
+// Constructor/destructor
+ CMemoryDC(HDC hDC, const RECT& rcPaint) : m_hDCOriginal(hDC), m_hBmpOld(NULL)
+ {
+ m_rcPaint = rcPaint;
+ CreateCompatibleDC(m_hDCOriginal);
+ ATLASSERT(m_hDC != NULL);
+ m_bmp.CreateCompatibleBitmap(m_hDCOriginal, m_rcPaint.right - m_rcPaint.left, m_rcPaint.bottom - m_rcPaint.top);
+ ATLASSERT(m_bmp.m_hBitmap != NULL);
+ m_hBmpOld = SelectBitmap(m_bmp);
+ SetViewportOrg(-m_rcPaint.left, -m_rcPaint.top);
+ }
+
+ ~CMemoryDC()
+ {
+ ::BitBlt(m_hDCOriginal, m_rcPaint.left, m_rcPaint.top, m_rcPaint.right - m_rcPaint.left, m_rcPaint.bottom - m_rcPaint.top, m_hDC, m_rcPaint.left, m_rcPaint.top, SRCCOPY);
+ SelectBitmap(m_hBmpOld);
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Enhanced metafile support
+
+class CEnhMetaFileInfo
+{
+public:
+// Data members
+ HENHMETAFILE m_hEMF;
+ BYTE* m_pBits;
+ TCHAR* m_pDesc;
+ ENHMETAHEADER m_header;
+ PIXELFORMATDESCRIPTOR m_pfd;
+
+// Constructor/destructor
+ CEnhMetaFileInfo(HENHMETAFILE hEMF) : m_hEMF(hEMF), m_pBits(NULL), m_pDesc(NULL)
+ {
+ memset(&m_header, 0, sizeof(m_header));
+ memset(&m_pfd, 0, sizeof(m_pfd));
+ }
+
+ ~CEnhMetaFileInfo()
+ {
+ delete [] m_pBits;
+ delete [] m_pDesc;
+ }
+
+// Operations
+ BYTE* GetEnhMetaFileBits()
+ {
+ ATLASSERT(m_hEMF != NULL);
+ UINT nBytes = ::GetEnhMetaFileBits(m_hEMF, 0, NULL);
+ delete [] m_pBits;
+ m_pBits = NULL;
+ ATLTRY(m_pBits = new BYTE[nBytes]);
+ if (m_pBits != NULL)
+ ::GetEnhMetaFileBits(m_hEMF, nBytes, m_pBits);
+ return m_pBits;
+ }
+
+ LPTSTR GetEnhMetaFileDescription()
+ {
+ ATLASSERT(m_hEMF != NULL);
+ UINT nLen = ::GetEnhMetaFileDescription(m_hEMF, 0, NULL);
+ delete [] m_pDesc;
+ m_pDesc = NULL;
+ ATLTRY(m_pDesc = new TCHAR[nLen]);
+ if (m_pDesc != NULL)
+ nLen = ::GetEnhMetaFileDescription(m_hEMF, nLen, m_pDesc);
+ return m_pDesc;
+ }
+
+ ENHMETAHEADER* GetEnhMetaFileHeader()
+ {
+ ATLASSERT(m_hEMF != NULL);
+ memset(&m_header, 0, sizeof(m_header));
+ m_header.iType = EMR_HEADER;
+ m_header.nSize = sizeof(ENHMETAHEADER);
+ UINT n = ::GetEnhMetaFileHeader(m_hEMF, sizeof(ENHMETAHEADER), &m_header);
+ return (n != 0) ? &m_header : NULL;
+ }
+
+ PIXELFORMATDESCRIPTOR* GetEnhMetaFilePixelFormat()
+ {
+ ATLASSERT(m_hEMF != NULL);
+ memset(&m_pfd, 0, sizeof(m_pfd));
+ UINT n = ::GetEnhMetaFilePixelFormat(m_hEMF, sizeof(m_pfd), &m_pfd);
+ return (n != 0) ? &m_pfd : NULL;
+ }
+};
+
+
+template <bool t_bManaged>
+class CEnhMetaFileT
+{
+public:
+// Data members
+ HENHMETAFILE m_hEMF;
+
+// Constructor/destructor
+ CEnhMetaFileT(HENHMETAFILE hEMF = NULL) : m_hEMF(hEMF)
+ {
+ }
+
+ ~CEnhMetaFileT()
+ {
+ if(t_bManaged && (m_hEMF != NULL))
+ DeleteObject();
+ }
+
+// Operations
+ CEnhMetaFileT<t_bManaged>& operator =(HENHMETAFILE hEMF)
+ {
+ Attach(hEMF);
+ return *this;
+ }
+
+ void Attach(HENHMETAFILE hEMF)
+ {
+ if(t_bManaged && (m_hEMF != NULL) && (m_hEMF != hEMF))
+ DeleteObject();
+ m_hEMF = hEMF;
+ }
+
+ HENHMETAFILE Detach()
+ {
+ HENHMETAFILE hEMF = m_hEMF;
+ m_hEMF = NULL;
+ return hEMF;
+ }
+
+ operator HENHMETAFILE() const { return m_hEMF; }
+
+ bool IsNull() const { return (m_hEMF == NULL); }
+
+ BOOL DeleteObject()
+ {
+ ATLASSERT(m_hEMF != NULL);
+ BOOL bRet = ::DeleteEnhMetaFile(m_hEMF);
+ m_hEMF = NULL;
+ return bRet;
+ }
+
+ UINT GetEnhMetaFileBits(UINT cbBuffer, LPBYTE lpbBuffer) const
+ {
+ ATLASSERT(m_hEMF != NULL);
+ return ::GetEnhMetaFileBits(m_hEMF, cbBuffer, lpbBuffer);
+ }
+
+ UINT GetEnhMetaFileDescription(UINT cchBuffer, LPTSTR lpszDescription) const
+ {
+ ATLASSERT(m_hEMF != NULL);
+ return ::GetEnhMetaFileDescription(m_hEMF, cchBuffer, lpszDescription);
+ }
+
+ UINT GetEnhMetaFileHeader(LPENHMETAHEADER lpemh) const
+ {
+ ATLASSERT(m_hEMF != NULL);
+ lpemh->iType = EMR_HEADER;
+ lpemh->nSize = sizeof(ENHMETAHEADER);
+ return ::GetEnhMetaFileHeader(m_hEMF, sizeof(ENHMETAHEADER), lpemh);
+ }
+
+ UINT GetEnhMetaFilePaletteEntries(UINT cEntries, LPPALETTEENTRY lppe) const
+ {
+ ATLASSERT(m_hEMF != NULL);
+ return ::GetEnhMetaFilePaletteEntries(m_hEMF, cEntries, lppe);
+ }
+
+ UINT GetEnhMetaFilePixelFormat(DWORD cbBuffer, PIXELFORMATDESCRIPTOR* ppfd) const
+ {
+ ATLASSERT(m_hEMF != NULL);
+ return ::GetEnhMetaFilePixelFormat(m_hEMF, cbBuffer, ppfd);
+ }
+};
+
+typedef CEnhMetaFileT<false> CEnhMetaFileHandle;
+typedef CEnhMetaFileT<true> CEnhMetaFile;
+
+
+class CEnhMetaFileDC : public CDC
+{
+public:
+// Constructor/destructor
+ CEnhMetaFileDC()
+ {
+ }
+
+ CEnhMetaFileDC(HDC hdc, LPCRECT lpRect)
+ {
+ Create(hdc, NULL, lpRect, NULL);
+ ATLASSERT(m_hDC != NULL);
+ }
+
+ CEnhMetaFileDC(HDC hdcRef, LPCTSTR lpFilename, LPCRECT lpRect, LPCTSTR lpDescription)
+ {
+ Create(hdcRef, lpFilename, lpRect, lpDescription);
+ ATLASSERT(m_hDC != NULL);
+ }
+
+ ~CEnhMetaFileDC()
+ {
+ HENHMETAFILE hEMF = Close();
+ if (hEMF != NULL)
+ ::DeleteEnhMetaFile(hEMF);
+ }
+
+// Operations
+ void Create(HDC hdcRef, LPCTSTR lpFilename, LPCRECT lpRect, LPCTSTR lpDescription)
+ {
+ ATLASSERT(m_hDC == NULL);
+ m_hDC = ::CreateEnhMetaFile(hdcRef, lpFilename, lpRect, lpDescription);
+ }
+
+ HENHMETAFILE Close()
+ {
+ HENHMETAFILE hEMF = NULL;
+ if (m_hDC != NULL)
+ {
+ hEMF = ::CloseEnhMetaFile(m_hDC);
+ m_hDC = NULL;
+ }
+ return hEMF;
+ }
+};
+
+} // namespace WTL
+
+#endif // __ATLGDI_H__
diff --git a/Examples/WhisperDesktop/Utils/WTL/atlres.h b/Examples/WhisperDesktop/Utils/WTL/atlres.h
new file mode 100644
index 0000000..aed58f1
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/WTL/atlres.h
@@ -0,0 +1,259 @@
+// Windows Template Library - WTL version 10.0
+// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved.
+//
+// This file is a part of the Windows Template Library.
+// The use and distribution terms for this software are covered by the
+// Microsoft Public License (http://opensource.org/licenses/MS-PL)
+// which can be found in the file MS-PL.txt at the root folder.
+
+#ifndef __ATLRES_H__
+#define __ATLRES_H__
+
+#pragma once
+
+
+#ifdef RC_INVOKED
+#ifndef _INC_WINDOWS
+
+ #define _INC_WINDOWS
+
+ #define VS_VERSION_INFO 1
+
+ #ifdef APSTUDIO_INVOKED
+ #define APSTUDIO_HIDDEN_SYMBOLS // Ignore following symbols
+ #endif // APSTUDIO_INVOKED
+
+ #ifndef WINVER
+ #define WINVER 0x0500
+ #endif // !WINVER
+
+ #include <winresrc.h>
+
+ // operation messages sent to DLGINIT
+ #define LB_ADDSTRING (WM_USER+1)
+ #define CB_ADDSTRING (WM_USER+3)
+
+ #ifdef APSTUDIO_INVOKED
+ #undef APSTUDIO_HIDDEN_SYMBOLS
+ #endif // APSTUDIO_INVOKED
+
+ #ifdef IDC_STATIC
+ #undef IDC_STATIC
+ #endif // IDC_STATIC
+ #define IDC_STATIC (-1)
+
+#endif // !_INC_WINDOWS
+#endif // RC_INVOKED
+
+#ifdef APSTUDIO_INVOKED
+ #define APSTUDIO_HIDDEN_SYMBOLS
+#endif // APSTUDIO_INVOKED
+
+///////////////////////////////////////////////////////////////////////////////
+// ATL resource types
+
+#ifndef RC_INVOKED
+ #define RT_DLGINIT MAKEINTRESOURCE(240)
+ #define RT_TOOLBAR MAKEINTRESOURCE(241)
+#endif // RC_INVOKED
+
+///////////////////////////////////////////////////////////////////////////////
+
+#ifdef APSTUDIO_INVOKED
+ #undef APSTUDIO_HIDDEN_SYMBOLS
+#endif // APSTUDIO_INVOKED
+
+///////////////////////////////////////////////////////////////////////////////
+// Standard window components
+
+#define ID_SEPARATOR 0 // special separator value
+#define ID_DEFAULT_PANE 0 // default status bar pane
+
+#ifndef RC_INVOKED // code only
+// standard control bars (IDW = window ID)
+ #define ATL_IDW_TOOLBAR 0xE800 // main Toolbar for window
+ #define ATL_IDW_STATUS_BAR 0xE801 // Status bar window
+ #define ATL_IDW_COMMAND_BAR 0xE802 // Command bar window
+
+// parts of a frame window
+ #define ATL_IDW_CLIENT 0xE900
+ #define ATL_IDW_PANE_FIRST 0xE900 // first pane (256 max)
+ #define ATL_IDW_PANE_LAST 0xE9FF
+ #define ATL_IDW_HSCROLL_FIRST 0xEA00 // first Horz scrollbar (16 max)
+ #define ATL_IDW_VSCROLL_FIRST 0xEA10 // first Vert scrollbar (16 max)
+
+ #define ATL_IDW_SIZE_BOX 0xEA20 // size box for splitters
+ #define ATL_IDW_PANE_SAVE 0xEA21 // to shift ATL_IDW_PANE_FIRST
+
+// bands for a rebar
+ #define ATL_IDW_BAND_FIRST 0xEB00
+ #define ATL_IDW_BAND_LAST 0xEBFF
+#endif // !RC_INVOKED
+
+///////////////////////////////////////////////////////////////////////////////
+// Standard Commands
+
+// File commands
+#define ID_FILE_NEW 0xE100
+#define ID_FILE_OPEN 0xE101
+#define ID_FILE_CLOSE 0xE102
+#define ID_FILE_SAVE 0xE103
+#define ID_FILE_SAVE_AS 0xE104
+#define ID_FILE_PAGE_SETUP 0xE105
+#define ID_FILE_PRINT_SETUP 0xE106
+#define ID_FILE_PRINT 0xE107
+#define ID_FILE_PRINT_DIRECT 0xE108
+#define ID_FILE_PRINT_PREVIEW 0xE109
+#define ID_FILE_UPDATE 0xE10A
+#define ID_FILE_SAVE_COPY_AS 0xE10B
+#define ID_FILE_SEND_MAIL 0xE10C
+
+#define ID_FILE_MRU_FIRST 0xE110
+#define ID_FILE_MRU_FILE1 0xE110 // range - 16 max
+#define ID_FILE_MRU_FILE2 0xE111
+#define ID_FILE_MRU_FILE3 0xE112
+#define ID_FILE_MRU_FILE4 0xE113
+#define ID_FILE_MRU_FILE5 0xE114
+#define ID_FILE_MRU_FILE6 0xE115
+#define ID_FILE_MRU_FILE7 0xE116
+#define ID_FILE_MRU_FILE8 0xE117
+#define ID_FILE_MRU_FILE9 0xE118
+#define ID_FILE_MRU_FILE10 0xE119
+#define ID_FILE_MRU_FILE11 0xE11A
+#define ID_FILE_MRU_FILE12 0xE11B
+#define ID_FILE_MRU_FILE13 0xE11C
+#define ID_FILE_MRU_FILE14 0xE11D
+#define ID_FILE_MRU_FILE15 0xE11E
+#define ID_FILE_MRU_FILE16 0xE11F
+#define ID_FILE_MRU_LAST 0xE11F
+
+// Edit commands
+#define ID_EDIT_CLEAR 0xE120
+#define ID_EDIT_CLEAR_ALL 0xE121
+#define ID_EDIT_COPY 0xE122
+#define ID_EDIT_CUT 0xE123
+#define ID_EDIT_FIND 0xE124
+#define ID_EDIT_PASTE 0xE125
+#define ID_EDIT_PASTE_LINK 0xE126
+#define ID_EDIT_PASTE_SPECIAL 0xE127
+#define ID_EDIT_REPEAT 0xE128
+#define ID_EDIT_REPLACE 0xE129
+#define ID_EDIT_SELECT_ALL 0xE12A
+#define ID_EDIT_UNDO 0xE12B
+#define ID_EDIT_REDO 0xE12C
+#define ID_EDIT_DELETE ID_EDIT_CLEAR
+#define ID_EDIT_FIND_NEXT ID_EDIT_REPEAT
+#define ID_EDIT_FIND_PREVIOUS 0xE12D
+
+// Window commands
+#define ID_WINDOW_NEW 0xE130
+#define ID_WINDOW_ARRANGE 0xE131
+#define ID_WINDOW_CASCADE 0xE132
+#define ID_WINDOW_TILE_HORZ 0xE133
+#define ID_WINDOW_TILE_VERT 0xE134
+#define ID_WINDOW_SPLIT 0xE135
+#ifndef RC_INVOKED // code only
+ #define ATL_IDM_WINDOW_FIRST 0xE130
+ #define ATL_IDM_WINDOW_LAST 0xE13F
+ #define ATL_IDM_FIRST_MDICHILD 0xFF00 // window list starts here
+ #define ATL_IDM_LAST_MDICHILD 0xFFFD
+#endif // !RC_INVOKED
+// TabView
+#define ID_WINDOW_TABFIRST 0xFF00 // = ATL_IDM_FIRST_MDICHILD
+#define ID_WINDOW_TABLAST 0xFFFD
+#define ID_WINDOW_SHOWTABLIST 0xFFFE
+
+// Help and App commands
+#define ID_APP_ABOUT 0xE140
+#define ID_APP_EXIT 0xE141
+#define ID_HELP_INDEX 0xE142
+#define ID_HELP_FINDER 0xE143
+#define ID_HELP_USING 0xE144
+#define ID_CONTEXT_HELP 0xE145 // shift-F1
+// special commands for processing help
+#define ID_HELP 0xE146 // first attempt for F1
+#define ID_DEFAULT_HELP 0xE147 // last attempt
+
+// Misc
+#define ID_NEXT_PANE 0xE150
+#define ID_PREV_PANE 0xE151
+#define ID_PANE_CLOSE 0xE152
+#define ID_PANE_NEXT ID_NEXT_PANE
+#define ID_PANE_PREVIOUS ID_PREV_PANE
+
+// Format
+#define ID_FORMAT_FONT 0xE160
+
+// Scroll
+#define ID_SCROLL_UP 0xE170
+#define ID_SCROLL_DOWN 0xE171
+#define ID_SCROLL_PAGE_UP 0xE172
+#define ID_SCROLL_PAGE_DOWN 0xE173
+#define ID_SCROLL_TOP 0xE174
+#define ID_SCROLL_BOTTOM 0xE175
+#define ID_SCROLL_LEFT 0xE176
+#define ID_SCROLL_RIGHT 0xE177
+#define ID_SCROLL_PAGE_LEFT 0xE178
+#define ID_SCROLL_PAGE_RIGHT 0xE179
+#define ID_SCROLL_ALL_LEFT 0xE17A
+#define ID_SCROLL_ALL_RIGHT 0xE17B
+
+// OLE commands
+#define ID_OLE_INSERT_NEW 0xE200
+#define ID_OLE_EDIT_LINKS 0xE201
+#define ID_OLE_EDIT_CONVERT 0xE202
+#define ID_OLE_EDIT_CHANGE_ICON 0xE203
+#define ID_OLE_EDIT_PROPERTIES 0xE204
+#define ID_OLE_VERB_FIRST 0xE210 // range - 16 max
+#ifndef RC_INVOKED // code only
+ #define ID_OLE_VERB_LAST 0xE21F
+#endif // !RC_INVOKED
+
+// View commands (same number used as IDW used for toolbar and status bar)
+#define ID_VIEW_TOOLBAR 0xE800
+#define ID_VIEW_STATUS_BAR 0xE801
+#define ID_VIEW_REFRESH 0xE803
+#define ID_VIEW_RIBBON 0xE804
+
+///////////////////////////////////////////////////////////////////////////////
+// Standard control IDs
+
+#ifdef IDC_STATIC
+ #undef IDC_STATIC
+#endif // IDC_STATIC
+#define IDC_STATIC (-1) // all static controls
+
+///////////////////////////////////////////////////////////////////////////////
+// Standard string error/warnings
+
+// idle status bar message
+#define ATL_IDS_IDLEMESSAGE 0xE001
+
+#ifndef RC_INVOKED // code only
+ #define ATL_IDS_SCFIRST 0xEF00
+#endif // !RC_INVOKED
+
+#define ATL_IDS_SCSIZE 0xEF00
+#define ATL_IDS_SCMOVE 0xEF01
+#define ATL_IDS_SCMINIMIZE 0xEF02
+#define ATL_IDS_SCMAXIMIZE 0xEF03
+#define ATL_IDS_SCNEXTWINDOW 0xEF04
+#define ATL_IDS_SCPREVWINDOW 0xEF05
+#define ATL_IDS_SCCLOSE 0xEF06
+#define ATL_IDS_SCRESTORE 0xEF12
+#define ATL_IDS_SCTASKLIST 0xEF13
+
+#define ATL_IDS_MDICHILD 0xEF1F
+#define ATL_IDS_MRU_FILE 0xEFDA
+
+///////////////////////////////////////////////////////////////////////////////
+// Misc. control IDs
+
+// Property Sheet control id's (determined with Spy++)
+#define ID_APPLY_NOW 0x3021
+#define ID_WIZBACK 0x3023
+#define ID_WIZNEXT 0x3024
+#define ID_WIZFINISH 0x3025
+#define ATL_IDC_TAB_CONTROL 0x3020
+
+#endif // __ATLRES_H__
diff --git a/Examples/WhisperDesktop/Utils/WTL/atluser.h b/Examples/WhisperDesktop/Utils/WTL/atluser.h
new file mode 100644
index 0000000..9f8e0f4
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/WTL/atluser.h
@@ -0,0 +1,1231 @@
+// Windows Template Library - WTL version 10.0
+// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved.
+//
+// This file is a part of the Windows Template Library.
+// The use and distribution terms for this software are covered by the
+// Microsoft Public License (http://opensource.org/licenses/MS-PL)
+// which can be found in the file MS-PL.txt at the root folder.
+
+#ifndef __ATLUSER_H__
+#define __ATLUSER_H__
+
+#pragma once
+
+#ifndef __ATLAPP_H__
+ #error atluser.h requires atlapp.h to be included first
+#endif
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Classes in this file:
+//
+// CMenuItemInfo
+// CMenuT<t_bManaged>
+// CAcceleratorT<t_bManaged>
+// CIconT<t_bManaged>
+// CCursorT<t_bManaged>
+// CResource
+//
+// Global functions:
+// AtlMessageBox()
+//
+// AtlLoadAccelerators()
+// AtlLoadMenu()
+// AtlLoadBitmap()
+// AtlLoadSysBitmap()
+// AtlLoadCursor()
+// AtlLoadSysCursor()
+// AtlLoadIcon()
+// AtlLoadSysIcon()
+// AtlLoadBitmapImage()
+// AtlLoadCursorImage()
+// AtlLoadIconImage()
+// AtlLoadSysBitmapImage()
+// AtlLoadSysCursorImage()
+// AtlLoadSysIconImage()
+// AtlLoadString()
+
+
+namespace WTL
+{
+
+///////////////////////////////////////////////////////////////////////////////
+// AtlMessageBox - accepts both memory and resource based strings
+
+inline int AtlMessageBox(HWND hWndOwner, ATL::_U_STRINGorID message, ATL::_U_STRINGorID title = (LPCTSTR)NULL, UINT uType = MB_OK | MB_ICONINFORMATION)
+{
+ ATLASSERT((hWndOwner == NULL) || ::IsWindow(hWndOwner));
+
+ LPTSTR lpstrMessage = NULL;
+ if(IS_INTRESOURCE(message.m_lpstr))
+ {
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ ATLTRY(lpstrMessage = new TCHAR[nLen]);
+ if(lpstrMessage == NULL)
+ {
+ ATLASSERT(FALSE);
+ return 0;
+ }
+ int nRes = ::LoadString(ModuleHelper::GetResourceInstance(), LOWORD(message.m_lpstr), lpstrMessage, nLen);
+ if(nRes < nLen - 1)
+ break;
+ delete [] lpstrMessage;
+ lpstrMessage = NULL;
+ }
+
+ message.m_lpstr = lpstrMessage;
+ }
+
+ LPTSTR lpstrTitle = NULL;
+ if(IS_INTRESOURCE(title.m_lpstr) && (LOWORD(title.m_lpstr) != 0))
+ {
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ ATLTRY(lpstrTitle = new TCHAR[nLen]);
+ if(lpstrTitle == NULL)
+ {
+ ATLASSERT(FALSE);
+ return 0;
+ }
+ int nRes = ::LoadString(ModuleHelper::GetResourceInstance(), LOWORD(title.m_lpstr), lpstrTitle, nLen);
+ if(nRes < nLen - 1)
+ break;
+ delete [] lpstrTitle;
+ lpstrTitle = NULL;
+ }
+
+ title.m_lpstr = lpstrTitle;
+ }
+
+ int nRet = ::MessageBox(hWndOwner, message.m_lpstr, title.m_lpstr, uType);
+
+ delete [] lpstrMessage;
+ delete [] lpstrTitle;
+
+ return nRet;
+}
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CMenu
+
+class CMenuItemInfo : public MENUITEMINFO
+{
+public:
+ CMenuItemInfo()
+ {
+ memset(this, 0, sizeof(MENUITEMINFO));
+ cbSize = sizeof(MENUITEMINFO);
+ }
+};
+
+
+// forward declarations
+template <bool t_bManaged> class CMenuT;
+typedef CMenuT<false> CMenuHandle;
+typedef CMenuT<true> CMenu;
+
+
+template <bool t_bManaged>
+class CMenuT
+{
+public:
+// Data members
+ HMENU m_hMenu;
+
+// Constructor/destructor/operators
+ CMenuT(HMENU hMenu = NULL) : m_hMenu(hMenu)
+ { }
+
+ ~CMenuT()
+ {
+ if(t_bManaged && (m_hMenu != NULL))
+ DestroyMenu();
+ }
+
+ CMenuT<t_bManaged>& operator =(HMENU hMenu)
+ {
+ Attach(hMenu);
+ return *this;
+ }
+
+ void Attach(HMENU hMenuNew)
+ {
+ ATLASSERT(::IsMenu(hMenuNew));
+ if(t_bManaged && (m_hMenu != NULL) && (m_hMenu != hMenuNew))
+ ::DestroyMenu(m_hMenu);
+ m_hMenu = hMenuNew;
+ }
+
+ HMENU Detach()
+ {
+ HMENU hMenu = m_hMenu;
+ m_hMenu = NULL;
+ return hMenu;
+ }
+
+ operator HMENU() const { return m_hMenu; }
+
+ bool IsNull() const { return (m_hMenu == NULL); }
+
+ BOOL IsMenu() const
+ {
+ return ::IsMenu(m_hMenu);
+ }
+
+// Create/destroy methods
+ BOOL CreateMenu()
+ {
+ ATLASSERT(m_hMenu == NULL);
+ m_hMenu = ::CreateMenu();
+ return (m_hMenu != NULL) ? TRUE : FALSE;
+ }
+
+ BOOL CreatePopupMenu()
+ {
+ ATLASSERT(m_hMenu == NULL);
+ m_hMenu = ::CreatePopupMenu();
+ return (m_hMenu != NULL) ? TRUE : FALSE;
+ }
+
+ BOOL LoadMenu(ATL::_U_STRINGorID menu)
+ {
+ ATLASSERT(m_hMenu == NULL);
+ m_hMenu = ::LoadMenu(ModuleHelper::GetResourceInstance(), menu.m_lpstr);
+ return (m_hMenu != NULL) ? TRUE : FALSE;
+ }
+
+ BOOL LoadMenuIndirect(const void* lpMenuTemplate)
+ {
+ ATLASSERT(m_hMenu == NULL);
+ m_hMenu = ::LoadMenuIndirect(lpMenuTemplate);
+ return (m_hMenu != NULL) ? TRUE : FALSE;
+ }
+
+ BOOL DestroyMenu()
+ {
+ if (m_hMenu == NULL)
+ return FALSE;
+ BOOL bRet = ::DestroyMenu(m_hMenu);
+ if(bRet)
+ m_hMenu = NULL;
+ return bRet;
+ }
+
+// Menu Operations
+ BOOL DeleteMenu(UINT nPosition, UINT nFlags)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::DeleteMenu(m_hMenu, nPosition, nFlags);
+ }
+
+ BOOL TrackPopupMenu(UINT nFlags, int x, int y, HWND hWnd, LPCRECT lpRect = NULL)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ x = _FixTrackMenuPopupX(x, y);
+ return ::TrackPopupMenu(m_hMenu, nFlags, x, y, 0, hWnd, lpRect);
+ }
+
+ BOOL TrackPopupMenuEx(UINT uFlags, int x, int y, HWND hWnd, LPTPMPARAMS lptpm = NULL)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ x = _FixTrackMenuPopupX(x, y);
+ return ::TrackPopupMenuEx(m_hMenu, uFlags, x, y, hWnd, lptpm);
+ }
+
+ // helper that fixes popup menu X position when it's off-screen
+ static int _FixTrackMenuPopupX(int x, int y)
+ {
+ POINT pt = { x, y };
+ HMONITOR hMonitor = ::MonitorFromPoint(pt, MONITOR_DEFAULTTONULL);
+ if(hMonitor == NULL)
+ {
+ HMONITOR hMonitorNear = ::MonitorFromPoint(pt, MONITOR_DEFAULTTONEAREST);
+ if(hMonitorNear != NULL)
+ {
+ MONITORINFO mi = { sizeof(MONITORINFO) };
+ if(::GetMonitorInfo(hMonitorNear, &mi) != FALSE)
+ {
+ if(x < mi.rcWork.left)
+ x = mi.rcWork.left;
+ else if(x > mi.rcWork.right)
+ x = mi.rcWork.right;
+ }
+ }
+ }
+
+ return x;
+ }
+
+ BOOL GetMenuInfo(LPMENUINFO lpMenuInfo) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuInfo(m_hMenu, lpMenuInfo);
+ }
+
+ BOOL SetMenuInfo(LPCMENUINFO lpMenuInfo)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::SetMenuInfo(m_hMenu, lpMenuInfo);
+ }
+
+// Menu Item Operations
+ BOOL AppendMenu(UINT nFlags, UINT_PTR nIDNewItem = 0, LPCTSTR lpszNewItem = NULL)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::AppendMenu(m_hMenu, nFlags, nIDNewItem, lpszNewItem);
+ }
+
+ BOOL AppendMenu(UINT nFlags, HMENU hSubMenu, LPCTSTR lpszNewItem)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ ATLASSERT(::IsMenu(hSubMenu));
+ return ::AppendMenu(m_hMenu, nFlags | MF_POPUP, (UINT_PTR)hSubMenu, lpszNewItem);
+ }
+
+ BOOL AppendMenu(UINT nFlags, UINT_PTR nIDNewItem, HBITMAP hBmp)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::AppendMenu(m_hMenu, nFlags | MF_BITMAP, nIDNewItem, (LPCTSTR)hBmp);
+ }
+
+ BOOL AppendMenu(UINT nFlags, HMENU hSubMenu, HBITMAP hBmp)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ ATLASSERT(::IsMenu(hSubMenu));
+ return ::AppendMenu(m_hMenu, nFlags | (MF_BITMAP | MF_POPUP), (UINT_PTR)hSubMenu, (LPCTSTR)hBmp);
+ }
+
+ UINT CheckMenuItem(UINT nIDCheckItem, UINT nCheck)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return (UINT)::CheckMenuItem(m_hMenu, nIDCheckItem, nCheck);
+ }
+
+ UINT EnableMenuItem(UINT nIDEnableItem, UINT nEnable)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::EnableMenuItem(m_hMenu, nIDEnableItem, nEnable);
+ }
+
+ BOOL HiliteMenuItem(HWND hWnd, UINT uIDHiliteItem, UINT uHilite)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::HiliteMenuItem(hWnd, m_hMenu, uIDHiliteItem, uHilite);
+ }
+
+ int GetMenuItemCount() const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuItemCount(m_hMenu);
+ }
+
+ UINT GetMenuItemID(int nPos) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuItemID(m_hMenu, nPos);
+ }
+
+ UINT GetMenuState(UINT nID, UINT nFlags) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuState(m_hMenu, nID, nFlags);
+ }
+
+ int GetMenuString(UINT nIDItem, LPTSTR lpString, int nMaxCount, UINT nFlags) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuString(m_hMenu, nIDItem, lpString, nMaxCount, nFlags);
+ }
+
+ int GetMenuStringLen(UINT nIDItem, UINT nFlags) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuString(m_hMenu, nIDItem, NULL, 0, nFlags);
+ }
+
+ BOOL GetMenuString(UINT nIDItem, BSTR& bstrText, UINT nFlags) const
+ {
+ USES_CONVERSION;
+ ATLASSERT(::IsMenu(m_hMenu));
+ ATLASSERT(bstrText == NULL);
+
+ int nLen = GetMenuStringLen(nIDItem, nFlags);
+ if(nLen == 0)
+ {
+ bstrText = ::SysAllocString(OLESTR(""));
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+
+ nLen++; // increment to include terminating NULL char
+ ATL::CTempBuffer<TCHAR, _WTL_STACK_ALLOC_THRESHOLD> buff;
+ LPTSTR lpszText = buff.Allocate(nLen);
+ if(lpszText == NULL)
+ return FALSE;
+
+ if(!GetMenuString(nIDItem, lpszText, nLen, nFlags))
+ return FALSE;
+
+ bstrText = ::SysAllocString(T2OLE(lpszText));
+ return (bstrText != NULL) ? TRUE : FALSE;
+ }
+
+#ifdef __ATLSTR_H__
+ int GetMenuString(UINT nIDItem, ATL::CString& strText, UINT nFlags) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+
+ int nLen = GetMenuStringLen(nIDItem, nFlags);
+ if(nLen == 0)
+ return 0;
+
+ nLen++; // increment to include terminating NULL char
+ LPTSTR lpstr = strText.GetBufferSetLength(nLen);
+ if(lpstr == NULL)
+ return 0;
+ int nRet = GetMenuString(nIDItem, lpstr, nLen, nFlags);
+ strText.ReleaseBuffer();
+ return nRet;
+ }
+#endif // __ATLSTR_H__
+
+ CMenuHandle GetSubMenu(int nPos) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return CMenuHandle(::GetSubMenu(m_hMenu, nPos));
+ }
+
+ BOOL InsertMenu(UINT nPosition, UINT nFlags, UINT_PTR nIDNewItem = 0, LPCTSTR lpszNewItem = NULL)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::InsertMenu(m_hMenu, nPosition, nFlags, nIDNewItem, lpszNewItem);
+ }
+
+ BOOL InsertMenu(UINT nPosition, UINT nFlags, HMENU hSubMenu, LPCTSTR lpszNewItem)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ ATLASSERT(::IsMenu(hSubMenu));
+ return ::InsertMenu(m_hMenu, nPosition, nFlags | MF_POPUP, (UINT_PTR)hSubMenu, lpszNewItem);
+ }
+
+ BOOL InsertMenu(UINT nPosition, UINT nFlags, UINT_PTR nIDNewItem, HBITMAP hBmp)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::InsertMenu(m_hMenu, nPosition, nFlags | MF_BITMAP, nIDNewItem, (LPCTSTR)hBmp);
+ }
+
+ BOOL InsertMenu(UINT nPosition, UINT nFlags, HMENU hSubMenu, HBITMAP hBmp)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ ATLASSERT(::IsMenu(hSubMenu));
+ return ::InsertMenu(m_hMenu, nPosition, nFlags | (MF_BITMAP | MF_POPUP), (UINT_PTR)hSubMenu, (LPCTSTR)hBmp);
+ }
+
+ BOOL ModifyMenu(UINT nPosition, UINT nFlags, UINT_PTR nIDNewItem = 0, LPCTSTR lpszNewItem = NULL)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::ModifyMenu(m_hMenu, nPosition, nFlags, nIDNewItem, lpszNewItem);
+ }
+
+ BOOL ModifyMenu(UINT nPosition, UINT nFlags, HMENU hSubMenu, LPCTSTR lpszNewItem)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ ATLASSERT(::IsMenu(hSubMenu));
+ return ::ModifyMenu(m_hMenu, nPosition, nFlags | MF_POPUP, (UINT_PTR)hSubMenu, lpszNewItem);
+ }
+
+ BOOL ModifyMenu(UINT nPosition, UINT nFlags, UINT_PTR nIDNewItem, HBITMAP hBmp)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::ModifyMenu(m_hMenu, nPosition, nFlags | MF_BITMAP, nIDNewItem, (LPCTSTR)hBmp);
+ }
+
+ BOOL ModifyMenu(UINT nPosition, UINT nFlags, HMENU hSubMenu, HBITMAP hBmp)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ ATLASSERT(::IsMenu(hSubMenu));
+ return ::ModifyMenu(m_hMenu, nPosition, nFlags | (MF_BITMAP | MF_POPUP), (UINT_PTR)hSubMenu, (LPCTSTR)hBmp);
+ }
+
+ BOOL RemoveMenu(UINT nPosition, UINT nFlags)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::RemoveMenu(m_hMenu, nPosition, nFlags);
+ }
+
+ BOOL SetMenuItemBitmaps(UINT nPosition, UINT nFlags, HBITMAP hBmpUnchecked, HBITMAP hBmpChecked)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::SetMenuItemBitmaps(m_hMenu, nPosition, nFlags, hBmpUnchecked, hBmpChecked);
+ }
+
+ BOOL CheckMenuRadioItem(UINT nIDFirst, UINT nIDLast, UINT nIDItem, UINT nFlags)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::CheckMenuRadioItem(m_hMenu, nIDFirst, nIDLast, nIDItem, nFlags);
+ }
+
+ BOOL GetMenuItemInfo(UINT uItem, BOOL bByPosition, LPMENUITEMINFO lpmii) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return (BOOL)::GetMenuItemInfo(m_hMenu, uItem, bByPosition, lpmii);
+ }
+
+ BOOL SetMenuItemInfo(UINT uItem, BOOL bByPosition, LPMENUITEMINFO lpmii)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return (BOOL)::SetMenuItemInfo(m_hMenu, uItem, bByPosition, lpmii);
+ }
+
+ BOOL InsertMenuItem(UINT uItem, BOOL bByPosition, LPMENUITEMINFO lpmii)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return (BOOL)::InsertMenuItem(m_hMenu, uItem, bByPosition, lpmii);
+ }
+
+ UINT GetMenuDefaultItem(BOOL bByPosition = FALSE, UINT uFlags = 0U) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuDefaultItem(m_hMenu, (UINT)bByPosition, uFlags);
+ }
+
+ BOOL SetMenuDefaultItem(UINT uItem = (UINT)-1, BOOL bByPosition = FALSE)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::SetMenuDefaultItem(m_hMenu, uItem, (UINT)bByPosition);
+ }
+
+ BOOL GetMenuItemRect(HWND hWnd, UINT uItem, LPRECT lprcItem) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuItemRect(hWnd, m_hMenu, uItem, lprcItem);
+ }
+
+ int MenuItemFromPoint(HWND hWnd, POINT point) const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::MenuItemFromPoint(hWnd, m_hMenu, point);
+ }
+
+// Context Help Functions
+ BOOL SetMenuContextHelpId(DWORD dwContextHelpId)
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::SetMenuContextHelpId(m_hMenu, dwContextHelpId);
+ }
+
+ DWORD GetMenuContextHelpId() const
+ {
+ ATLASSERT(::IsMenu(m_hMenu));
+ return ::GetMenuContextHelpId(m_hMenu);
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CAccelerator
+
+template <bool t_bManaged>
+class CAcceleratorT
+{
+public:
+ HACCEL m_hAccel;
+
+// Constructor/destructor/operators
+ CAcceleratorT(HACCEL hAccel = NULL) : m_hAccel(hAccel)
+ { }
+
+ ~CAcceleratorT()
+ {
+ if(t_bManaged && (m_hAccel != NULL))
+ ::DestroyAcceleratorTable(m_hAccel);
+ }
+
+ CAcceleratorT<t_bManaged>& operator =(HACCEL hAccel)
+ {
+ Attach(hAccel);
+ return *this;
+ }
+
+ void Attach(HACCEL hAccel)
+ {
+ if(t_bManaged && (m_hAccel != NULL))
+ ::DestroyAcceleratorTable(m_hAccel);
+ m_hAccel = hAccel;
+ }
+
+ HACCEL Detach()
+ {
+ HACCEL hAccel = m_hAccel;
+ m_hAccel = NULL;
+ return hAccel;
+ }
+
+ operator HACCEL() const { return m_hAccel; }
+
+ bool IsNull() const { return m_hAccel == NULL; }
+
+// Create/destroy methods
+ HACCEL LoadAccelerators(ATL::_U_STRINGorID accel)
+ {
+ ATLASSERT(m_hAccel == NULL);
+ m_hAccel = ::LoadAccelerators(ModuleHelper::GetResourceInstance(), accel.m_lpstr);
+ return m_hAccel;
+ }
+
+ HACCEL CreateAcceleratorTable(LPACCEL pAccel, int cEntries)
+ {
+ ATLASSERT(m_hAccel == NULL);
+ ATLASSERT(pAccel != NULL);
+ m_hAccel = ::CreateAcceleratorTable(pAccel, cEntries);
+ return m_hAccel;
+ }
+
+ void DestroyObject()
+ {
+ if(m_hAccel != NULL)
+ {
+ ::DestroyAcceleratorTable(m_hAccel);
+ m_hAccel = NULL;
+ }
+ }
+
+// Operations
+ int CopyAcceleratorTable(LPACCEL lpAccelDst, int cEntries)
+ {
+ ATLASSERT(m_hAccel != NULL);
+ ATLASSERT(lpAccelDst != NULL);
+ return ::CopyAcceleratorTable(m_hAccel, lpAccelDst, cEntries);
+ }
+
+ int GetEntriesCount() const
+ {
+ ATLASSERT(m_hAccel != NULL);
+ return ::CopyAcceleratorTable(m_hAccel, NULL, 0);
+ }
+
+ BOOL TranslateAccelerator(HWND hWnd, LPMSG pMsg)
+ {
+ ATLASSERT(m_hAccel != NULL);
+ ATLASSERT(::IsWindow(hWnd));
+ ATLASSERT(pMsg != NULL);
+ return ::TranslateAccelerator(hWnd, m_hAccel, pMsg);
+ }
+};
+
+typedef CAcceleratorT<false> CAcceleratorHandle;
+typedef CAcceleratorT<true> CAccelerator;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CIcon
+
+template <bool t_bManaged>
+class CIconT
+{
+public:
+ HICON m_hIcon;
+
+// Constructor/destructor/operators
+ CIconT(HICON hIcon = NULL) : m_hIcon(hIcon)
+ { }
+
+ ~CIconT()
+ {
+ if(t_bManaged && (m_hIcon != NULL))
+ ::DestroyIcon(m_hIcon);
+ }
+
+ CIconT<t_bManaged>& operator =(HICON hIcon)
+ {
+ Attach(hIcon);
+ return *this;
+ }
+
+ void Attach(HICON hIcon)
+ {
+ if(t_bManaged && (m_hIcon != NULL))
+ ::DestroyIcon(m_hIcon);
+ m_hIcon = hIcon;
+ }
+
+ HICON Detach()
+ {
+ HICON hIcon = m_hIcon;
+ m_hIcon = NULL;
+ return hIcon;
+ }
+
+ operator HICON() const { return m_hIcon; }
+
+ bool IsNull() const { return m_hIcon == NULL; }
+
+// Create/destroy methods
+ HICON LoadIcon(ATL::_U_STRINGorID icon)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ m_hIcon = ::LoadIcon(ModuleHelper::GetResourceInstance(), icon.m_lpstr);
+ return m_hIcon;
+ }
+
+ HICON LoadIcon(ATL::_U_STRINGorID icon, int cxDesired, int cyDesired, UINT fuLoad = 0)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ m_hIcon = (HICON) ::LoadImage(ModuleHelper::GetResourceInstance(), icon.m_lpstr, IMAGE_ICON, cxDesired, cyDesired, fuLoad);
+ return m_hIcon;
+ }
+
+ HICON LoadOEMIcon(LPCTSTR lpstrIconName)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(IsOEMIcon(lpstrIconName));
+ m_hIcon = ::LoadIcon(NULL, lpstrIconName);
+ return m_hIcon;
+ }
+
+ HICON CreateIcon(int nWidth, int nHeight, BYTE cPlanes, BYTE cBitsPixel, CONST BYTE* lpbANDbits, CONST BYTE *lpbXORbits)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(lpbANDbits != NULL);
+ ATLASSERT(lpbXORbits != NULL);
+ m_hIcon = ::CreateIcon(ModuleHelper::GetResourceInstance(), nWidth, nHeight, cPlanes, cBitsPixel, lpbANDbits, lpbXORbits);
+ return m_hIcon;
+ }
+
+ HICON CreateIconFromResource(PBYTE pBits, DWORD dwResSize, DWORD dwVersion = 0x00030000)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(pBits != NULL);
+ m_hIcon = ::CreateIconFromResource(pBits, dwResSize, TRUE, dwVersion);
+ return m_hIcon;
+ }
+
+ HICON CreateIconFromResourceEx(PBYTE pbBits, DWORD cbBits, DWORD dwVersion = 0x00030000, int cxDesired = 0, int cyDesired = 0, UINT uFlags = LR_DEFAULTCOLOR)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(pbBits != NULL);
+ ATLASSERT(cbBits > 0);
+ m_hIcon = ::CreateIconFromResourceEx(pbBits, cbBits, TRUE, dwVersion, cxDesired, cyDesired, uFlags);
+ return m_hIcon;
+ }
+
+ HICON CreateIconIndirect(PICONINFO pIconInfo)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(pIconInfo != NULL);
+ m_hIcon = ::CreateIconIndirect(pIconInfo);
+ return m_hIcon;
+ }
+
+ HICON ExtractIcon(LPCTSTR lpszExeFileName, UINT nIconIndex)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(lpszExeFileName != NULL);
+ m_hIcon = ::ExtractIcon(ModuleHelper::GetModuleInstance(), lpszExeFileName, nIconIndex);
+ return m_hIcon;
+ }
+
+ HICON ExtractAssociatedIcon(HINSTANCE hInst, LPTSTR lpIconPath, LPWORD lpiIcon)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(lpIconPath != NULL);
+ ATLASSERT(lpiIcon != NULL);
+ m_hIcon = ::ExtractAssociatedIcon(hInst, lpIconPath, lpiIcon);
+ return m_hIcon;
+ }
+
+ BOOL DestroyIcon()
+ {
+ ATLASSERT(m_hIcon != NULL);
+ BOOL bRet = ::DestroyIcon(m_hIcon);
+ if(bRet != FALSE)
+ m_hIcon = NULL;
+ return bRet;
+ }
+
+// Operations
+ HICON CopyIcon()
+ {
+ ATLASSERT(m_hIcon != NULL);
+ return ::CopyIcon(m_hIcon);
+ }
+
+ HICON DuplicateIcon()
+ {
+ ATLASSERT(m_hIcon != NULL);
+ return ::DuplicateIcon(NULL, m_hIcon);
+ }
+
+ BOOL DrawIcon(HDC hDC, int x, int y)
+ {
+ ATLASSERT(m_hIcon != NULL);
+ return ::DrawIcon(hDC, x, y, m_hIcon);
+ }
+
+ BOOL DrawIcon(HDC hDC, POINT pt)
+ {
+ ATLASSERT(m_hIcon != NULL);
+ return ::DrawIcon(hDC, pt.x, pt.y, m_hIcon);
+ }
+
+ BOOL DrawIconEx(HDC hDC, int x, int y, int cxWidth, int cyWidth, UINT uStepIfAniCur = 0, HBRUSH hbrFlickerFreeDraw = NULL, UINT uFlags = DI_NORMAL)
+ {
+ ATLASSERT(m_hIcon != NULL);
+ return ::DrawIconEx(hDC, x, y, m_hIcon, cxWidth, cyWidth, uStepIfAniCur, hbrFlickerFreeDraw, uFlags);
+ }
+
+ BOOL DrawIconEx(HDC hDC, POINT pt, SIZE size, UINT uStepIfAniCur = 0, HBRUSH hbrFlickerFreeDraw = NULL, UINT uFlags = DI_NORMAL)
+ {
+ ATLASSERT(m_hIcon != NULL);
+ return ::DrawIconEx(hDC, pt.x, pt.y, m_hIcon, size.cx, size.cy, uStepIfAniCur, hbrFlickerFreeDraw, uFlags);
+ }
+
+ BOOL GetIconInfo(PICONINFO pIconInfo) const
+ {
+ ATLASSERT(m_hIcon != NULL);
+ ATLASSERT(pIconInfo != NULL);
+ return ::GetIconInfo(m_hIcon, pIconInfo);
+ }
+
+#if (_WIN32_WINNT >= 0x0600)
+ BOOL GetIconInfoEx(PICONINFOEX pIconInfo) const
+ {
+ ATLASSERT(m_hIcon != NULL);
+ ATLASSERT(pIconInfo != NULL);
+ return ::GetIconInfoEx(m_hIcon, pIconInfo);
+ }
+#endif // (_WIN32_WINNT >= 0x0600)
+
+#if defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+ HRESULT LoadIconMetric(ATL::_U_STRINGorID icon, int lims)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ USES_CONVERSION;
+ return ::LoadIconMetric(ModuleHelper::GetResourceInstance(), T2CW(icon.m_lpstr), lims, &m_hIcon);
+ }
+
+ HRESULT LoadIconWithScaleDown(ATL::_U_STRINGorID icon, int cx, int cy)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ USES_CONVERSION;
+ return ::LoadIconWithScaleDown(ModuleHelper::GetResourceInstance(), T2CW(icon.m_lpstr), cx, cy, &m_hIcon);
+ }
+
+ HRESULT LoadOEMIconMetric(LPCTSTR lpstrIconName, int lims)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(IsOEMIcon(lpstrIconName));
+ return ::LoadIconMetric(NULL, (LPCWSTR)lpstrIconName, lims, &m_hIcon);
+ }
+
+ HRESULT LoadOEMIconWithScaleDown(LPCTSTR lpstrIconName, int cx, int cy)
+ {
+ ATLASSERT(m_hIcon == NULL);
+ ATLASSERT(IsOEMIcon(lpstrIconName));
+ USES_CONVERSION;
+ return ::LoadIconWithScaleDown(NULL, (LPCWSTR)lpstrIconName, cx, cy, &m_hIcon);
+ }
+#endif // defined(NTDDI_VERSION) && (NTDDI_VERSION >= NTDDI_LONGHORN)
+
+ // Helper
+ static bool IsOEMIcon(LPCTSTR lpstrIconName)
+ {
+#if (WINVER >= 0x0600)
+ return ((lpstrIconName == IDI_APPLICATION) || (lpstrIconName == IDI_ASTERISK) || (lpstrIconName == IDI_EXCLAMATION) ||
+ (lpstrIconName == IDI_HAND) || (lpstrIconName == IDI_QUESTION) || (lpstrIconName == IDI_WINLOGO) ||
+ (lpstrIconName == IDI_SHIELD));
+#else // !(WINVER >= 0x0600)
+ return ((lpstrIconName == IDI_APPLICATION) || (lpstrIconName == IDI_ASTERISK) || (lpstrIconName == IDI_EXCLAMATION) ||
+ (lpstrIconName == IDI_HAND) || (lpstrIconName == IDI_QUESTION) || (lpstrIconName == IDI_WINLOGO));
+#endif // !(WINVER >= 0x0600)
+ }
+};
+
+typedef CIconT<false> CIconHandle;
+typedef CIconT<true> CIcon;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CCursor
+
+// protect template member from a winuser.h macro
+#ifdef CopyCursor
+ #undef CopyCursor
+#endif
+
+template <bool t_bManaged>
+class CCursorT
+{
+public:
+ HCURSOR m_hCursor;
+
+// Constructor/destructor/operators
+ CCursorT(HCURSOR hCursor = NULL) : m_hCursor(hCursor)
+ { }
+
+ ~CCursorT()
+ {
+ if(t_bManaged && (m_hCursor != NULL))
+ DestroyCursor();
+ }
+
+ CCursorT<t_bManaged>& operator =(HCURSOR hCursor)
+ {
+ Attach(hCursor);
+ return *this;
+ }
+
+ void Attach(HCURSOR hCursor)
+ {
+ if(t_bManaged && (m_hCursor != NULL))
+ DestroyCursor();
+ m_hCursor = hCursor;
+ }
+
+ HCURSOR Detach()
+ {
+ HCURSOR hCursor = m_hCursor;
+ m_hCursor = NULL;
+ return hCursor;
+ }
+
+ operator HCURSOR() const { return m_hCursor; }
+
+ bool IsNull() const { return m_hCursor == NULL; }
+
+// Create/destroy methods
+ HCURSOR LoadCursor(ATL::_U_STRINGorID cursor)
+ {
+ ATLASSERT(m_hCursor == NULL);
+ m_hCursor = ::LoadCursor(ModuleHelper::GetResourceInstance(), cursor.m_lpstr);
+ return m_hCursor;
+ }
+
+ HCURSOR LoadSysCursor(LPCTSTR lpstrCursorName)
+ {
+ ATLASSERT(m_hCursor == NULL);
+ ATLASSERT((lpstrCursorName == IDC_ARROW) || (lpstrCursorName == IDC_IBEAM) || (lpstrCursorName == IDC_WAIT) ||
+ (lpstrCursorName == IDC_CROSS) || (lpstrCursorName == IDC_UPARROW) || (lpstrCursorName == IDC_SIZE) ||
+ (lpstrCursorName == IDC_ICON) || (lpstrCursorName == IDC_SIZENWSE) || (lpstrCursorName == IDC_SIZENESW) ||
+ (lpstrCursorName == IDC_SIZEWE) || (lpstrCursorName == IDC_SIZENS) || (lpstrCursorName == IDC_SIZEALL) ||
+ (lpstrCursorName == IDC_NO) || (lpstrCursorName == IDC_APPSTARTING) || (lpstrCursorName == IDC_HELP) ||
+ (lpstrCursorName == IDC_HAND));
+ m_hCursor = ::LoadCursor(NULL, lpstrCursorName);
+ return m_hCursor;
+ }
+
+ // deprecated
+ HCURSOR LoadOEMCursor(LPCTSTR lpstrCursorName)
+ {
+ return LoadSysCursor(lpstrCursorName);
+ }
+
+ HCURSOR LoadCursor(ATL::_U_STRINGorID cursor, int cxDesired, int cyDesired, UINT fuLoad = 0)
+ {
+ ATLASSERT(m_hCursor == NULL);
+ m_hCursor = (HCURSOR) ::LoadImage(ModuleHelper::GetResourceInstance(), cursor.m_lpstr, IMAGE_CURSOR, cxDesired, cyDesired, fuLoad);
+ return m_hCursor;
+ }
+
+ HCURSOR LoadCursorFromFile(LPCTSTR pstrFilename)
+ {
+ ATLASSERT(m_hCursor == NULL);
+ ATLASSERT(pstrFilename != NULL);
+ m_hCursor = ::LoadCursorFromFile(pstrFilename);
+ return m_hCursor;
+ }
+
+ HCURSOR CreateCursor(int xHotSpot, int yHotSpot, int nWidth, int nHeight, CONST VOID *pvANDPlane, CONST VOID *pvXORPlane)
+ {
+ ATLASSERT(m_hCursor == NULL);
+ m_hCursor = ::CreateCursor(ModuleHelper::GetResourceInstance(), xHotSpot, yHotSpot, nWidth, nHeight, pvANDPlane, pvXORPlane);
+ return m_hCursor;
+ }
+
+ HCURSOR CreateCursorFromResource(PBYTE pBits, DWORD dwResSize, DWORD dwVersion = 0x00030000)
+ {
+ ATLASSERT(m_hCursor == NULL);
+ ATLASSERT(pBits != NULL);
+ m_hCursor = (HCURSOR)::CreateIconFromResource(pBits, dwResSize, FALSE, dwVersion);
+ return m_hCursor;
+ }
+
+ HCURSOR CreateCursorFromResourceEx(PBYTE pbBits, DWORD cbBits, DWORD dwVersion = 0x00030000, int cxDesired = 0, int cyDesired = 0, UINT uFlags = LR_DEFAULTCOLOR)
+ {
+ ATLASSERT(m_hCursor == NULL);
+ ATLASSERT(pbBits != NULL);
+ ATLASSERT(cbBits > 0);
+ m_hCursor = (HCURSOR)::CreateIconFromResourceEx(pbBits, cbBits, FALSE, dwVersion, cxDesired, cyDesired, uFlags);
+ return m_hCursor;
+ }
+
+ BOOL DestroyCursor()
+ {
+ ATLASSERT(m_hCursor != NULL);
+ BOOL bRet = ::DestroyCursor(m_hCursor);
+ if(bRet != FALSE)
+ m_hCursor = NULL;
+ return bRet;
+ }
+
+// Operations
+ HCURSOR CopyCursor()
+ {
+ ATLASSERT(m_hCursor != NULL);
+ return (HCURSOR)::CopyIcon((HICON)m_hCursor);
+ }
+
+ BOOL GetCursorInfo(LPCURSORINFO pCursorInfo)
+ {
+ ATLASSERT(m_hCursor != NULL);
+ ATLASSERT(pCursorInfo != NULL);
+ return ::GetCursorInfo(pCursorInfo);
+ }
+};
+
+typedef CCursorT<false> CCursorHandle;
+typedef CCursorT<true> CCursor;
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CResource - Wraps a generic Windows resource.
+// Use it with custom resource types other than the
+// standard RT_CURSOR, RT_BITMAP, etc.
+
+class CResource
+{
+public:
+ HGLOBAL m_hGlobal;
+ HRSRC m_hResource;
+
+// Constructor/destructor
+ CResource() : m_hGlobal(NULL), m_hResource(NULL)
+ { }
+
+ ~CResource()
+ {
+ Release();
+ }
+
+// Load methods
+ bool Load(ATL::_U_STRINGorID Type, ATL::_U_STRINGorID ID)
+ {
+ ATLASSERT(m_hResource == NULL);
+ ATLASSERT(m_hGlobal == NULL);
+
+ m_hResource = ::FindResource(ModuleHelper::GetResourceInstance(), ID.m_lpstr, Type.m_lpstr);
+ if(m_hResource == NULL)
+ return false;
+
+ m_hGlobal = ::LoadResource(ModuleHelper::GetResourceInstance(), m_hResource);
+ if(m_hGlobal == NULL)
+ {
+ m_hResource = NULL;
+ return false;
+ }
+
+ return true;
+ }
+
+ bool LoadEx(ATL::_U_STRINGorID ID, ATL::_U_STRINGorID Type, WORD wLanguage)
+ {
+ ATLASSERT(m_hResource == NULL);
+ ATLASSERT(m_hGlobal == NULL);
+
+ m_hResource = ::FindResourceEx(ModuleHelper::GetResourceInstance(), Type.m_lpstr, ID.m_lpstr, wLanguage);
+ if(m_hResource == NULL)
+ return false;
+
+ m_hGlobal = ::LoadResource(ModuleHelper::GetResourceInstance(), m_hResource);
+ if(m_hGlobal == NULL)
+ {
+ m_hResource = NULL;
+ return false;
+ }
+
+ return true;
+ }
+
+// Misc. operations
+ DWORD GetSize() const
+ {
+ ATLASSERT(m_hResource != NULL);
+ return ::SizeofResource(ModuleHelper::GetResourceInstance(), m_hResource);
+ }
+
+ LPVOID Lock()
+ {
+ ATLASSERT(m_hResource != NULL);
+ ATLASSERT(m_hGlobal != NULL);
+ LPVOID pVoid = ::LockResource(m_hGlobal);
+ ATLASSERT(pVoid != NULL);
+ return pVoid;
+ }
+
+ void Release()
+ {
+ if(m_hGlobal != NULL)
+ {
+ FreeResource(m_hGlobal);
+ m_hGlobal = NULL;
+ m_hResource = NULL;
+ }
+ }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Toolbar resource descriptor
+
+struct _AtlToolBarData
+{
+ WORD wVersion;
+ WORD wWidth;
+ WORD wHeight;
+ WORD wItemCount;
+
+ WORD* items()
+ { return (WORD*)(this+1); }
+};
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Global functions for loading resources
+
+inline HACCEL AtlLoadAccelerators(ATL::_U_STRINGorID table)
+{
+ return ::LoadAccelerators(ModuleHelper::GetResourceInstance(), table.m_lpstr);
+}
+
+inline HMENU AtlLoadMenu(ATL::_U_STRINGorID menu)
+{
+ return ::LoadMenu(ModuleHelper::GetResourceInstance(), menu.m_lpstr);
+}
+
+inline HBITMAP AtlLoadBitmap(ATL::_U_STRINGorID bitmap)
+{
+ return ::LoadBitmap(ModuleHelper::GetResourceInstance(), bitmap.m_lpstr);
+}
+
+#ifdef OEMRESOURCE
+inline HBITMAP AtlLoadSysBitmap(ATL::_U_STRINGorID bitmap)
+{
+#ifdef _DEBUG
+ WORD wID = LOWORD(bitmap.m_lpstr);
+ ATLASSERT((wID >= 32734) && (wID <= 32767));
+#endif // _DEBUG
+ return ::LoadBitmap(NULL, bitmap.m_lpstr);
+}
+#endif // OEMRESOURCE
+
+inline HCURSOR AtlLoadCursor(ATL::_U_STRINGorID cursor)
+{
+ return ::LoadCursor(ModuleHelper::GetResourceInstance(), cursor.m_lpstr);
+}
+
+inline HCURSOR AtlLoadSysCursor(LPCTSTR lpCursorName)
+{
+ ATLASSERT((lpCursorName == IDC_ARROW) || (lpCursorName == IDC_IBEAM) || (lpCursorName == IDC_WAIT) ||
+ (lpCursorName == IDC_CROSS) || (lpCursorName == IDC_UPARROW) || (lpCursorName == IDC_SIZE) ||
+ (lpCursorName == IDC_ICON) || (lpCursorName == IDC_SIZENWSE) || (lpCursorName == IDC_SIZENESW) ||
+ (lpCursorName == IDC_SIZEWE) || (lpCursorName == IDC_SIZENS) || (lpCursorName == IDC_SIZEALL) ||
+ (lpCursorName == IDC_NO) || (lpCursorName == IDC_APPSTARTING) || (lpCursorName == IDC_HELP) ||
+ (lpCursorName == IDC_HAND));
+ return ::LoadCursor(NULL, lpCursorName);
+}
+
+inline HICON AtlLoadIcon(ATL::_U_STRINGorID icon)
+{
+ return ::LoadIcon(ModuleHelper::GetResourceInstance(), icon.m_lpstr);
+}
+
+inline HICON AtlLoadSysIcon(LPCTSTR lpIconName)
+{
+#if (WINVER >= 0x0600)
+ ATLASSERT((lpIconName == IDI_APPLICATION) || (lpIconName == IDI_ASTERISK) || (lpIconName == IDI_EXCLAMATION) ||
+ (lpIconName == IDI_HAND) || (lpIconName == IDI_QUESTION) || (lpIconName == IDI_WINLOGO) ||
+ (lpIconName == IDI_SHIELD));
+#else // !(WINVER >= 0x0600)
+ ATLASSERT((lpIconName == IDI_APPLICATION) || (lpIconName == IDI_ASTERISK) || (lpIconName == IDI_EXCLAMATION) ||
+ (lpIconName == IDI_HAND) || (lpIconName == IDI_QUESTION) || (lpIconName == IDI_WINLOGO));
+#endif // !(WINVER >= 0x0600)
+ return ::LoadIcon(NULL, lpIconName);
+}
+
+inline HBITMAP AtlLoadBitmapImage(ATL::_U_STRINGorID bitmap, UINT fuLoad = LR_DEFAULTCOLOR)
+{
+ return (HBITMAP)::LoadImage(ModuleHelper::GetResourceInstance(), bitmap.m_lpstr, IMAGE_BITMAP, 0, 0, fuLoad);
+}
+
+inline HCURSOR AtlLoadCursorImage(ATL::_U_STRINGorID cursor, UINT fuLoad = LR_DEFAULTCOLOR | LR_DEFAULTSIZE, int cxDesired = 0, int cyDesired = 0)
+{
+ return (HCURSOR)::LoadImage(ModuleHelper::GetResourceInstance(), cursor.m_lpstr, IMAGE_CURSOR, cxDesired, cyDesired, fuLoad);
+}
+
+inline HICON AtlLoadIconImage(ATL::_U_STRINGorID icon, UINT fuLoad = LR_DEFAULTCOLOR | LR_DEFAULTSIZE, int cxDesired = 0, int cyDesired = 0)
+{
+ return (HICON)::LoadImage(ModuleHelper::GetResourceInstance(), icon.m_lpstr, IMAGE_ICON, cxDesired, cyDesired, fuLoad);
+}
+
+#ifdef OEMRESOURCE
+inline HBITMAP AtlLoadSysBitmapImage(WORD wBitmapID, UINT fuLoad = LR_DEFAULTCOLOR)
+{
+ ATLASSERT((wBitmapID >= 32734) && (wBitmapID <= 32767));
+ ATLASSERT((fuLoad & LR_LOADFROMFILE) == 0); // this one doesn't load from a file
+ return (HBITMAP)::LoadImage(NULL, MAKEINTRESOURCE(wBitmapID), IMAGE_BITMAP, 0, 0, fuLoad);
+}
+#endif // OEMRESOURCE
+
+inline HCURSOR AtlLoadSysCursorImage(ATL::_U_STRINGorID cursor, UINT fuLoad = LR_DEFAULTCOLOR | LR_DEFAULTSIZE, int cxDesired = 0, int cyDesired = 0)
+{
+#ifdef _DEBUG
+ WORD wID = LOWORD(cursor.m_lpstr);
+ ATLASSERT(((wID >= 32512) && (wID <= 32516)) || ((wID >= 32640) && (wID <= 32648)) || (wID == 32650) || (wID == 32651));
+ ATLASSERT((fuLoad & LR_LOADFROMFILE) == 0); // this one doesn't load from a file
+#endif // _DEBUG
+ return (HCURSOR)::LoadImage(NULL, cursor.m_lpstr, IMAGE_CURSOR, cxDesired, cyDesired, fuLoad);
+}
+
+inline HICON AtlLoadSysIconImage(ATL::_U_STRINGorID icon, UINT fuLoad = LR_DEFAULTCOLOR | LR_DEFAULTSIZE, int cxDesired = 0, int cyDesired = 0)
+{
+#ifdef _DEBUG
+ WORD wID = LOWORD(icon.m_lpstr);
+ ATLASSERT((wID >= 32512) && (wID <= 32517));
+ ATLASSERT((fuLoad & LR_LOADFROMFILE) == 0); // this one doesn't load from a file
+#endif // _DEBUG
+ return (HICON)::LoadImage(NULL, icon.m_lpstr, IMAGE_ICON, cxDesired, cyDesired, fuLoad);
+}
+
+inline bool AtlLoadString(UINT uID, BSTR& bstrText)
+{
+ USES_CONVERSION;
+ ATLASSERT(bstrText == NULL);
+
+ LPTSTR lpstrText = NULL;
+ int nRes = 0;
+ for(int nLen = 256; ; nLen *= 2)
+ {
+ ATLTRY(lpstrText = new TCHAR[nLen]);
+ if(lpstrText == NULL)
+ break;
+ nRes = ::LoadString(ModuleHelper::GetResourceInstance(), uID, lpstrText, nLen);
+ if(nRes < nLen - 1)
+ break;
+ delete [] lpstrText;
+ lpstrText = NULL;
+ }
+
+ if(lpstrText != NULL)
+ {
+ if(nRes != 0)
+ bstrText = ::SysAllocString(T2OLE(lpstrText));
+ delete [] lpstrText;
+ }
+
+ return (bstrText != NULL) ? true : false;
+}
+
+} // namespace WTL
+
+#endif // __ATLUSER_H__
diff --git a/Examples/WhisperDesktop/Utils/WTL/atlwinx.h b/Examples/WhisperDesktop/Utils/WTL/atlwinx.h
new file mode 100644
index 0000000..b89c513
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/WTL/atlwinx.h
@@ -0,0 +1,623 @@
+// Windows Template Library - WTL version 10.0
+// Copyright (C) Microsoft Corporation, WTL Team. All rights reserved.
+//
+// This file is a part of the Windows Template Library.
+// The use and distribution terms for this software are covered by the
+// Microsoft Public License (http://opensource.org/licenses/MS-PL)
+// which can be found in the file MS-PL.txt at the root folder.
+
+#ifndef __ATLWINX_H__
+#define __ATLWINX_H__
+
+#pragma once
+
+#ifndef __ATLAPP_H__
+ #error atlwinx.h requires atlapp.h to be included first
+#endif
+
+#include <atlwin.h>
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Classes in this file:
+//
+// CWindowEx
+
+
+/////////////////////////////////////////////////////////////////////////////
+// Additional macros needed for template classes
+
+#ifndef DECLARE_WND_CLASS_EX2
+ #define DECLARE_WND_CLASS_EX2(WndClassName, EnclosingClass, style, bkgnd) \
+ static ATL::CWndClassInfo& GetWndClassInfo() \
+ { \
+ static ATL::CWndClassInfo wc = \
+ { \
+ { sizeof(WNDCLASSEX), style, EnclosingClass::StartWindowProc, \
+ 0, 0, NULL, NULL, NULL, (HBRUSH)(bkgnd + 1), NULL, WndClassName, NULL }, \
+ NULL, NULL, IDC_ARROW, TRUE, 0, _T("") \
+ }; \
+ return wc; \
+ }
+#endif // DECLARE_WND_CLASS_EX2
+
+#ifndef DECLARE_WND_SUPERCLASS2
+ #define DECLARE_WND_SUPERCLASS2(WndClassName, EnclosingClass, OrigWndClassName) \
+ static ATL::CWndClassInfo& GetWndClassInfo() \
+ { \
+ static ATL::CWndClassInfo wc = \
+ { \
+ { sizeof(WNDCLASSEX), 0, EnclosingClass::StartWindowProc, \
+ 0, 0, NULL, NULL, NULL, NULL, NULL, WndClassName, NULL }, \
+ OrigWndClassName, NULL, NULL, TRUE, 0, _T("") \
+ }; \
+ return wc; \
+ }
+#endif // DECLARE_WND_SUPERCLASS2
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Command Chaining Macros
+
+#define CHAIN_COMMANDS(theChainClass) \
+ if(uMsg == WM_COMMAND) \
+ CHAIN_MSG_MAP(theChainClass)
+
+#define CHAIN_COMMANDS_ALT(theChainClass, msgMapID) \
+ if(uMsg == WM_COMMAND) \
+ CHAIN_MSG_MAP_ALT(theChainClass, msgMapID)
+
+#define CHAIN_COMMANDS_MEMBER(theChainMember) \
+ if(uMsg == WM_COMMAND) \
+ CHAIN_MSG_MAP_MEMBER(theChainMember)
+
+#define CHAIN_COMMANDS_ALT_MEMBER(theChainMember, msgMapID) \
+ if(uMsg == WM_COMMAND) \
+ CHAIN_MSG_MAP_ALT_MEMBER(theChainMember, msgMapID)
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Macros for parent message map to selectively reflect control messages
+
+// NOTE: ReflectNotifications is a member of ATL's CWindowImplRoot
+// (and overridden in 2 cases - CContainedWindowT and CAxHostWindow)
+// Since we can't modify ATL, we'll provide the needed additions
+// in a separate function (that is not a member of CWindowImplRoot)
+
+namespace WTL
+{
+
+inline LRESULT WtlReflectNotificationsFiltered(HWND hWndParent, UINT uMsg, WPARAM wParam, LPARAM lParam, BOOL& bHandled,
+ UINT uMsgFilter = WM_NULL, UINT_PTR idFromFilter = 0, HWND hWndChildFilter = NULL)
+{
+ if((uMsgFilter != WM_NULL) && (uMsgFilter != uMsg))
+ {
+ // The notification message doesn't match the filter.
+ bHandled = FALSE;
+ return 1;
+ }
+
+ HWND hWndChild = NULL;
+ UINT_PTR idFrom = 0;
+
+ switch(uMsg)
+ {
+ case WM_COMMAND:
+ if(lParam != NULL) // not from a menu
+ {
+ hWndChild = (HWND)lParam;
+ idFrom = (UINT_PTR)LOWORD(wParam);
+ }
+ break;
+ case WM_NOTIFY:
+ hWndChild = ((LPNMHDR)lParam)->hwndFrom;
+ idFrom = ((LPNMHDR)lParam)->idFrom;
+ break;
+ case WM_PARENTNOTIFY:
+ switch(LOWORD(wParam))
+ {
+ case WM_CREATE:
+ case WM_DESTROY:
+ hWndChild = (HWND)lParam;
+ idFrom = (UINT_PTR)HIWORD(wParam);
+ break;
+ default:
+ hWndChild = ::GetDlgItem(hWndParent, HIWORD(wParam));
+ idFrom = (UINT_PTR)::GetDlgCtrlID(hWndChild);
+ break;
+ }
+ break;
+ case WM_DRAWITEM:
+ if(wParam) // not from a menu
+ {
+ hWndChild = ((LPDRAWITEMSTRUCT)lParam)->hwndItem;
+ idFrom = (UINT_PTR)wParam;
+ }
+ break;
+ case WM_MEASUREITEM:
+ if(wParam) // not from a menu
+ {
+ hWndChild = ::GetDlgItem(hWndParent, ((LPMEASUREITEMSTRUCT)lParam)->CtlID);
+ idFrom = (UINT_PTR)wParam;
+ }
+ break;
+ case WM_COMPAREITEM:
+ if(wParam) // not from a menu
+ {
+ hWndChild = ((LPCOMPAREITEMSTRUCT)lParam)->hwndItem;
+ idFrom = (UINT_PTR)wParam;
+ }
+ break;
+ case WM_DELETEITEM:
+ if(wParam) // not from a menu
+ {
+ hWndChild = ((LPDELETEITEMSTRUCT)lParam)->hwndItem;
+ idFrom = (UINT_PTR)wParam;
+ }
+ break;
+ case WM_VKEYTOITEM:
+ case WM_CHARTOITEM:
+ case WM_HSCROLL:
+ case WM_VSCROLL:
+ case WM_CTLCOLORBTN:
+ case WM_CTLCOLORDLG:
+ case WM_CTLCOLOREDIT:
+ case WM_CTLCOLORLISTBOX:
+ case WM_CTLCOLORMSGBOX:
+ case WM_CTLCOLORSCROLLBAR:
+ case WM_CTLCOLORSTATIC:
+ hWndChild = (HWND)lParam;
+ idFrom = (UINT_PTR)::GetDlgCtrlID(hWndChild);
+ break;
+ default:
+ break;
+ }
+
+ if((hWndChild == NULL) ||
+ ((hWndChildFilter != NULL) && (hWndChildFilter != hWndChild)))
+ {
+ // Either hWndChild isn't valid, or
+ // hWndChild doesn't match the filter.
+ bHandled = FALSE;
+ return 1;
+ }
+
+ if((idFromFilter != 0) && (idFromFilter != idFrom))
+ {
+ // The dialog control id doesn't match the filter.
+ bHandled = FALSE;
+ return 1;
+ }
+
+ ATLASSERT(::IsWindow(hWndChild));
+ LRESULT lResult = ::SendMessage(hWndChild, OCM__BASE + uMsg, wParam, lParam);
+ if((lResult == 0) && (uMsg >= WM_CTLCOLORMSGBOX) && (uMsg <= WM_CTLCOLORSTATIC))
+ {
+ // Try to prevent problems with WM_CTLCOLOR* messages when
+ // the message wasn't really handled
+ bHandled = FALSE;
+ }
+
+ return lResult;
+}
+
+} // namespace WTL
+
+// Try to prevent problems with WM_CTLCOLOR* messages when
+// the message wasn't really handled
+#define REFLECT_NOTIFICATIONS_EX() \
+{ \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if((lResult == 0) && (uMsg >= WM_CTLCOLORMSGBOX) && (uMsg <= WM_CTLCOLORSTATIC)) \
+ bHandled = FALSE; \
+ if(bHandled) \
+ return TRUE; \
+}
+
+#define REFLECT_NOTIFICATIONS_MSG_FILTERED(uMsgFilter) \
+ { \
+ bHandled = TRUE; \
+ lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, uMsgFilter, 0, NULL); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFICATIONS_ID_FILTERED(idFromFilter) \
+ { \
+ bHandled = TRUE; \
+ lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, WM_NULL, idFromFilter, NULL); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFICATIONS_HWND_FILTERED(hWndChildFilter) \
+ { \
+ bHandled = TRUE; \
+ lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, WM_NULL, 0, hWndChildFilter); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFICATIONS_MSG_ID_FILTERED(uMsgFilter, idFromFilter) \
+ { \
+ bHandled = TRUE; \
+ lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, uMsgFilter, idFromFilter, NULL); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFICATIONS_MSG_HWND_FILTERED(uMsgFilter, hWndChildFilter) \
+ { \
+ bHandled = TRUE; \
+ lResult = WTL::WtlReflectNotificationsFiltered(this->m_hWnd, uMsg, wParam, lParam, bHandled, uMsgFilter, 0, hWndChildFilter); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_COMMAND(id, code) \
+ if((uMsg == WM_COMMAND) && (id == LOWORD(wParam)) && (code == HIWORD(wParam))) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_COMMAND_ID(id) \
+ if((uMsg == WM_COMMAND) && (id == LOWORD(wParam))) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_COMMAND_CODE(code) \
+ if((uMsg == WM_COMMAND) && (code == HIWORD(wParam))) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_COMMAND_RANGE(idFirst, idLast) \
+ if((uMsg == WM_COMMAND) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_COMMAND_RANGE_CODE(idFirst, idLast, code) \
+ if((uMsg == WM_COMMAND) && (code == HIWORD(wParam)) && (LOWORD(wParam) >= idFirst) && (LOWORD(wParam) <= idLast)) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFY(id, cd) \
+ if((uMsg == WM_NOTIFY) && (id == ((LPNMHDR)lParam)->idFrom) && (cd == ((LPNMHDR)lParam)->code)) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFY_ID(id) \
+ if((uMsg == WM_NOTIFY) && (id == ((LPNMHDR)lParam)->idFrom)) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFY_CODE(cd) \
+ if((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code)) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFY_RANGE(idFirst, idLast) \
+ if((uMsg == WM_NOTIFY) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+#define REFLECT_NOTIFY_RANGE_CODE(idFirst, idLast, cd) \
+ if((uMsg == WM_NOTIFY) && (cd == ((LPNMHDR)lParam)->code) && (((LPNMHDR)lParam)->idFrom >= idFirst) && (((LPNMHDR)lParam)->idFrom <= idLast)) \
+ { \
+ bHandled = TRUE; \
+ lResult = this->ReflectNotifications(uMsg, wParam, lParam, bHandled); \
+ if(bHandled) \
+ return TRUE; \
+ }
+
+
+///////////////////////////////////////////////////////////////////////////////
+// GetClassLong/SetClassLong redefinition to avoid problems with class members
+
+#ifdef SetClassLongPtrA
+ #undef SetClassLongPtrA
+ inline LONG_PTR SetClassLongPtrA(HWND hWnd, int nIndex, LONG_PTR dwNewLong)
+ {
+ return ::SetClassLongA(hWnd, nIndex, LONG(dwNewLong));
+ }
+#endif
+
+#ifdef SetClassLongPtrW
+ #undef SetClassLongPtrW
+ inline LONG_PTR SetClassLongPtrW(HWND hWnd, int nIndex, LONG_PTR dwNewLong)
+ {
+ return ::SetClassLongW(hWnd, nIndex, LONG(dwNewLong));
+ }
+#endif
+
+#ifdef GetClassLongPtrA
+ #undef GetClassLongPtrA
+ inline LONG_PTR GetClassLongPtrA(HWND hWnd, int nIndex)
+ {
+ return ::GetClassLongA(hWnd, nIndex);
+ }
+#endif
+
+#ifdef GetClassLongPtrW
+ #undef GetClassLongPtrW
+ inline LONG_PTR GetClassLongPtrW(HWND hWnd, int nIndex)
+ {
+ return ::GetClassLongW(hWnd, nIndex);
+ }
+#endif
+
+
+///////////////////////////////////////////////////////////////////////////////
+// CWindowEx - extension of ATL::CWindow
+
+namespace WTL
+{
+
+class CWindowEx : public ATL::CWindow
+{
+public:
+ CWindowEx(HWND hWnd = NULL) : ATL::CWindow(hWnd)
+ { }
+
+ CWindowEx& operator =(HWND hWnd)
+ {
+ m_hWnd = hWnd;
+ return *this;
+ }
+
+ operator HWND() const
+ {
+ return m_hWnd;
+ }
+
+// Methods
+ BOOL PrintWindow(HDC hDC, UINT uFlags = 0)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ return ::PrintWindow(m_hWnd, hDC, uFlags);
+ }
+
+ BOOL DragDetect(POINT pt)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ return ::DragDetect(m_hWnd, pt);
+ }
+
+ BOOL DragDetect()
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ POINT pt = {};
+ ::GetCursorPos(&pt);
+ return ::DragDetect(m_hWnd, pt);
+ }
+
+ CWindowEx GetAncestor(UINT uFlags) const
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ return CWindowEx(::GetAncestor(m_hWnd, uFlags));
+ }
+
+ // Note: Does not work properly on Vista Aero and above
+ BOOL AnimateWindow(DWORD dwFlags, DWORD dwTime = 200)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ return ::AnimateWindow(m_hWnd, dwTime, dwFlags);
+ }
+
+ BOOL FlashWindowEx(DWORD dwFlags, UINT uCount, DWORD dwTimeout = 0)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ FLASHWINFO fi = { sizeof(FLASHWINFO) };
+ fi.hwnd = m_hWnd;
+ fi.dwFlags = dwFlags;
+ fi.uCount = uCount;
+ fi.dwTimeout = dwTimeout;
+ return ::FlashWindowEx(&fi);
+ }
+
+ BOOL StopFlashWindowEx()
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ FLASHWINFO fi = { sizeof(FLASHWINFO) };
+ fi.hwnd = m_hWnd;
+ fi.dwFlags = FLASHW_STOP;
+ return ::FlashWindowEx(&fi);
+ }
+
+// Class long properties
+ DWORD GetClassLong(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ return ::GetClassLong(m_hWnd, nIndex);
+ }
+
+ DWORD SetClassLong(int nIndex, LONG dwNewLong)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ return ::SetClassLong(m_hWnd, nIndex, dwNewLong);
+ }
+
+ ULONG_PTR GetClassLongPtr(int nIndex) const
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ return ::GetClassLongPtr(m_hWnd, nIndex);
+ }
+
+ ULONG_PTR SetClassLongPtr(int nIndex, LONG_PTR dwNewLong)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ return ::SetClassLongPtr(m_hWnd, nIndex, dwNewLong);
+ }
+
+// Layered windows
+ BOOL SetLayeredWindowAttributes(COLORREF crlKey, BYTE byteAlpha, DWORD dwFlags)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ ATLASSERT((GetExStyle() & WS_EX_LAYERED) != 0);
+
+ return ::SetLayeredWindowAttributes(m_hWnd, crlKey, byteAlpha, dwFlags);
+ }
+
+ BOOL UpdateLayeredWindow(HDC hdcDst, LPPOINT pptDst, LPSIZE psize, HDC hdcSrc, LPPOINT pptSrc, COLORREF crlKey, BLENDFUNCTION* pblend, DWORD dwFlags)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ ATLASSERT((GetExStyle() & WS_EX_LAYERED) != 0);
+
+ return ::UpdateLayeredWindow(m_hWnd, hdcDst, pptDst, psize, hdcSrc, pptSrc, crlKey, pblend, dwFlags);
+ }
+
+ BOOL UpdateLayeredWindow(LPPOINT pptDst = NULL)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ ATLASSERT((GetExStyle() & WS_EX_LAYERED) != 0);
+
+ return ::UpdateLayeredWindow(m_hWnd, NULL, pptDst, NULL, NULL, NULL, CLR_NONE, NULL, 0);
+ }
+
+ BOOL GetLayeredWindowAttributes(COLORREF* pcrlKey, BYTE* pbyteAlpha, DWORD* pdwFlags) const
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+ ATLASSERT((GetExStyle() & WS_EX_LAYERED) != 0);
+
+ return ::GetLayeredWindowAttributes(m_hWnd, pcrlKey, pbyteAlpha, pdwFlags);
+ }
+
+// Mouse tracking
+ BOOL StartTrackMouseLeave()
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ TRACKMOUSEEVENT tme = {};
+ tme.cbSize = sizeof(TRACKMOUSEEVENT);
+ tme.dwFlags = TME_LEAVE;
+ tme.hwndTrack = m_hWnd;
+ return ::TrackMouseEvent(&tme);
+ }
+
+ BOOL StartTrackMouse(DWORD dwFlags, DWORD dwHoverTime = HOVER_DEFAULT)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ TRACKMOUSEEVENT tme = {};
+ tme.cbSize = sizeof(TRACKMOUSEEVENT);
+ tme.dwFlags = dwFlags;
+ tme.hwndTrack = m_hWnd;
+ tme.dwHoverTime = dwHoverTime;
+ return ::TrackMouseEvent(&tme);
+ }
+
+ BOOL CancelTrackMouse(DWORD dwType)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ TRACKMOUSEEVENT tme = {};
+ tme.cbSize = sizeof(TRACKMOUSEEVENT);
+ tme.dwFlags = TME_CANCEL | dwType;
+ tme.hwndTrack = m_hWnd;
+ return ::TrackMouseEvent(&tme);
+ }
+
+// CString support
+#ifdef __ATLSTR_H__
+ int GetWindowText(ATL::CString& strText) const
+ {
+ int nLength = GetWindowTextLength();
+ LPTSTR pszText = strText.GetBuffer(nLength + 1);
+ nLength = ::GetWindowText(m_hWnd, pszText, nLength + 1);
+ strText.ReleaseBuffer(nLength);
+
+ return nLength;
+ }
+
+ UINT GetDlgItemText(int nID, ATL::CString& strText) const
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ HWND hItem = GetDlgItem(nID);
+ if(hItem != NULL)
+ {
+ int nLength = ::GetWindowTextLength(hItem);
+ LPTSTR pszText = strText.GetBuffer(nLength + 1);
+ nLength = ::GetWindowText(hItem, pszText, nLength + 1);
+ strText.ReleaseBuffer(nLength);
+
+ return nLength;
+ }
+ else
+ {
+ strText.Empty();
+
+ return 0;
+ }
+ }
+#endif // __ATLSTR_H__
+
+// Dialog window only
+ UINT DlgGetDefID() const
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ LRESULT lRet = ::SendMessage(m_hWnd, DM_GETDEFID, 0, 0L);
+ UINT uID = 0U;
+ if(HIWORD(lRet) == DC_HASDEFID)
+ uID = LOWORD(lRet);
+
+ return uID;
+ }
+
+ void DlgSetDefID(UINT uID)
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ ::SendMessage(m_hWnd, DM_SETDEFID, uID, 0L);
+ }
+
+ void DlgReposition()
+ {
+ ATLASSERT(::IsWindow(m_hWnd));
+
+ ::SendMessage(m_hWnd, DM_REPOSITION, 0, 0L);
+ }
+};
+
+} // namespace WTL
+
+#endif // __ATLWINX_H__
diff --git a/Examples/WhisperDesktop/Utils/logger.cpp b/Examples/WhisperDesktop/Utils/logger.cpp
new file mode 100644
index 0000000..5c7c257
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/logger.cpp
@@ -0,0 +1,71 @@
+#include "stdafx.h"
+#include "logger.h"
+#include "miscUtils.h"
+
+namespace
+{
+ using namespace Whisper;
+
+ // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
+ // Lowest is red, middle is yellow, highest is green.
+ static const std::array<const char*, 10> k_colors =
+ {
+ "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
+ "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
+ };
+
+ static int colorIndex( const sToken& tok )
+ {
+ const float p = tok.probability;
+ const float p3 = p * p * p;
+ int col = (int)( p3 * float( k_colors.size() ) );
+ col = std::max( 0, std::min( (int)k_colors.size() - 1, col ) );
+ return col;
+ }
+}
+
+void printTimeStamp( CStringA& rdi, Whisper::sTimeSpan ts )
+{
+ sTimeSpanFields fields = ts;
+ uint32_t msec = fields.ticks / 10'000;
+ uint32_t hr = fields.days * 24 + fields.hours;
+ uint32_t min = fields.minutes;
+ uint32_t sec = fields.seconds;
+ rdi.AppendFormat( "%02d:%02d:%02d.%03d", hr, min, sec, msec );
+}
+
+HRESULT logNewSegments( const iTranscribeResult* results, size_t newSegments, bool printSpecial )
+{
+ sTranscribeLength length;
+ CHECK( results->getSize( length ) );
+
+ const size_t len = length.countSegments;
+ size_t i = len - newSegments;
+
+ const sSegment* const segments = results->getSegments();
+ const sToken* const tokens = results->getTokens();
+
+ CStringA str;
+ for( ; i < len; i++ )
+ {
+ const sSegment& seg = segments[ i ];
+ str = "[";
+ printTimeStamp( str, seg.time.begin );
+ str += " --> ";
+ printTimeStamp( str, seg.time.end );
+ str += "] ";
+
+ for( uint32_t j = 0; j < seg.countTokens; j++ )
+ {
+ const sToken& tok = tokens[ seg.firstToken + j ];
+ if( !printSpecial && ( tok.flags & eTokenFlags::Special ) )
+ continue;
+ str += k_colors[ colorIndex( tok ) ];
+ str += tok.text;
+ str += "\033[0m";
+ }
+ logInfo( u8"%s", cstr( str ) );
+ }
+
+ return S_OK;
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/logger.h b/Examples/WhisperDesktop/Utils/logger.h
new file mode 100644
index 0000000..07ec012
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/logger.h
@@ -0,0 +1,36 @@
+#pragma once
+#include <whisperWindows.h>
+#include <cstdarg>
+
+void logMessage( Whisper::eLogLevel lvl, const char8_t* pszFormat, va_list args );
+
+#define LOG_MESSAGE_IMPL( lvl ) \
+ std::va_list args; \
+ va_start( args, pszFormat ); \
+ logMessage( lvl, pszFormat, args ); \
+ va_end( args )
+
+inline void logError( const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( Whisper::eLogLevel::Error );
+}
+inline void logWarning( const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( Whisper::eLogLevel::Warning );
+}
+inline void logInfo( const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( Whisper::eLogLevel::Info );
+}
+inline void logDebug( const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( Whisper::eLogLevel::Debug );
+}
+
+#undef LOG_MESSAGE_IMPL
+
+HRESULT logNewSegments( const Whisper::iTranscribeResult* results, size_t newSegments, bool printSpecial = false );
+
+void clearLastError();
+bool getLastError( CString& rdi );
+void printTimeStamp( CStringA& rdi, Whisper::sTimeSpan ts ); \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/miscUtils.cpp b/Examples/WhisperDesktop/Utils/miscUtils.cpp
new file mode 100644
index 0000000..485cf00
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/miscUtils.cpp
@@ -0,0 +1,254 @@
+#include "stdafx.h"
+#include "miscUtils.h"
+
+namespace
+{
+ wchar_t* formatMessage( HRESULT hr )
+ {
+ wchar_t* err;
+ if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
+ NULL,
+ hr,
+ MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ),
+ (LPTSTR)&err,
+ 0,
+ NULL ) )
+ return err;
+ return nullptr;
+ }
+}
+
+CString formatErrorMessage( HRESULT hr )
+{
+ CString message;
+ const wchar_t* err = formatMessage( hr );
+ if( nullptr != err )
+ {
+ message = err;
+ LocalFree( (HLOCAL)err );
+ message.TrimRight();
+ }
+ else
+ message.Format( L"Error code %i (0x%08X)", hr, hr );
+
+ return message;
+}
+
+void reportFatalError( const char* what, HRESULT hr )
+{
+ CString message;
+ message.Format( L"%S\n%S\n", "Unable to start the application.", what );
+ message += formatErrorMessage( hr );
+ ::MessageBox( nullptr, message, L"Whisper Desktop Startup", MB_OK | MB_ICONERROR );
+}
+
+namespace
+{
+ using Whisper::eModelImplementation;
+
+ struct sImplString
+ {
+ eModelImplementation val;
+ LPCTSTR str;
+ };
+ static const std::array<sImplString, 3> s_implStrings =
+ {
+ sImplString{ eModelImplementation::GPU, L"GPU" },
+ sImplString{ eModelImplementation::Hybrid, L"Hybrid" },
+ sImplString{ eModelImplementation::Reference, L"Reference" },
+ };
+}
+
+HRESULT implParse( const CString& s, eModelImplementation& rdi )
+{
+ for( const auto& is : s_implStrings )
+ {
+ if( 0 != s.CompareNoCase( is.str ) )
+ continue;
+ rdi = is.val;;
+ return S_OK;
+ }
+ return E_INVALIDARG;
+}
+
+LPCTSTR implString( eModelImplementation i )
+{
+ for( const auto& is : s_implStrings )
+ if( is.val == i )
+ return is.str;
+ return nullptr;
+}
+
+void implPopulateCombobox( CComboBox& cb, Whisper::eModelImplementation impl )
+{
+ int curSel = 0;
+ int idx = 0;
+ for( const auto& is : s_implStrings )
+ {
+ cb.AddString( is.str );
+ if( is.val == impl )
+ curSel = idx;
+ idx++;
+ }
+ cb.SetCurSel( curSel );
+}
+
+Whisper::eModelImplementation implGetValue( CComboBox& cb )
+{
+ int curSel = cb.GetCurSel();
+ if( curSel < 0 )
+ return (Whisper::eModelImplementation)0;
+ return s_implStrings[ curSel ].val;
+}
+
+ThreadPoolWork::~ThreadPoolWork()
+{
+ if( nullptr != work )
+ {
+ CloseThreadpoolWork( work );
+ work = nullptr;
+ }
+}
+
+void __stdcall ThreadPoolWork::callback( PTP_CALLBACK_INSTANCE Instance, PVOID Context, PTP_WORK Work )
+{
+ iThreadPoolCallback* cb = (iThreadPoolCallback*)Context;
+ cb->poolCallback();
+}
+
+HRESULT ThreadPoolWork::create( iThreadPoolCallback* cb )
+{
+ if( nullptr == cb )
+ return E_POINTER;
+ if( nullptr != work )
+ return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
+
+ work = CreateThreadpoolWork( &callback, cb, nullptr );
+ if( nullptr != work )
+ return S_OK;
+
+ return HRESULT_FROM_WIN32( GetLastError() );
+}
+
+HRESULT ThreadPoolWork::post()
+{
+ if( nullptr == work )
+ return OLE_E_BLANK;
+ SubmitThreadpoolWork( work );
+ return S_OK;
+}
+
+void makeUtf16( CString& rdi, const char* utf8 )
+{
+ const size_t length = strlen( utf8 );
+ int count = MultiByteToWideChar( CP_UTF8, 0, utf8, (int)length, nullptr, 0 );
+ wchar_t* p = rdi.GetBufferSetLength( count );
+ MultiByteToWideChar( CP_UTF8, 0, utf8, (int)length, p, count );
+ rdi.ReleaseBuffer();
+}
+
+void makeUtf8( CStringA& rdi, const CString& utf16 )
+{
+ int count = WideCharToMultiByte( CP_UTF8, 0, utf16, utf16.GetLength(), nullptr, 0, nullptr, nullptr );
+ char* s = rdi.GetBufferSetLength( count + 1 );
+ count = WideCharToMultiByte( CP_UTF8, 0, utf16, utf16.GetLength(), s, count, nullptr, nullptr );
+ rdi.ReleaseBufferSetLength( count );
+}
+
+constexpr int ofnBufferLength = 2048;
+
+bool getOpenFileName( HWND owner, LPCTSTR title, LPCTSTR filter, CString& path )
+{
+ wchar_t buffer[ ofnBufferLength ];
+ buffer[ 0 ] = 0;
+ OPENFILENAME ofn;
+ memset( &ofn, 0, sizeof( ofn ) );
+ ofn.lStructSize = sizeof( OPENFILENAME );
+ ofn.hwndOwner = owner;
+ ofn.lpstrFilter = filter;
+ ofn.lpstrTitle = title;
+ ofn.Flags = OFN_EXPLORER | OFN_FILEMUSTEXIST | OFN_PATHMUSTEXIST;
+ ofn.lpstrFile = buffer;
+ ofn.nMaxFile = ofnBufferLength - 1;
+
+ CString dir;
+ if( path.GetLength() > 0 && path.GetLength() < ofnBufferLength )
+ wcsncpy_s( buffer, path, path.GetLength() );
+
+ if( !GetOpenFileName( &ofn ) )
+ {
+ path = L"";
+ return false;
+ }
+ else
+ {
+ path = ofn.lpstrFile;
+ return true;
+ }
+}
+
+bool getSaveFileName( HWND owner, LPCTSTR title, LPCTSTR filter, CString& path, DWORD* filterIndex )
+{
+ wchar_t buffer[ ofnBufferLength ];
+ buffer[ 0 ] = 0;
+
+ OPENFILENAME ofn;
+ memset( &ofn, 0, sizeof( ofn ) );
+ ofn.lStructSize = sizeof( OPENFILENAME );
+ ofn.hwndOwner = owner;
+ ofn.lpstrFilter = filter;
+ ofn.lpstrTitle = title;
+ ofn.Flags = OFN_EXPLORER | OFN_PATHMUSTEXIST;
+ ofn.lpstrFile = buffer;
+ ofn.nMaxFile = ofnBufferLength - 1;
+ if( nullptr != filterIndex )
+ ofn.nFilterIndex = *filterIndex + 1;
+
+ if( path.GetLength() > 0 && path.GetLength() < ofnBufferLength )
+ wcsncpy_s( buffer, path, path.GetLength() );
+
+ if( !GetSaveFileName( &ofn ) )
+ return false;
+
+ path = ofn.lpstrFile;
+
+ if( nullptr != filterIndex )
+ *filterIndex = ofn.nFilterIndex - 1;
+
+ return true;
+}
+
+void reportError( HWND owner, LPCTSTR text, LPCTSTR title, HRESULT hr )
+{
+ if( nullptr == title )
+ title = L"Operation Failed";
+
+ CString message = text;
+ message.TrimRight();
+ if( FAILED( hr ) )
+ {
+ message += L"\n";
+ message += formatErrorMessage( hr );
+ }
+
+ ::MessageBox( owner, message, title, MB_OK | MB_ICONWARNING );
+}
+
+HRESULT writeUtf8Bom( CAtlFile& file )
+{
+ const std::array<uint8_t, 3> bom = { 0xEF, 0xBB, 0xBF };
+ return file.Write( bom.data(), 3 );
+}
+
+bool isInvalidTranslate( HWND owner, uint32_t lang, bool translate )
+{
+ if( !translate )
+ return false;
+ constexpr uint32_t english = 0x6E65;
+ if( lang != english )
+ return false;
+
+ LPCTSTR message = L"The translate feature translates speech to English.\nIt’s not available when the audio language is already English.";
+ MessageBox( owner, message, L"Incompatible parameters", MB_OK | MB_ICONINFORMATION );
+ return true;
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/Utils/miscUtils.h b/Examples/WhisperDesktop/Utils/miscUtils.h
new file mode 100644
index 0000000..8cf8151
--- /dev/null
+++ b/Examples/WhisperDesktop/Utils/miscUtils.h
@@ -0,0 +1,72 @@
+#pragma once
+#include <iContext.h>
+#include "logger.h"
+
+CString formatErrorMessage( HRESULT hr );
+
+void reportFatalError( const char* what, HRESULT hr );
+
+#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
+
+HRESULT implParse( const CString& s, Whisper::eModelImplementation& rdi );
+
+LPCTSTR implString( Whisper::eModelImplementation i );
+
+void implPopulateCombobox( CComboBox& cb, Whisper::eModelImplementation impl );
+
+Whisper::eModelImplementation implGetValue( CComboBox& cb );
+
+__interface iThreadPoolCallback
+{
+ void __stdcall poolCallback() noexcept;
+};
+
+class ThreadPoolWork
+{
+ PTP_WORK work = nullptr;
+ static void __stdcall callback( PTP_CALLBACK_INSTANCE Instance, PVOID Context, PTP_WORK Work );
+
+public:
+
+ ~ThreadPoolWork();
+ HRESULT create( iThreadPoolCallback* cb );
+ HRESULT post();
+};
+
+void makeUtf16( CString& rdi, const char* utf8 );
+void makeUtf8( CStringA& rdi, const CString& utf16 );
+
+bool getOpenFileName( HWND owner, LPCTSTR title, LPCTSTR filter, CString& path );
+
+bool getSaveFileName( HWND owner, LPCTSTR title, LPCTSTR filter, CString& path, DWORD* filterIndex = nullptr );
+
+#define ON_BUTTON_CLICK( id, func ) \
+ if( uMsg == WM_COMMAND && \
+ id == LOWORD( wParam ) ) \
+ { \
+ bHandled = TRUE; \
+ func(); \
+ lResult = 0; \
+ return TRUE; \
+ }
+
+void reportError( HWND owner, LPCTSTR text, LPCTSTR title, HRESULT hr = S_FALSE );
+
+inline const wchar_t* cstr( const CString& s ) { return s; }
+inline const char* cstr( const CStringA& s ) { return s; }
+
+inline HRESULT getLastHr()
+{
+ return HRESULT_FROM_WIN32( GetLastError() );
+}
+
+HRESULT writeUtf8Bom( CAtlFile& file );
+
+// Flip order of bytes from RGB to BGR, or vice versa
+inline uint32_t flipRgb( uint32_t val )
+{
+ val = _byteswap_ulong( val );
+ return val >> 8;
+}
+
+bool isInvalidTranslate( HWND owner, uint32_t lang, bool translate ); \ No newline at end of file
diff --git a/Examples/WhisperDesktop/WhisperDesktop.cpp b/Examples/WhisperDesktop/WhisperDesktop.cpp
new file mode 100644
index 0000000..ddef5b6
--- /dev/null
+++ b/Examples/WhisperDesktop/WhisperDesktop.cpp
@@ -0,0 +1,63 @@
+#include "stdafx.h"
+#include "AppState.h"
+#include "Utils/miscUtils.h"
+#include "LoadModelDlg.h"
+#include "TranscribeDlg.h"
+#include "CaptureDlg.h"
+
+HRESULT dialogLoadModel( AppState& appState )
+{
+ LoadModelDlg loadDialog{ appState };
+ HRESULT hr = loadDialog.show();
+ if( FAILED( hr ) )
+ {
+ reportFatalError( "Error loading the model", hr );
+ return hr;
+ }
+ appState.automaticallyLoadModel = false;
+ return hr;
+}
+
+HRESULT dialogTranscribe( AppState& appState )
+{
+ TranscribeDlg dialog{ appState };
+ return dialog.show();
+}
+
+HRESULT dialogCapture( AppState& appState )
+{
+ CaptureDlg dialog{ appState };
+ return dialog.show();
+}
+
+using pfnDialog = HRESULT( * )( AppState& appState );
+static const std::array<pfnDialog, 4> s_dialogs =
+{
+ nullptr, // S_OK
+ &dialogLoadModel, // SCREEN_MODEL
+ &dialogTranscribe, // SCREEN_TRANSCRIBE
+ &dialogCapture, // SCREEN_CAPTURE
+};
+
+int __stdcall wWinMain( HINSTANCE hInstance, HINSTANCE hPrevInstance, LPWSTR lpCmdLine, int nCmdShow )
+{
+ AppState appState;
+ HRESULT hr = appState.startup();
+ if( FAILED( hr ) )
+ return hr;
+
+ appState.findModelSource();
+
+ hr = SCREEN_MODEL;
+ while( true )
+ {
+ pfnDialog pfn = s_dialogs[ hr ];
+ if( nullptr == pfn )
+ return S_OK;
+ hr = pfn( appState );
+ if( FAILED( hr ) )
+ return hr;
+ if( hr == SCREEN_MODEL )
+ appState.model = nullptr;
+ }
+} \ No newline at end of file
diff --git a/Examples/WhisperDesktop/WhisperDesktop.manifest b/Examples/WhisperDesktop/WhisperDesktop.manifest
new file mode 100644
index 0000000..af2b796
--- /dev/null
+++ b/Examples/WhisperDesktop/WhisperDesktop.manifest
@@ -0,0 +1,16 @@
+<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
+<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3">
+ <assemblyIdentity version="1.0.0.0" processorArchitecture="amd64" name="CompanyName.ProductName.YourApplication" type="win32" />
+ <description>Your application description here.</description>
+ <dependency>
+ <dependentAssembly>
+ <assemblyIdentity type="win32" name="Microsoft.Windows.Common-Controls" version="6.0.0.0" processorArchitecture="amd64" publicKeyToken="6595b64144ccf1df" language="*" />
+ </dependentAssembly>
+ </dependency>
+ <asmv3:application>
+ <asmv3:windowsSettings>
+ <dpiAware xmlns="http://schemas.microsoft.com/SMI/2005/WindowsSettings">true</dpiAware>
+ <dpiAwareness xmlns="http://schemas.microsoft.com/SMI/2016/WindowsSettings">PerMonitorV2</dpiAwareness>
+ </asmv3:windowsSettings>
+ </asmv3:application>
+</assembly> \ No newline at end of file
diff --git a/Examples/WhisperDesktop/WhisperDesktop.rc b/Examples/WhisperDesktop/WhisperDesktop.rc
new file mode 100644
index 0000000..67461d7
--- /dev/null
+++ b/Examples/WhisperDesktop/WhisperDesktop.rc
Binary files differ
diff --git a/Examples/WhisperDesktop/WhisperDesktop.vcxproj b/Examples/WhisperDesktop/WhisperDesktop.vcxproj
new file mode 100644
index 0000000..4956410
--- /dev/null
+++ b/Examples/WhisperDesktop/WhisperDesktop.vcxproj
@@ -0,0 +1,151 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>16.0</VCProjectVersion>
+ <Keyword>Win32Proj</Keyword>
+ <ProjectGuid>{cd9e49f0-75a3-4f91-ac71-336109ee39c6}</ProjectGuid>
+ <RootNamespace>WhisperDesktop</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <IncludePath>$(ProjectDir);$(SolutionDir)Whisper\API\;$(IncludePath)</IncludePath>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <IncludePath>$(ProjectDir);$(SolutionDir)Whisper\API\;$(IncludePath)</IncludePath>
+ </PropertyGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>_DEBUG;_WINDOWS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <PrecompiledHeader>Use</PrecompiledHeader>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ <MultiProcessorCompilation>true</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <ManifestFile>WhisperDesktop.manifest</ManifestFile>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_WINDOWS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <PrecompiledHeader>Use</PrecompiledHeader>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ <MultiProcessorCompilation>true</MultiProcessorCompilation>
+ <FavorSizeOrSpeed>Size</FavorSizeOrSpeed>
+ <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <ManifestFile>WhisperDesktop.manifest</ManifestFile>
+ <LinkTimeCodeGeneration>UseLinkTimeCodeGeneration</LinkTimeCodeGeneration>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ClInclude Include="CircleIndicator.h" />
+ <ClInclude Include="AppState.h" />
+ <ClInclude Include="CaptureDlg.h" />
+ <ClInclude Include="Utils\TranslateCheckbox.h" />
+ <ClInclude Include="Utils\DebugConsole.h" />
+ <ClInclude Include="framework.h" />
+ <ClInclude Include="Utils\logger.h" />
+ <ClInclude Include="Utils\PendingState.h" />
+ <ClInclude Include="Utils\LanguageDropdown.h" />
+ <ClInclude Include="LoadModelDlg.h" />
+ <ClInclude Include="TranscribeDlg.h" />
+ <ClInclude Include="Utils\miscUtils.h" />
+ <ClInclude Include="stdafx.h" />
+ <ClInclude Include="Resource.h" />
+ <ClInclude Include="targetver.h" />
+ <ClInclude Include="Utils\WTL\atlapp.h" />
+ <ClInclude Include="Utils\WTL\atlcrack.h" />
+ <ClInclude Include="Utils\WTL\atlctrls.h" />
+ <ClInclude Include="Utils\WTL\atlddx.h" />
+ <ClInclude Include="Utils\WTL\atlgdi.h" />
+ <ClInclude Include="Utils\WTL\atlres.h" />
+ <ClInclude Include="Utils\WTL\atluser.h" />
+ <ClInclude Include="Utils\WTL\atlwinx.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="CircleIndicator.cpp" />
+ <ClCompile Include="AppState.cpp" />
+ <ClCompile Include="CaptureDlg.cpp" />
+ <ClCompile Include="Utils\TranslateCheckbox.cpp" />
+ <ClCompile Include="Utils\DebugConsole.cpp" />
+ <ClCompile Include="Utils\logger.cpp" />
+ <ClCompile Include="Utils\PendingState.cpp" />
+ <ClCompile Include="Utils\LanguageDropdown.cpp" />
+ <ClCompile Include="LoadModelDlg.cpp" />
+ <ClCompile Include="TranscribeDlg.cpp" />
+ <ClCompile Include="Utils\miscUtils.cpp" />
+ <ClCompile Include="stdafx.cpp">
+ <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
+ <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
+ </ClCompile>
+ <ClCompile Include="WhisperDesktop.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ResourceCompile Include="WhisperDesktop.rc" />
+ </ItemGroup>
+ <ItemGroup>
+ <Image Include="sunflower.ico" />
+ </ItemGroup>
+ <ItemGroup>
+ <Manifest Include="WhisperDesktop.manifest" />
+ </ItemGroup>
+ <ItemGroup>
+ <ProjectReference Include="..\..\Whisper\Whisper.vcxproj">
+ <Project>{701df8c8-e4a5-43ec-9c6b-747bbf4d8e71}</Project>
+ </ProjectReference>
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/WhisperDesktop/WhisperDesktop.vcxproj.filters b/Examples/WhisperDesktop/WhisperDesktop.vcxproj.filters
new file mode 100644
index 0000000..7a36c86
--- /dev/null
+++ b/Examples/WhisperDesktop/WhisperDesktop.vcxproj.filters
@@ -0,0 +1,142 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <Filter Include="Source Files">
+ <UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
+ <Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
+ </Filter>
+ <Filter Include="Header Files">
+ <UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
+ <Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions>
+ </Filter>
+ <Filter Include="Resource Files">
+ <UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
+ <Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
+ </Filter>
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="framework.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="targetver.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Resource.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="stdafx.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="LoadModelDlg.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="AppState.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\miscUtils.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\WTL\atlapp.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\WTL\atlddx.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\WTL\atlctrls.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\WTL\atlgdi.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\WTL\atlres.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\WTL\atluser.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\WTL\atlwinx.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="TranscribeDlg.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\LanguageDropdown.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\PendingState.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\DebugConsole.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\logger.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="CaptureDlg.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="CircleIndicator.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\WTL\atlcrack.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="Utils\TranslateCheckbox.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="WhisperDesktop.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="stdafx.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="LoadModelDlg.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="AppState.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="Utils\miscUtils.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="TranscribeDlg.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="Utils\LanguageDropdown.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="Utils\PendingState.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="Utils\DebugConsole.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="Utils\logger.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="CaptureDlg.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="CircleIndicator.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="Utils\TranslateCheckbox.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ </ItemGroup>
+ <ItemGroup>
+ <ResourceCompile Include="WhisperDesktop.rc">
+ <Filter>Resource Files</Filter>
+ </ResourceCompile>
+ </ItemGroup>
+ <ItemGroup>
+ <Image Include="sunflower.ico">
+ <Filter>Resource Files</Filter>
+ </Image>
+ </ItemGroup>
+ <ItemGroup>
+ <Manifest Include="WhisperDesktop.manifest" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/WhisperDesktop/framework.h b/Examples/WhisperDesktop/framework.h
new file mode 100644
index 0000000..b52a644
--- /dev/null
+++ b/Examples/WhisperDesktop/framework.h
@@ -0,0 +1,22 @@
+#pragma once
+#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
+#define NOMINMAX
+// Windows Header Files
+#include "targetver.h"
+#include <windows.h>
+// ATL header files
+#include <atlstr.h>
+#include <atlfile.h>
+#include <atlbase.h>
+#include <atlwin.h>
+
+// C RunTime Header Files
+#include <stdlib.h>
+#include <malloc.h>
+#include <memory.h>
+#include <tchar.h>
+#include <assert.h>
+// C++ headers
+#include <vector>
+#include <array>
+#include <algorithm> \ No newline at end of file
diff --git a/Examples/WhisperDesktop/stdafx.cpp b/Examples/WhisperDesktop/stdafx.cpp
new file mode 100644
index 0000000..1577c4e
--- /dev/null
+++ b/Examples/WhisperDesktop/stdafx.cpp
@@ -0,0 +1 @@
+#include "stdafx.h" \ No newline at end of file
diff --git a/Examples/WhisperDesktop/stdafx.h b/Examples/WhisperDesktop/stdafx.h
new file mode 100644
index 0000000..6f46ad0
--- /dev/null
+++ b/Examples/WhisperDesktop/stdafx.h
@@ -0,0 +1,8 @@
+#pragma once
+#include "framework.h"
+
+#include <whisperWindows.h>
+
+#include "resource.h"
+#include "Utils/WTL/atlapp.h"
+#include "Utils/WTL/atlctrls.h" \ No newline at end of file
diff --git a/Examples/WhisperDesktop/sunflower.ico b/Examples/WhisperDesktop/sunflower.ico
new file mode 100644
index 0000000..b6404e1
--- /dev/null
+++ b/Examples/WhisperDesktop/sunflower.ico
Binary files differ
diff --git a/Examples/WhisperDesktop/targetver.h b/Examples/WhisperDesktop/targetver.h
new file mode 100644
index 0000000..17d4075
--- /dev/null
+++ b/Examples/WhisperDesktop/targetver.h
@@ -0,0 +1,6 @@
+#pragma once
+// Setup Windows SDK to only enable features available since Windows 8.0
+#include <WinSDKVer.h>
+#define _WIN32_WINNT _WIN32_WINNT_WIN8
+#define NTDDI_VERSION NTDDI_WIN8
+#include <SDKDDKVer.h>
diff --git a/Examples/main/main.cpp b/Examples/main/main.cpp
new file mode 100644
index 0000000..c9eacf2
--- /dev/null
+++ b/Examples/main/main.cpp
@@ -0,0 +1,315 @@
+#include "params.h"
+#include "../../Whisper/API/iContext.cl.h"
+#include "../../Whisper/API/iMediaFoundation.cl.h"
+#include "../../ComLightLib/comLightClient.h"
+#include "miscUtils.h"
+#include <array>
+#include <atomic>
+using namespace Whisper;
+
+#define STREAM_AUDIO 1
+
+static HRESULT loadWhisperModel( const wchar_t* path, iModel** pp )
+{
+ using namespace Whisper;
+ constexpr eModelImplementation impl = eModelImplementation::GPU;
+ // constexpr eModelImplementation impl = eModelImplementation::Reference;
+ return Whisper::loadModel( path, impl, nullptr, pp );
+}
+
+namespace
+{
+ struct sPrintUserData
+ {
+ const whisper_params* params;
+ // const std::vector<std::vector<float>>* pcmf32s;
+ };
+
+ // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
+ // Lowest is red, middle is yellow, highest is green.
+ static const std::array<const char*, 10> k_colors =
+ {
+ "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
+ "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
+ };
+
+ std::string to_timestamp( sTimeSpan ts, bool comma = false )
+ {
+ sTimeSpanFields fields = ts;
+ uint32_t msec = fields.ticks / 10'000;
+ uint32_t hr = fields.days * 24 + fields.hours;
+ uint32_t min = fields.minutes;
+ uint32_t sec = fields.seconds;
+
+ char buf[ 32 ];
+ snprintf( buf, sizeof( buf ), "%02d:%02d:%02d%s%03d", hr, min, sec, comma ? "," : ".", msec );
+ return std::string( buf );
+ }
+
+ static int colorIndex( const sToken& tok )
+ {
+ const float p = tok.probability;
+ const float p3 = p * p * p;
+ int col = (int)( p3 * float( k_colors.size() ) );
+ col = std::max( 0, std::min( (int)k_colors.size() - 1, col ) );
+ return col;
+ }
+
+ HRESULT __cdecl newSegmentCallback( iContext* context, uint32_t n_new, void* user_data ) noexcept
+ {
+ ComLight::CComPtr<iTranscribeResult> results;
+ CHECK( context->getResults( eResultFlags::Timestamps | eResultFlags::Tokens, &results ) );
+
+ sTranscribeLength length;
+ CHECK( results->getSize( length ) );
+
+ const whisper_params& params = *( (sPrintUserData*)user_data )->params;
+ // const std::vector<std::vector<float>>& pcmf32s = *( (sPrintUserData*)user_data )->pcmf32s;
+
+ // print the last n_new segments
+ const uint32_t s0 = length.countSegments - n_new;
+ if( s0 == 0 )
+ printf( "\n" );
+
+ const sSegment* const segments = results->getSegments();
+ const sToken* const tokens = results->getTokens();
+
+ for( uint32_t i = s0; i < length.countSegments; i++ )
+ {
+ const sSegment& seg = segments[ i ];
+
+ if( params.no_timestamps )
+ {
+ if( params.print_colors )
+ {
+ for( uint32_t j = 0; j < seg.countTokens; j++ )
+ {
+ const sToken& tok = tokens[ seg.firstToken + j ];
+ if( !params.print_special && ( tok.flags & eTokenFlags::Special ) )
+ continue;
+ wprintf( L"%S%s%S", k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" );
+ }
+ }
+ else
+ wprintf( L"%s", utf16( seg.text ).c_str() );
+ fflush( stdout );
+ continue;
+ }
+
+ std::string speaker = "";
+#if 0
+ if( params.diarize && pcmf32s.size() == 2 )
+ {
+ const size_t n_samples = pcmf32s[ 0 ].size();
+ const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples );
+ const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples );
+
+ double energy0 = 0.0f;
+ double energy1 = 0.0f;
+
+ for( int64_t j = is0; j < is1; j++ )
+ {
+ energy0 += fabs( pcmf32s[ 0 ][ j ] );
+ energy1 += fabs( pcmf32s[ 1 ][ j ] );
+ }
+
+ if( energy0 > 1.1 * energy1 )
+ speaker = "(speaker 0)";
+ else if( energy1 > 1.1 * energy0 )
+ speaker = "(speaker 1)";
+ else
+ speaker = "(speaker ?)";
+
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+ }
+#endif
+
+ if( params.print_colors )
+ {
+ printf( "[%s --> %s] ", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str() );
+
+ for( uint32_t j = 0; j < seg.countTokens; j++ )
+ {
+ const sToken& tok = tokens[ seg.firstToken + j ];
+ if( !params.print_special && ( tok.flags & eTokenFlags::Special ) )
+ continue;
+ wprintf( L"%S%S%s%S", speaker.c_str(), k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" );
+ }
+ printf( "\n" );
+ }
+ else
+ wprintf( L"[%S --> %S] %S%s\n", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str(), speaker.c_str(), utf16( seg.text ).c_str() );
+ }
+ return S_OK;
+ }
+
+ HRESULT __cdecl beginSegmentCallback( iContext* context, void* user_data ) noexcept
+ {
+ std::atomic_bool* flag = (std::atomic_bool*)user_data;
+ bool aborted = flag->load();
+ return aborted ? S_FALSE : S_OK;
+ }
+
+ HRESULT setupConsoleColors()
+ {
+ HANDLE h = GetStdHandle( STD_OUTPUT_HANDLE );
+ if( h == INVALID_HANDLE_VALUE )
+ return HRESULT_FROM_WIN32( GetLastError() );
+
+ DWORD mode = 0;
+ if( !GetConsoleMode( h, &mode ) )
+ return HRESULT_FROM_WIN32( GetLastError() );
+ if( 0 != ( mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING ) )
+ return S_FALSE;
+
+ mode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING;
+ if( !SetConsoleMode( h, mode ) )
+ return HRESULT_FROM_WIN32( GetLastError() );
+ return S_OK;
+ }
+}
+
+int wmain( int argc, wchar_t* argv[] )
+{
+ // Whisper::dbgCompareTraces( LR"(C:\Temp\2remove\Whisper\ref.bin)", LR"(C:\Temp\2remove\Whisper\gpu.bin )" ); return 0;
+
+ // Tell logger to use the standard output stream for the messages
+ {
+ Whisper::sLoggerSetup logSetup;
+ logSetup.flags = eLoggerFlags::UseStandardError;
+ logSetup.level = eLogLevel::Debug;
+ Whisper::setupLogger( logSetup );
+ }
+
+ whisper_params params;
+ if( !params.parse( argc, argv ) )
+ return 1;
+
+ if( params.print_colors )
+ {
+ if( FAILED( setupConsoleColors() ) )
+ params.print_colors = false;
+ }
+
+ if( params.fname_inp.empty() )
+ {
+ fprintf( stderr, "error: no input files specified\n" );
+ whisper_print_usage( argc, argv, params );
+ return 2;
+ }
+
+ if( Whisper::findLanguageKeyA( params.language.c_str() ) == UINT_MAX )
+ {
+ fprintf( stderr, "error: unknown language '%s'\n", params.language.c_str() );
+ whisper_print_usage( argc, argv, params );
+ return 3;
+ }
+
+ ComLight::CComPtr<iModel> model;
+ HRESULT hr = loadWhisperModel( params.model.c_str(), &model );
+ if( FAILED( hr ) )
+ {
+ printError( "failed to load the model", hr );
+ return 4;
+ }
+
+ ComLight::CComPtr<iContext> context;
+ hr = model->createContext( &context );
+ if( FAILED( hr ) )
+ {
+ printError( "failed to initialize whisper context", hr );
+ return 5;
+ }
+
+ ComLight::CComPtr<iMediaFoundation> mf;
+ hr = initMediaFoundation( &mf );
+ if( FAILED( hr ) )
+ {
+ printError( "failed to initialize Media Foundation runtime", hr );
+ return 5;
+ }
+
+ for( const std::wstring& fname : params.fname_inp )
+ {
+ // print some info about the processing
+ {
+ if( model->isMultilingual() == S_FALSE )
+ {
+ if( params.language != "en" || params.translate )
+ {
+ params.language = "en";
+ params.translate = false;
+ fprintf( stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__ );
+ }
+ }
+
+ /*
+ fwprintf( stderr, L"%S: processing '%s' (%zu samples, %.1f sec), %d threads, %d processors, lang = %S, task = %S, timestamps = %d ...\n",
+ __func__, fname.c_str(), audio.pcmf32.size(), audio.seconds(),
+ params.n_threads, params.n_processors,
+ params.language.c_str(),
+ params.translate ? "translate" : "transcribe",
+ params.no_timestamps ? 0 : 1 );
+ */
+ }
+
+ // run the inference
+ Whisper::sFullParams wparams;
+ context->fullDefaultParams( eSamplingStrategy::Greedy, &wparams );
+
+ wparams.resetFlag( eFullParamsFlags::PrintRealtime | eFullParamsFlags::PrintProgress );
+ wparams.setFlag( eFullParamsFlags::PrintTimestamps, !params.no_timestamps );
+ wparams.setFlag( eFullParamsFlags::PrintSpecial, params.print_special );
+ wparams.setFlag( eFullParamsFlags::Translate, params.translate );
+ wparams.language = Whisper::makeLanguageKey( params.language.c_str() );
+ wparams.cpuThreads = params.n_threads;
+ if( params.max_context != UINT_MAX )
+ wparams.n_max_text_ctx = params.max_context;
+ wparams.offset_ms = params.offset_t_ms;
+ wparams.duration_ms = params.duration_ms;
+
+ wparams.setFlag( eFullParamsFlags::TokenTimestamps, params.output_wts || params.max_len > 0 );
+ wparams.thold_pt = params.word_thold;
+ wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+
+ wparams.setFlag( eFullParamsFlags::SpeedupAudio, params.speed_up );
+ // sPrintUserData user_data = { &params, &audio.pcmf32s };
+ sPrintUserData user_data = { &params };
+
+ // this callback is called on each new segment
+ if( !wparams.flag( eFullParamsFlags::PrintRealtime ) )
+ {
+ wparams.new_segment_callback = &newSegmentCallback;
+ wparams.new_segment_callback_user_data = &user_data;
+ }
+
+ // example for abort mechanism
+ // in this example, we do not abort the processing, but we could if the flag is set to true
+ // the callback is called before every encoder run - if it returns false, the processing is aborted
+ std::atomic_bool is_aborted = false;
+ {
+ wparams.encoder_begin_callback = &beginSegmentCallback;
+ wparams.encoder_begin_callback_user_data = &is_aborted;
+ }
+
+#if STREAM_AUDIO
+ ComLight::CComPtr<iAudioReader> reader;
+ CHECK( mf->openAudioFile( fname.c_str(), params.diarize, &reader ) );
+ sProgressSink progressSink{ nullptr, nullptr };
+ hr = context->runStreamed( wparams, progressSink, reader );
+#else
+ ComLight::CComPtr<iAudioBuffer> buffer;
+ CHECK( mf->loadAudioFile( fname.c_str(), params.diarize, &buffer ) );
+ hr = context->runFull( wparams, buffer );
+#endif
+ if( FAILED( hr ) )
+ {
+ fwprintf( stderr, L"%s: failed to process audio\n", argv[ 0 ] );
+ return 10;
+ }
+ }
+
+ context->timingsPrint();
+ context = nullptr;
+ return 0;
+} \ No newline at end of file
diff --git a/Examples/main/main.vcxproj b/Examples/main/main.vcxproj
new file mode 100644
index 0000000..4945b88
--- /dev/null
+++ b/Examples/main/main.vcxproj
@@ -0,0 +1,93 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>16.0</VCProjectVersion>
+ <Keyword>Win32Proj</Keyword>
+ <ProjectGuid>{4cca7042-eb15-4f7a-b77b-5cafd2df47b2}</ProjectGuid>
+ <RootNamespace>main</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NOMINMAX;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NOMINMAX;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ClCompile Include="main.cpp" />
+ <ClCompile Include="miscUtils.cpp" />
+ <ClCompile Include="params.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="miscUtils.h" />
+ <ClInclude Include="params.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <ProjectReference Include="..\..\Whisper\Whisper.vcxproj">
+ <Project>{701df8c8-e4a5-43ec-9c6b-747bbf4d8e71}</Project>
+ </ProjectReference>
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/main/main.vcxproj.filters b/Examples/main/main.vcxproj.filters
new file mode 100644
index 0000000..94cd8a1
--- /dev/null
+++ b/Examples/main/main.vcxproj.filters
@@ -0,0 +1,12 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <ClCompile Include="main.cpp" />
+ <ClCompile Include="params.cpp" />
+ <ClCompile Include="miscUtils.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="params.h" />
+ <ClInclude Include="miscUtils.h" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/main/miscUtils.cpp b/Examples/main/miscUtils.cpp
new file mode 100644
index 0000000..3ebda20
--- /dev/null
+++ b/Examples/main/miscUtils.cpp
@@ -0,0 +1,48 @@
+#include "miscUtils.h"
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+
+std::string utf8( const std::wstring& utf16 )
+{
+ int count = WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), (int)utf16.length(), nullptr, 0, nullptr, nullptr );
+ std::string str( count, 0 );
+ WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), -1, &str[ 0 ], count, nullptr, nullptr );
+ return str;
+}
+
+std::wstring utf16( const std::string& u8 )
+{
+ int count = MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), nullptr, 0 );
+ std::wstring str( count, 0 );
+ MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), &str[ 0 ], count );
+ return str;
+}
+
+namespace
+{
+ wchar_t* formatMessage( HRESULT hr )
+ {
+ wchar_t* err;
+ if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
+ NULL,
+ hr,
+ MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ),
+ (LPTSTR)&err,
+ 0,
+ NULL ) )
+ return err;
+ return nullptr;
+ }
+}
+
+void printError( const char* what, HRESULT hr )
+{
+ const wchar_t* err = formatMessage( hr );
+ if( nullptr != err )
+ {
+ fwprintf( stderr, L"%S: %s\n", what, err );
+ LocalFree( (HLOCAL)err );
+ }
+ else
+ fprintf( stderr, "%s: error code %i (0x%08X)\n", what, hr, hr );
+} \ No newline at end of file
diff --git a/Examples/main/miscUtils.h b/Examples/main/miscUtils.h
new file mode 100644
index 0000000..52770a6
--- /dev/null
+++ b/Examples/main/miscUtils.h
@@ -0,0 +1,9 @@
+#pragma once
+#include <string>
+
+std::string utf8( const std::wstring& utf16 );
+
+std::wstring utf16( const std::string& u8 );
+
+using HRESULT = long;
+void printError( const char* what, HRESULT hr ); \ No newline at end of file
diff --git a/Examples/main/params.cpp b/Examples/main/params.cpp
new file mode 100644
index 0000000..ff1cfdd
--- /dev/null
+++ b/Examples/main/params.cpp
@@ -0,0 +1,101 @@
+#include "params.h"
+#include <algorithm>
+#include <thread>
+#include "miscUtils.h"
+
+whisper_params::whisper_params()
+{
+#ifdef _DEBUG
+ n_threads = 2;
+#else
+ n_threads = std::min( 4u, std::thread::hardware_concurrency() );
+#endif
+}
+
+namespace
+{
+ const char* cstr( bool b )
+ {
+ return b ? "true" : "false";
+ }
+}
+
+void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params )
+{
+ fprintf( stderr, "\n" );
+ fprintf( stderr, "usage: %S [options] file0.wav file1.wav ...\n", argv[ 0 ] );
+ fprintf( stderr, "\n" );
+ fprintf( stderr, "options:\n" );
+ fprintf( stderr, " -h, --help [default] show this help message and exit\n" );
+ fprintf( stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads );
+ fprintf( stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors );
+ fprintf( stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms );
+ fprintf( stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n );
+ fprintf( stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms );
+ fprintf( stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context );
+ fprintf( stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len );
+ fprintf( stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold );
+ fprintf( stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", cstr( params.speed_up ) );
+ fprintf( stderr, " -tr, --translate [%-7s] translate from source language to english\n", cstr( params.translate ) );
+ fprintf( stderr, " -di, --diarize [%-7s] stereo audio diarization\n", cstr( params.diarize ) );
+ fprintf( stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", cstr( params.output_txt ) );
+ fprintf( stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", cstr( params.output_vtt ) );
+ fprintf( stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", cstr( params.output_srt ) );
+ fprintf( stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", cstr( params.output_wts ) );
+ fprintf( stderr, " -ps, --print-special [%-7s] print special tokens\n", cstr( params.print_special ) );
+ fprintf( stderr, " -nc, --no-colors [%-7s] do not print colors\n", cstr( !params.print_colors ) );
+ fprintf( stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", cstr( params.no_timestamps ) );
+ fprintf( stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str() );
+ fprintf( stderr, " -m FNAME, --model FNAME [%-7S] model path\n", params.model.c_str() );
+ fprintf( stderr, " -f FNAME, --file FNAME [%-7s] path of the input audio file\n", "" );
+ fprintf( stderr, "\n" );
+}
+
+bool whisper_params::parse( int argc, wchar_t* argv[] )
+{
+ for( int i = 1; i < argc; i++ )
+ {
+ std::wstring arg = argv[ i ];
+
+ if( arg[ 0 ] != '-' )
+ {
+ fname_inp.push_back( arg );
+ continue;
+ }
+
+ if( arg == L"-h" || arg == L"--help" )
+ {
+ whisper_print_usage( argc, argv, *this );
+ return false;
+ }
+
+ else if( arg == L"-t" || arg == L"--threads" ) { n_threads = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-p" || arg == L"--processors" ) { n_processors = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-ot" || arg == L"--offset-t" ) { offset_t_ms = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-on" || arg == L"--offset-n" ) { offset_n = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-d" || arg == L"--duration" ) { duration_ms = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-mc" || arg == L"--max-context" ) { max_context = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-ml" || arg == L"--max-len" ) { max_len = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-wt" || arg == L"--word-thold" ) { word_thold = std::stof( argv[ ++i ] ); }
+ else if( arg == L"-su" || arg == L"--speed-up" ) { speed_up = true; }
+ else if( arg == L"-tr" || arg == L"--translate" ) { translate = true; }
+ else if( arg == L"-di" || arg == L"--diarize" ) { diarize = true; }
+ else if( arg == L"-otxt" || arg == L"--output-txt" ) { output_txt = true; }
+ else if( arg == L"-ovtt" || arg == L"--output-vtt" ) { output_vtt = true; }
+ else if( arg == L"-osrt" || arg == L"--output-srt" ) { output_srt = true; }
+ else if( arg == L"-owts" || arg == L"--output-words" ) { output_wts = true; }
+ else if( arg == L"-ps" || arg == L"--print-special" ) { print_special = true; }
+ else if( arg == L"-nc" || arg == L"--no-colors" ) { print_colors = false; }
+ else if( arg == L"-nt" || arg == L"--no-timestamps" ) { no_timestamps = true; }
+ else if( arg == L"-l" || arg == L"--language" ) { language = utf8( argv[ ++i ] ); }
+ else if( arg == L"-m" || arg == L"--model" ) { model = argv[ ++i ]; }
+ else if( arg == L"-f" || arg == L"--file" ) { fname_inp.push_back( argv[ ++i ] ); }
+ else
+ {
+ fprintf( stderr, "error: unknown argument: %S\n", arg.c_str() );
+ whisper_print_usage( argc, argv, *this );
+ return false;
+ }
+ }
+ return true;
+} \ No newline at end of file
diff --git a/Examples/main/params.h b/Examples/main/params.h
new file mode 100644
index 0000000..9eb2b04
--- /dev/null
+++ b/Examples/main/params.h
@@ -0,0 +1,38 @@
+#pragma once
+#include <vector>
+#include <string>
+
+// command-line parameters
+struct whisper_params
+{
+ uint32_t n_threads;
+ uint32_t n_processors = 1;
+ uint32_t offset_t_ms = 0;
+ uint32_t offset_n = 0;
+ uint32_t duration_ms = 0;
+ uint32_t max_context = UINT_MAX;
+ uint32_t max_len = 0;
+
+ float word_thold = 0.01f;
+
+ bool speed_up = false;
+ bool translate = false;
+ bool diarize = false;
+ bool output_txt = false;
+ bool output_vtt = false;
+ bool output_srt = false;
+ bool output_wts = false;
+ bool print_special = false;
+ bool print_colors = true;
+ bool no_timestamps = false;
+
+ std::string language = "en";
+ std::wstring model = L"models/ggml-base.en.bin";
+ std::vector<std::wstring> fname_inp;
+
+ whisper_params();
+
+ bool parse( int argc, wchar_t* argv[] );
+};
+
+void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params ); \ No newline at end of file
diff --git a/Tools/CompressShaders/Cabinet.cs b/Tools/CompressShaders/Cabinet.cs
new file mode 100644
index 0000000..b53fd18
--- /dev/null
+++ b/Tools/CompressShaders/Cabinet.cs
@@ -0,0 +1,60 @@
+using System.ComponentModel;
+using System.Runtime.InteropServices;
+
+namespace CompressShaders
+{
+ /// <summary>Lossless data compressor implemented by <c>Cabinet.dll</c> Windows component</summary>
+ /// <remarks>
+ /// <para>Whisper.dll consumes that component in runtime, to decompress these shader binaries</para>
+ /// <para>If you wonder why not gzip — because the OS doesn’t include an API for that, at least not an API usable from C or C++.<br/>
+ /// .NET standard library includes gzip algorithm, but we don't want Whisper.dll to depend on .NET.</para>
+ /// </remarks>
+ static class Cabinet
+ {
+ /// <summary>Compression algorithm</summary>
+ /// <seealso href="https://learn.microsoft.com/en-us/windows/win32/cmpapi/using-the-compression-api#selecting-the-compression-algorithm" />
+ enum eCompressionAlgorithm: uint
+ {
+ MSZIP = 2,
+ XPRESS = 3,
+ XPRESS_HUFF = 4,
+ LZMS = 5,
+ }
+ /// <summary>The value should match <c>constexpr DWORD compressionAlgorithm</c> constant,<br/>in <c>Whisper/D3D/shaders.cpp</c> source file</summary>
+ const eCompressionAlgorithm algo = eCompressionAlgorithm.MSZIP;
+
+ [DllImport( "Cabinet.dll", SetLastError = true )]
+ static extern bool CreateCompressor( eCompressionAlgorithm Algorithm, IntPtr AllocationRoutines, out IntPtr CompressorHandle );
+
+ [DllImport( "Cabinet.dll", SetLastError = true )]
+ static extern bool CloseCompressor( IntPtr CompressorHandle );
+
+ [DllImport( "Cabinet.dll", SetLastError = true )]
+ static extern bool Compress( IntPtr CompressorHandle, [In] byte[] UncompressedData, IntPtr UncompressedDataSize, [Out] byte[] CompressedBuffer, IntPtr CompressedBufferSize, out IntPtr CompressedDataSize );
+
+ /// <summary>Compress an array of bytes into another, smaller array of bytes</summary>
+ /// <remarks>In practice, the compression ratio is about 7.1 for the shader binaries in Release configuration.</remarks>
+ public static byte[] compressBuffer( byte[] src )
+ {
+ if( src.Length <= 0 )
+ throw new ArgumentException( "The source buffer is empty" );
+ IntPtr hCompressor;
+ if( !CreateCompressor( algo, IntPtr.Zero, out hCompressor ) )
+ throw new Win32Exception( "Unable to create the compressor" );
+ try
+ {
+ byte[] dest = new byte[ src.Length * 2 ];
+ IntPtr srcSize = new IntPtr( src.Length );
+ IntPtr destSize = new IntPtr( src.Length * 2 );
+ if( !Compress( hCompressor, src, srcSize, dest, destSize, out destSize ) )
+ throw new Win32Exception( "Compress failed" );
+ Array.Resize( ref dest, (int)destSize );
+ return dest;
+ }
+ finally
+ {
+ CloseCompressor( hCompressor );
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/CompressShaders/CompressShaders.cs b/Tools/CompressShaders/CompressShaders.cs
new file mode 100644
index 0000000..814f966
--- /dev/null
+++ b/Tools/CompressShaders/CompressShaders.cs
@@ -0,0 +1,244 @@
+using System.Runtime.CompilerServices;
+namespace CompressShaders;
+
+record struct sShaderBinary
+{
+ public string name;
+ public byte[] data;
+
+ public sShaderBinary( string path )
+ {
+ name = Path.GetFileNameWithoutExtension( path );
+ data = File.ReadAllBytes( path );
+ }
+
+ public bool wave64 => name.EndsWith( "64" );
+ public string uniqueName => wave64 ? name.Substring( 0, name.Length - 2 ) : name;
+}
+
+sealed class FoundShaders
+{
+ public readonly sShaderBinary[] binaries;
+ public readonly string[] names;
+ public readonly int[] wave32, wave64;
+
+ public FoundShaders( IEnumerable<sShaderBinary> found )
+ {
+ binaries = found
+ .OrderBy( b => b.name )
+ .ToArray();
+
+ names = binaries
+ .Select( b => b.uniqueName )
+ .Distinct()
+ .ToArray();
+
+ wave32 = new int[ names.Length ];
+ wave64 = new int[ names.Length ];
+ for( int i = 0; i < names.Length; i++ )
+ {
+ int i32 = findIndex( names[ i ], false );
+ int i64 = findIndex( names[ i ], true );
+ if( i32 >= 0 && i64 >= 0 )
+ {
+ wave32[ i ] = i32;
+ wave64[ i ] = i64;
+ continue;
+ }
+ if( i32 >= 0 )
+ {
+ wave32[ i ] = wave64[ i ] = i32;
+ continue;
+ }
+ throw new ApplicationException( $"Wave64 shader {names[ i ]} doesn't have the corresponding Wave32 one" );
+ }
+ }
+
+ int findIndex( string name, bool wave64 )
+ {
+ for( int i = 0; i < binaries.Length; i++ )
+ {
+ sShaderBinary sb = binaries[ i ];
+ if( sb.uniqueName != name )
+ continue;
+ if( sb.wave64 == wave64 )
+ return i;
+ }
+ return -1;
+ }
+}
+
+class Program
+{
+ static string getSolutionRoot( [CallerFilePath] string? path = null )
+ {
+ string? dir = Path.GetDirectoryName( path );
+ dir = Path.GetDirectoryName( dir );
+ dir = Path.GetDirectoryName( dir );
+ return dir ?? throw new ApplicationException();
+ }
+
+#if DEBUG
+ const string config = "Debug";
+#else
+ const string config = "Release";
+#endif
+
+ static string shadersBinDir( string root )
+ {
+ return Path.Combine( root, "ComputeShaders", "x64", config );
+ }
+
+ static IEnumerable<sShaderBinary> readShaders( string root )
+ {
+ string dir = shadersBinDir( root );
+ foreach( string path in Directory.EnumerateFiles( dir, "*.cso" ) )
+ yield return new sShaderBinary( path );
+ }
+
+ static void writeHeader( string root, IEnumerable<string> names )
+ {
+ string path = Path.Combine( root, "Whisper", "D3D", "shaderNames.h" );
+ using var stream = File.CreateText( path );
+ stream.WriteLine( @"// This header is generated by a tool
+#pragma once
+#include <stdint.h>
+
+namespace DirectCompute
+{
+ enum struct eComputeShader: uint16_t
+ {" );
+
+ int id = 0;
+ foreach( string name in names )
+ {
+ stream.WriteLine( "\t\t{0} = {1},", name, id );
+ id++;
+ }
+ stream.Write( @" };
+
+ const char* computeShaderName( eComputeShader cs );
+}" );
+ }
+
+ static void writeCpp( string root, IEnumerable<string> names )
+ {
+ string path = Path.Combine( root, "Whisper", "D3D", "shaderNames.cpp" );
+ ShaderNames.write( path, names );
+ }
+
+ static void writePayloadIDs( StreamWriter stream, string varName, int[] ids )
+ {
+ stream.Write( @"
+static const std::array<uint8_t, {0}> {1} = {{", ids.Length, varName );
+
+ for( int i = 0; i < ids.Length; i++ )
+ {
+ if( 0 == i % 16 )
+ stream.Write( "\r\n\t" );
+ else
+ stream.Write( ' ' );
+ stream.Write( "{0},", ids[ i ] );
+ }
+ stream.Write( @"
+};" );
+ }
+
+ static void writePayload( string root, FoundShaders shaders, out int cbSource, out int cbCompressed )
+ {
+ MemoryStream ms = new MemoryStream();
+ List<int> offsets = new List<int>();
+ foreach( var bin in shaders.binaries )
+ {
+ offsets.Add( (int)ms.Length );
+ ms.Write( bin.data );
+ }
+ offsets.Add( (int)ms.Length );
+
+ byte[] dxbc = ms.ToArray();
+ byte[] compressed = Cabinet.compressBuffer( dxbc );
+ cbSource = dxbc.Length;
+ cbCompressed = compressed.Length;
+
+ string path = Path.Combine( root, "Whisper", "D3D", $"shaderData-{config}.inl" );
+ using var stream = File.CreateText( path );
+ stream.Write( @"// This source file is generated by a tool
+
+// This array contains concatenated and compressed DXBC binaries for all compiled compute shaders
+static const std::array<uint8_t, {0}> s_compressedShaders =
+{{", compressed.Length );
+
+ for( int i = 0; i < compressed.Length; i++ )
+ {
+ if( 0 == i % 16 )
+ stream.Write( "\r\n\t" );
+ else
+ stream.Write( ' ' );
+ stream.Write( "0x{0:X02},", compressed[ i ] );
+ }
+
+ stream.Write( @"
+}};
+
+// This array contains start offsets of shader binaries in the decompressed DXBC blob.
+// It includes one more entry for the end of the complete decompressed blob.
+static const std::array<uint32_t, {0}> s_shaderOffsets = {{", offsets.Count );
+
+ for( int i = 0; i < offsets.Count; i++ )
+ {
+ if( 0 == i % 16 )
+ stream.Write( "\r\n\t" );
+ else
+ stream.Write( ' ' );
+ stream.Write( "{0},", offsets[ i ] );
+ }
+ stream.Write( @"
+};" );
+
+ stream.Write( @"
+// Index = eComputeShader enum value, value = index of the shader binary to use on nVidia and Intel GPUs" );
+ writePayloadIDs( stream, "s_shaderBlobs32", shaders.wave32 );
+ stream.Write( @"
+// Index = eComputeShader enum value, value = index of the shader binary to use on AMD GPUs" );
+ writePayloadIDs( stream, "s_shaderBlobs64", shaders.wave64 );
+
+ ulong fp64Flags = 0;
+ for( int i = 0; i < shaders.binaries.Length; i++ )
+ {
+ bool fp64 = DetectFp64.usesFp64( shaders.binaries[ i ].data );
+ if( fp64 )
+ fp64Flags |= (ulong)1 << i;
+ }
+
+ stream.Write( @"
+// Bitmap of the shader binaries which use FP64 arithmetic instructions
+constexpr uint64_t fp64ShadersBitmap = 0x{0:X}ull;", fp64Flags );
+ }
+
+ static void mainImpl()
+ {
+ string root = getSolutionRoot();
+ LanguageCodes.produce( root );
+
+ FoundShaders shaders = new FoundShaders( readShaders( root ) );
+
+ writeHeader( root, shaders.names );
+ writeCpp( root, shaders.names );
+ writePayload( root, shaders, out int cbIn, out int cbOut );
+ Console.WriteLine( "Compressed {0} compute shaders, {1:F1} kb -> {2:F1} kb", shaders.binaries.Length, cbIn / 1024.0, cbOut / 1024.0 );
+ }
+
+ static int Main( string[] args )
+ {
+ try
+ {
+ mainImpl();
+ return 0;
+ }
+ catch( Exception ex )
+ {
+ Console.WriteLine( ex.Message );
+ return ex.HResult;
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/CompressShaders/CompressShaders.csproj b/Tools/CompressShaders/CompressShaders.csproj
new file mode 100644
index 0000000..dee1710
--- /dev/null
+++ b/Tools/CompressShaders/CompressShaders.csproj
@@ -0,0 +1,10 @@
+<Project Sdk="Microsoft.NET.Sdk">
+ <PropertyGroup>
+ <OutputType>Exe</OutputType>
+ <TargetFramework>net6.0</TargetFramework>
+ <ImplicitUsings>enable</ImplicitUsings>
+ <Nullable>enable</Nullable>
+ <CheckForOverflowUnderflow>true</CheckForOverflowUnderflow>
+ <AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath>
+ </PropertyGroup>
+</Project> \ No newline at end of file
diff --git a/Tools/CompressShaders/DetectFp64.cs b/Tools/CompressShaders/DetectFp64.cs
new file mode 100644
index 0000000..1d75126
--- /dev/null
+++ b/Tools/CompressShaders/DetectFp64.cs
@@ -0,0 +1,43 @@
+#pragma warning disable CS0649
+using System.Runtime.InteropServices;
+
+namespace CompressShaders
+{
+ static class DetectFp64
+ {
+ struct DXBCHeader
+ {
+ public uint FourCC; // Four character code "DXBC"
+ public uint Hash0; // 32-bit hash of the DXBC file
+ public uint Hash1; // 32-bit hash of the DXBC file
+ public uint Hash2; // 32-bit hash of the DXBC file
+ public uint Hash3; // 32-bit hash of the DXBC file
+ public uint unknownOne;
+ public uint TotalSize; // Total size of the DXBC file in bytes
+ public int NumChunks; // Number of chunks in the DXBC file
+ };
+
+ public static bool usesFp64( ReadOnlySpan<byte> dxbc )
+ {
+ ReadOnlySpan<DXBCHeader> dxbcHeaderSpan = MemoryMarshal.Cast<byte, DXBCHeader>( dxbc );
+ DXBCHeader dxbcHeader = dxbcHeaderSpan[ 0 ];
+
+ int cbHeader = Marshal.SizeOf<DXBCHeader>();
+ int nChunks = dxbcHeader.NumChunks;
+ ReadOnlySpan<int> chunkOffsets = MemoryMarshal.Cast<byte, int>( dxbc.Slice( cbHeader, nChunks * 4 ) );
+ foreach( int off in chunkOffsets )
+ {
+ uint id = MemoryMarshal.Cast<byte, uint>( dxbc.Slice( off, 4 ) )[ 0 ];
+ const uint SFI0 = 0x30494653;
+ if( id != SFI0 )
+ continue;
+ int size = MemoryMarshal.Cast<byte, int>( dxbc.Slice( off + 4, 4 ) )[ 0 ];
+ if( size < 4 )
+ throw new ApplicationException();
+ uint data = MemoryMarshal.Cast<byte, uint>( dxbc.Slice( off + 8, 4 ) )[ 0 ];
+ return 0 != ( data & 1u );
+ }
+ return false;
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/CompressShaders/LanguageCodes.cs b/Tools/CompressShaders/LanguageCodes.cs
new file mode 100644
index 0000000..71a9909
--- /dev/null
+++ b/Tools/CompressShaders/LanguageCodes.cs
@@ -0,0 +1,103 @@
+using System.Globalization;
+using System.Text.RegularExpressions;
+
+namespace CompressShaders
+{
+ static class LanguageCodes
+ {
+ record struct Row
+ {
+ public string keySource;
+ public uint keyValue;
+ public int code;
+ public string name;
+ }
+
+ static uint makeKey( string str )
+ {
+ if( str.Length > 4 )
+ throw new ArgumentException();
+ uint k = 0;
+ int shift = 0;
+ foreach( char c in str )
+ {
+ if( c >= 0x80 )
+ throw new ArgumentException();
+ uint u = (uint)c;
+ k |= ( u << shift );
+ shift += 8;
+ }
+ return k;
+ }
+
+ static IEnumerable<Row> load( string path )
+ {
+ using var stm = File.OpenText( path );
+ while( true )
+ {
+ string? line = stm.ReadLine();
+ if( null == line )
+ break;
+ if( string.IsNullOrWhiteSpace( line ) )
+ continue;
+ string[] fields = line.Split( '\t' );
+ yield return new Row()
+ {
+ keySource = fields[ 0 ],
+ keyValue = makeKey( fields[ 0 ] ),
+ code = int.Parse( fields[ 1 ] ),
+ name = fields[ 2 ]
+ };
+ }
+ }
+
+ static void writeCpp( string inl, Row[] data )
+ {
+ // TODO [very low]: sort them by the key here, then in C++ use binary search instead of the hash map
+ using var stm = File.CreateText( inl );
+ stm.WriteLine( "// This file is generated by a tool, from the `languageCodez.tsv` file in this repository" );
+ foreach( Row row in data )
+ stm.WriteLine( "Lang{{ 0x{0:X}, {1}, \"{2}\" }},", row.keyValue, row.code, row.name );
+ }
+
+ static readonly CultureInfo ci = new CultureInfo( "en-US", false );
+ static string titleCase( this string name ) =>
+ ci.TextInfo.ToTitleCase( name.ToLower( ci ) );
+
+ static void writeCs( string cs, Row[] data )
+ {
+ using var stm = File.CreateText( cs );
+ stm.WriteLine( @"// This file is generated by a tool, from the `languageCodez.tsv` file in this repository
+namespace Whisper
+{
+ /// <summary>Supported languages</summary>
+ public enum eLanguage: uint
+ {" );
+
+ foreach( Row row in data )
+ {
+ string tc = row.name.titleCase();
+ stm.WriteLine( " /// <summary>{0}</summary>", tc );
+ tc = Regex.Replace( tc, @"\s+", string.Empty );
+ stm.WriteLine( " {0} = 0x{1:X},", tc, row.keyValue );
+ }
+ stm.Write( @" }
+}" );
+ }
+
+ static void produce( string tsv, string inl, string cs )
+ {
+ Row[] data = load( tsv ).OrderBy( r => r.name ).ToArray();
+ writeCpp( inl, data );
+ writeCs( cs, data );
+ }
+
+ public static void produce( string solutionRoot )
+ {
+ string tsv = Path.Combine( solutionRoot, "Whisper\\Whisper\\languageCodez.tsv" );
+ string inl = Path.Combine( solutionRoot, "Whisper\\Whisper\\languageCodez.inl" );
+ string cs = Path.Combine( solutionRoot, "WhisperNet\\API\\eLanguage.cs" );
+ produce( tsv, inl, cs );
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/CompressShaders/Readme.txt b/Tools/CompressShaders/Readme.txt
new file mode 100644
index 0000000..69ef35a
--- /dev/null
+++ b/Tools/CompressShaders/Readme.txt
@@ -0,0 +1,10 @@
+This project builds a C# console app which serves as a code generator for a few pieces of Whisper.dll and WhisperNet.dll.
+
+Specifically, it generates two things.
+
+1. It compresses the compiled DXBC shaders into a blob of bytes, and prints std::array with these bytes into shaderData-Release.inl and shaderData-Debug.inl C++ files.
+
+2. It parses the `languageCodez.tsv`, and generates both C++ and C# code with the data from that table.
+
+The tool uses relative paths across source files.
+These paths will break if you move the source of the tool, or the source data of the tool. \ No newline at end of file
diff --git a/Tools/CompressShaders/ShaderNames.cs b/Tools/CompressShaders/ShaderNames.cs
new file mode 100644
index 0000000..81ba46e
--- /dev/null
+++ b/Tools/CompressShaders/ShaderNames.cs
@@ -0,0 +1,27 @@
+static class ShaderNames
+{
+ public static void write( string path, IEnumerable<string> names )
+ {
+ string[] arr = names.ToArray();
+ using var stream = File.CreateText( path );
+ stream.WriteLine( @"// This source file is generated by a tool
+#include ""stdafx.h""
+#include ""shaderNames.h""
+" );
+
+ stream.WriteLine( "static const std::array<const char*, {0}> s_shaderNames = ", arr.Length );
+ stream.WriteLine( "{" );
+ foreach( string name in arr )
+ stream.WriteLine( @" ""{0}"",", name );
+
+ stream.Write( @"};
+
+const char* DirectCompute::computeShaderName( eComputeShader cs )
+{
+ const uint16_t i = (uint16_t)cs;
+ if( i < s_shaderNames.size() )
+ return s_shaderNames[ i ];
+ return nullptr;
+}" );
+ }
+} \ No newline at end of file
diff --git a/Tools/compareTraces/CommandLineArgs.cpp b/Tools/compareTraces/CommandLineArgs.cpp
new file mode 100644
index 0000000..5f26bdb
--- /dev/null
+++ b/Tools/compareTraces/CommandLineArgs.cpp
@@ -0,0 +1,51 @@
+#include "stdafx.h"
+#include "CommandLineArgs.h"
+#include <charconv>
+
+static bool printUsage()
+{
+ fprintf( stderr, "Usage: compareTraces.exe trace1.bin trace2.bin [-diff N]\n" );
+ return false;
+}
+
+bool CommandLineArgs::parse( int argc, wchar_t* argv[] )
+{
+ size_t idx = 0;
+ CString sw;
+ CStringA tmp;
+ for( int i = 1; i < argc; i++ )
+ {
+ if( argv[ i ][ 0 ] != L'-' )
+ {
+ if( idx >= 2 )
+ return printUsage();
+ inputs[ idx ] = argv[ i ];
+ idx++;
+ continue;
+ }
+ sw = argv[ i ];
+ if( 0 == sw.CompareNoCase( L"-diff" ) )
+ {
+ i++;
+ if( i >= argc )
+ return printUsage();
+ tmp.Format( "%S", argv[ i ] );
+ tmp.Trim();
+ uint64_t v;
+ auto res = std::from_chars( tmp, cstr( tmp ) + tmp.GetLength(), v );
+ if( res.ec != (std::errc)0 )
+ {
+ fprintf( stderr, "Unable to parse string into number\n" );
+ return false;
+ }
+ printDiff = v;
+ continue;
+ }
+ return printUsage();
+ }
+
+ if( idx != 2 )
+ return printUsage();
+
+ return true;
+} \ No newline at end of file
diff --git a/Tools/compareTraces/CommandLineArgs.h b/Tools/compareTraces/CommandLineArgs.h
new file mode 100644
index 0000000..d434e76
--- /dev/null
+++ b/Tools/compareTraces/CommandLineArgs.h
@@ -0,0 +1,9 @@
+#pragma once
+
+struct CommandLineArgs
+{
+ int64_t printDiff = -1;
+ std::array<CString, 2> inputs;
+
+ bool parse( int argc, wchar_t* argv[] );
+}; \ No newline at end of file
diff --git a/Tools/compareTraces/Readme.txt b/Tools/compareTraces/Readme.txt
new file mode 100644
index 0000000..035d658
--- /dev/null
+++ b/Tools/compareTraces/Readme.txt
@@ -0,0 +1,9 @@
+This project builds a C++ console tool which compares debug traces of the model.
+
+Tracing files easily exceed 1GB, and by default they’re disabled with a preprocessor macro in stdafx.h of the Whisper project.
+
+When enabled, the main GPU implementation saves a trace into C:\Temp\2remove\Whisper\gpu.bin
+
+The reference CPU implementation saves a trace into C:\Temp\2remove\Whisper\ref.bin
+
+This code in this project is optimized for development speed. For this reason it requires AVX2 CPU, uses memory-mapped IO instead of proper parsing, and checks little to no errors. \ No newline at end of file
diff --git a/Tools/compareTraces/TraceReader.cpp b/Tools/compareTraces/TraceReader.cpp
new file mode 100644
index 0000000..b4b9681
--- /dev/null
+++ b/Tools/compareTraces/TraceReader.cpp
@@ -0,0 +1,46 @@
+#include "stdafx.h"
+#include "TraceReader.h"
+using namespace Tracing;
+
+const sTraceItem& TraceReader::operator[]( size_t idx ) const
+{
+ if( idx >= countItems )
+ throw E_BOUNDS;
+ return items[ idx ];
+}
+
+CStringA TraceReader::getName( const sTraceItem& item ) const
+{
+ const size_t idx = item.stringIndex;
+ if( idx >= countStrings )
+ throw E_BOUNDS;
+ const char* const source = stringData + stringIndex[ idx ];
+ CStringA res;
+ res.Format( source, item.formatArgs[ 0 ], item.formatArgs[ 1 ], item.formatArgs[ 2 ], item.formatArgs[ 3 ] );
+ return res;
+}
+
+HRESULT TraceReader::open( LPCTSTR path )
+{
+ CHECK( file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING ) );
+ CHECK( mapping.MapFile( file ) );
+
+ const uint8_t* rsi = mapping;
+ const sFileHeader& header = *(const sFileHeader*)rsi;
+ if( header.magic != header.correctMagic )
+ return E_INVALIDARG;
+ countItems = header.countItems;
+ countStrings = header.countStrings;
+
+ rsi += sizeof( sFileHeader );
+ payloadPointer = rsi;
+
+ rsi += header.bytesPayload;
+ stringIndex = (const uint32_t*)( rsi );
+ stringData = (const char*)( rsi + countStrings * 4 );
+
+ rsi += header.bytesStrings;
+ items = (const sTraceItem*)rsi;
+
+ return S_OK;
+} \ No newline at end of file
diff --git a/Tools/compareTraces/TraceReader.h b/Tools/compareTraces/TraceReader.h
new file mode 100644
index 0000000..8d1e1f2
--- /dev/null
+++ b/Tools/compareTraces/TraceReader.h
@@ -0,0 +1,35 @@
+#pragma once
+#include "../../Whisper/Utils/Trace/TraceStructures.h"
+#include <atlstr.h>
+#include <atlfile.h>
+
+namespace Tracing
+{
+ class TraceReader
+ {
+ const uint8_t* payloadPointer = nullptr;
+ const sTraceItem* items = nullptr;
+ size_t countItems = 0;
+ size_t countStrings = 0;
+ const uint32_t* stringIndex = nullptr;
+ const char* stringData = nullptr;
+
+ CAtlFile file;
+ CAtlFileMapping<uint8_t> mapping;
+
+ public:
+
+ TraceReader() = default;
+ ~TraceReader() = default;
+
+ HRESULT open( LPCTSTR path );
+ size_t size() const { return countItems; }
+ const sTraceItem& operator[]( size_t idx ) const;
+ CStringA getName( const sTraceItem& item ) const;
+
+ const void* payload( const sTraceItem& item ) const
+ {
+ return payloadPointer + item.payloadOffset;
+ }
+ };
+} \ No newline at end of file
diff --git a/Tools/compareTraces/compare.cpp b/Tools/compareTraces/compare.cpp
new file mode 100644
index 0000000..ec2a6ef
--- /dev/null
+++ b/Tools/compareTraces/compare.cpp
@@ -0,0 +1,364 @@
+#include "stdafx.h"
+#include "../../Whisper/API/iContext.cl.h"
+#include "TraceReader.h"
+#include "../../Whisper/ML/testUtils.h"
+#include "compare.h"
+using namespace Tracing;
+using namespace DirectCompute;
+
+namespace
+{
+ inline const char* cstr( eItemType it )
+ {
+ switch( it )
+ {
+ case eItemType::Buffer: return "Buffer";
+ case eItemType::Tensor: return "Tensor";
+ }
+ throw E_INVALIDARG;
+ }
+ inline const char* cstr( const CStringA& s ) { return s; }
+
+ inline int tensorDims( __m128i vec )
+ {
+ const __m128i one = _mm_set1_epi32( 1 );
+ const uint32_t bitmapOnes = (uint32_t)_mm_movemask_ps( _mm_castsi128_ps( _mm_cmpeq_epi32( vec, one ) ) );
+ const uint32_t bitmapNotOnes = bitmapOnes ^ 0b1111u;
+ unsigned long idx;
+ if( !_BitScanReverse( &idx, bitmapNotOnes ) )
+ return 0;
+ return idx + 1;
+ }
+
+ int printSize( __m128i vec )
+ {
+ const int sz = tensorDims( vec );
+ switch( sz )
+ {
+ case 0:
+ printf( "[ scalar ]" );
+ break;
+ case 1:
+ printf( "[ %i ]", _mm_cvtsi128_si32( vec ) );
+ break;
+ case 2:
+ printf( "[ %i, %i ]", _mm_cvtsi128_si32( vec ), _mm_extract_epi32( vec, 1 ) );
+ break;
+ case 3:
+ printf( "[ %i, %i, %i ]", _mm_cvtsi128_si32( vec ), _mm_extract_epi32( vec, 1 ), _mm_extract_epi32( vec, 2 ) );
+ break;
+ case 4:
+ printf( "[ %i, %i, %i, %i ]", _mm_cvtsi128_si32( vec ), _mm_extract_epi32( vec, 1 ), _mm_extract_epi32( vec, 2 ), _mm_extract_epi32( vec, 3 ) );
+ break;
+ default:
+ throw E_UNEXPECTED;
+ }
+ return sz;
+ }
+
+ class Comparer
+ {
+ TraceReader& readerA;
+ TraceReader& readerB;
+
+ bool diffBuffers( size_t i, const sTraceItem& a, const sTraceItem& b, const CStringA& name )
+ {
+ const size_t lenA = *(const uint64_t*)a.size.data();
+ const size_t lenB = *(const uint64_t*)b.size.data();
+ if( lenA != lenB )
+ {
+ printf( "Buffer %zu \"%s\": different size, %zu in trace A, %zu in trace B\n", i, cstr( name ), lenA, lenB );
+ return false;
+ }
+ if( a.dataType != b.dataType )
+ {
+ printf( "Buffer %zu \"%s\": different data types\n", i, cstr( name ) );
+ return false;
+ }
+
+ switch( a.dataType )
+ {
+ case eDataType::FP32:
+ return buffersFp32( i, name, (const float*)readerA.payload( a ), (const float*)readerB.payload( b ), lenA );
+ }
+ throw E_NOTIMPL;
+ }
+
+ bool diffTensors( size_t i, const sTraceItem& a, const sTraceItem& b, const CStringA& name )
+ {
+ const __m128i ne1 = load( a.size );
+ const __m128i ne2 = load( b.size );
+ if( !vectorEqual( ne1, ne2 ) )
+ {
+ printf( "Tensor %zu \"%s\" - different size: trace A size is ", i, cstr( name ) );
+ printSize( ne1 );
+ printf( ", trace B size is " );
+ printSize( ne2 );
+ printf( "\n" );
+ return false;
+ }
+
+ const __m128i stride1 = load( a.stride );
+ const __m128i stride2 = load( b.stride );
+ if( !vectorEqual( stride1, stride2 ) )
+ {
+ printf( "Tensor %zu \"%s\" - different memory layout\n", i, cstr( name ) );
+ return false;
+ }
+
+ if( a.dataType != b.dataType )
+ {
+ printf( "Tensor %zu \"%s\": different data types\n", i, cstr( name ) );
+ return false;
+ }
+
+ size_t elements = (uint32_t)_mm_cvtsi128_si32( ne1 );
+ elements *= (uint32_t)_mm_extract_epi32( ne1, 1 );
+ elements *= (uint32_t)_mm_extract_epi32( ne1, 2 );
+ elements *= (uint32_t)_mm_extract_epi32( ne1, 3 );
+
+ switch( a.dataType )
+ {
+ case eDataType::FP32:
+ return tensorsFp32( i, name, (const float*)readerA.payload( a ), (const float*)readerB.payload( b ), elements, ne1, stride1 );
+ }
+ throw E_NOTIMPL;
+ }
+
+ protected:
+ virtual bool buffersFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length ) = 0;
+ virtual bool tensorsFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length, __m128i ne, __m128i nb ) = 0;
+
+ public:
+
+ Comparer( TraceReader& t1, TraceReader& t2 ) :
+ readerA( t1 ), readerB( t2 ) { }
+
+ bool compare( size_t i )
+ {
+ const sTraceItem& a = readerA[ i ];
+ const sTraceItem& b = readerB[ i ];
+ CStringA name1 = readerA.getName( a );
+ CStringA name2 = readerB.getName( b );
+
+ if( a.itemType != b.itemType )
+ {
+ printf( "Item %zu: different type, trace A %s \"%s\", trace B %s \"%s\"\n", i,
+ cstr( a.itemType ), cstr( name1 ), cstr( b.itemType ), cstr( name2 ) );
+ return false;
+ }
+
+ if( name1 != name2 )
+ {
+ printf( "%s %zu: different names, they are \"%s\" and \"%s\"\n", cstr( a.itemType ), i, cstr( name1 ), cstr( name2 ) );
+ return false;
+ }
+
+ switch( a.itemType )
+ {
+ case eItemType::Buffer:
+ return diffBuffers( i, a, b, name1 );
+ case eItemType::Tensor:
+ return diffTensors( i, a, b, name1 );
+ default:
+ throw E_INVALIDARG;
+ }
+ }
+ };
+
+ class PrintSummary : public Comparer
+ {
+ bool buffersFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length ) override;
+ bool tensorsFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length, __m128i ne, __m128i nb ) override;
+
+ public:
+ PrintSummary( TraceReader& a, TraceReader& b ) : Comparer( a, b ) { }
+ };
+
+ bool PrintSummary::buffersFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length )
+ {
+ sTensorDiff diff = computeDiff( a, b, length );
+ printf( "%s %zu \"%s\": ", cstr( eItemType::Buffer ), idx, cstr( name ) );
+ diff.print();
+ return true;
+ }
+
+ bool PrintSummary::tensorsFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length, __m128i ne, __m128i nb )
+ {
+ printSize( ne );
+ printf( " " );
+ sTensorDiff diff = computeDiff( a, b, length );
+ printf( "%s %zu \"%s\": ", cstr( eItemType::Tensor ), idx, cstr( name ) );
+ diff.print();
+ return true;
+ }
+
+ class PrintDiff : public Comparer
+ {
+ bool buffersFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length ) override;
+ bool tensorsFp32( size_t idx, const CStringA& name, const float* a, const float* b, size_t length, __m128i ne, __m128i nb ) override;
+ public:
+ PrintDiff( TraceReader& a, TraceReader& b ) : Comparer( a, b ) { }
+ };
+
+ bool PrintDiff::buffersFp32( size_t idx, const CStringA& name, const float* A, const float* B, size_t length )
+ {
+ printf( "idx\tA\tB\tA(hex)\tB(hex)\tdiff\n" );
+ for( size_t i = 0; i < length; i++ )
+ {
+ const float a = *A;
+ const float b = *B;
+ __m128 vf = _mm_setr_ps( a, b, 0, 0 );
+ __m128i vi = _mm_castps_si128( vf );
+ const float diff = std::abs( a - b );
+ printf( "%zu\t%g\t%g\t0x%08X\t0x%08X\t%g\n",
+ i, a, b, _mm_cvtsi128_si32( vi ), _mm_extract_epi32( vi, 1 ), diff );
+ }
+ return true;
+ }
+
+ std::array<uint32_t, 4> storeSize( __m128i v )
+ {
+ std::array<uint32_t, 4> a;
+ _mm_storeu_si128( ( __m128i* )a.data(), v );
+ return a;
+ }
+
+ std::array<size_t, 4> storeStrides( __m128i v )
+ {
+ const __m128i zero = _mm_setzero_si128();
+ std::array<size_t, 4> a;
+ _mm_storeu_si128( ( __m128i* ) & a[ 0 ], _mm_unpacklo_epi32( v, zero ) );
+ _mm_storeu_si128( ( __m128i* ) & a[ 2 ], _mm_unpackhi_epi32( v, zero ) );
+ return a;
+ }
+
+ bool PrintDiff::tensorsFp32( size_t idx, const CStringA& name, const float* A, const float* B, size_t length, __m128i ne, __m128i nb )
+ {
+ const int dims = tensorDims( ne );
+ const std::array<uint32_t, 4> size = storeSize( ne );
+ const std::array<size_t, 4> strides = storeStrides( ne );
+ CStringA line;
+ if( dims > 4 )
+ throw E_UNEXPECTED;
+
+ for( int i = 0; i < dims; i++ )
+ {
+ const char c = "xyzw"[ i ];
+ line.AppendChar( c );
+ line.AppendChar( '\t' );
+ }
+ line += "A\tB\tA(hex)\tB(hex)\tdiff\n";
+ printf( "%s", cstr( line ) );
+
+ if( 0 == dims )
+ {
+ const float a = *A;
+ const float b = *B;
+ __m128 vf = _mm_setr_ps( a, b, 0, 0 );
+ __m128i vi = _mm_castps_si128( vf );
+ const float diff = std::abs( a - b );
+ printf( "%g\t%g\t0x%08X\t0x%08X\t%g\n",
+ a, b, _mm_cvtsi128_si32( vi ), _mm_extract_epi32( vi, 1 ), diff );
+ return true;
+ }
+
+ size_t offLayer2 = 0;
+ for( uint32_t w = 0; w < size[ 3 ]; w++, offLayer2 += strides[ 3 ] )
+ {
+ size_t offLayer = offLayer2;
+ for( uint32_t z = 0; z < size[ 2 ]; z++, offLayer += strides[ 2 ] )
+ {
+ size_t offRow = offLayer;
+ for( uint32_t y = 0; y < size[ 1 ]; y++, offRow += strides[ 1 ] )
+ {
+ size_t off = offRow;
+ for( uint32_t x = 0; x < size[ 0 ]; x++, off += strides[ 0 ] )
+ {
+ line.Format( "%i\t", x );
+ if( dims > 1 )
+ line.AppendFormat( "%i\t", y );
+ if( dims > 2 )
+ line.AppendFormat( "%i\t", z );
+ if( dims > 3 )
+ line.AppendFormat( "%i\t", w );
+
+ const float a = A[ off ];
+ const float b = B[ off ];
+ __m128 vf = _mm_setr_ps( a, b, 0, 0 );
+ __m128i vi = _mm_castps_si128( vf );
+ const float diff = std::abs( a - b );
+ line.AppendFormat( "%g\t%g\t0x%08X\t0x%08X\t%g\n",
+ a, b, _mm_cvtsi128_si32( vi ), _mm_extract_epi32( vi, 1 ), diff );
+ printf( "%s", cstr( line ) );
+ }
+ }
+ }
+ }
+ return true;
+ }
+}
+
+HRESULT compareTraces( const CommandLineArgs& arguments )
+{
+ const wchar_t* pathA = arguments.inputs[ 0 ];
+ const wchar_t* pathB = arguments.inputs[ 1 ];
+
+ TraceReader a, b;
+ HRESULT hr = a.open( pathA );
+ if( FAILED( hr ) )
+ {
+ fwprintf( stderr, L"Unable to load trace A from \"%s\"", pathA );
+ printError( hr );
+ return hr;
+ }
+
+ hr = b.open( pathB );
+ if( FAILED( hr ) )
+ {
+ fwprintf( stderr, L"Unable to load trace B from \"%s\"", pathA );
+ printError( hr );
+ return hr;
+ }
+
+ wprintf( L"Trace A: %s\n", pathA );
+ wprintf( L"Trace B: %s\n", pathB );
+ const size_t sizeA = a.size();
+ const size_t sizeB = b.size();
+ const size_t count = std::min( sizeA, sizeB );
+
+ if( arguments.printDiff >= 0 )
+ {
+ if( arguments.printDiff >= (int64_t)count )
+ {
+ fprintf( stderr, "Trace A has %zu entries, trace B %zu entries; entry %zu ain't there\n",
+ sizeA, sizeB, (size_t)arguments.printDiff );
+ return E_INVALIDARG;
+ }
+ try
+ {
+ PrintDiff print{ a, b };
+ print.compare( arguments.printDiff );
+ return S_OK;
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+ }
+
+ printf( "Trace A has %zu entries, trace B %zu entries, comparing first %zu\n", sizeA, sizeB, count );
+
+ try
+ {
+ PrintSummary print{ a, b };
+ for( size_t i = 0; i < count; i++ )
+ if( !print.compare( i ) )
+ return S_FALSE;
+ return S_OK;
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+} \ No newline at end of file
diff --git a/Tools/compareTraces/compare.h b/Tools/compareTraces/compare.h
new file mode 100644
index 0000000..2a4cd86
--- /dev/null
+++ b/Tools/compareTraces/compare.h
@@ -0,0 +1,4 @@
+#pragma once
+#include "CommandLineArgs.h"
+
+HRESULT compareTraces( const CommandLineArgs& arguments ); \ No newline at end of file
diff --git a/Tools/compareTraces/compareTraces.cpp b/Tools/compareTraces/compareTraces.cpp
new file mode 100644
index 0000000..8813500
--- /dev/null
+++ b/Tools/compareTraces/compareTraces.cpp
@@ -0,0 +1,16 @@
+#include "stdafx.h"
+#include <stdio.h>
+#include "compare.h"
+#include "CommandLineArgs.h"
+
+int wmain( int argc, wchar_t* argv[] )
+{
+ CommandLineArgs cla;
+ if( !cla.parse( argc, argv ) )
+ return 1;
+
+ HRESULT hr = compareTraces( cla );
+ if( SUCCEEDED( hr ) )
+ return 0;
+ return hr;
+} \ No newline at end of file
diff --git a/Tools/compareTraces/compareTraces.vcxproj b/Tools/compareTraces/compareTraces.vcxproj
new file mode 100644
index 0000000..a9670b3
--- /dev/null
+++ b/Tools/compareTraces/compareTraces.vcxproj
@@ -0,0 +1,103 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>16.0</VCProjectVersion>
+ <Keyword>Win32Proj</Keyword>
+ <ProjectGuid>{8478a77c-d851-4c63-9511-1770cc82d33e}</ProjectGuid>
+ <RootNamespace>compareTraces</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ <PrecompiledHeader>Use</PrecompiledHeader>
+ <EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ <PrecompiledHeader>Use</PrecompiledHeader>
+ <EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ClCompile Include="CommandLineArgs.cpp" />
+ <ClCompile Include="compareTraces.cpp" />
+ <ClCompile Include="compare.cpp" />
+ <ClCompile Include="stdafx.cpp">
+ <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
+ <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
+ </ClCompile>
+ <ClCompile Include="testUtils.cpp" />
+ <ClCompile Include="TraceReader.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="CommandLineArgs.h" />
+ <ClInclude Include="compare.h" />
+ <ClInclude Include="stdafx.h" />
+ <ClInclude Include="TraceReader.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <Text Include="Readme.txt" />
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/Tools/compareTraces/compareTraces.vcxproj.filters b/Tools/compareTraces/compareTraces.vcxproj.filters
new file mode 100644
index 0000000..1d0d4c3
--- /dev/null
+++ b/Tools/compareTraces/compareTraces.vcxproj.filters
@@ -0,0 +1,20 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <ClCompile Include="compareTraces.cpp" />
+ <ClCompile Include="TraceReader.cpp" />
+ <ClCompile Include="compare.cpp" />
+ <ClCompile Include="stdafx.cpp" />
+ <ClCompile Include="testUtils.cpp" />
+ <ClCompile Include="CommandLineArgs.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="TraceReader.h" />
+ <ClInclude Include="stdafx.h" />
+ <ClInclude Include="compare.h" />
+ <ClInclude Include="CommandLineArgs.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <Text Include="Readme.txt" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/Tools/compareTraces/stdafx.cpp b/Tools/compareTraces/stdafx.cpp
new file mode 100644
index 0000000..5c7f6c9
--- /dev/null
+++ b/Tools/compareTraces/stdafx.cpp
@@ -0,0 +1,30 @@
+#include "stdafx.h"
+
+namespace
+{
+ wchar_t* formatMessage( HRESULT hr )
+ {
+ wchar_t* err;
+ if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
+ NULL,
+ hr,
+ MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ),
+ (LPTSTR)&err,
+ 0,
+ NULL ) )
+ return err;
+ return nullptr;
+ }
+}
+
+void printError( HRESULT hr )
+{
+ const wchar_t* err = formatMessage( hr );
+ if( nullptr != err )
+ {
+ fwprintf( stderr, L"%s\n", err );
+ LocalFree( (HLOCAL)err );
+ }
+ else
+ fprintf( stderr, "Error code %i (0x%08X)\n", hr, hr );
+} \ No newline at end of file
diff --git a/Tools/compareTraces/stdafx.h b/Tools/compareTraces/stdafx.h
new file mode 100644
index 0000000..1e496f3
--- /dev/null
+++ b/Tools/compareTraces/stdafx.h
@@ -0,0 +1,40 @@
+#pragma once
+#include <stdint.h>
+#include <assert.h>
+
+#define WIN32_LEAN_AND_MEAN
+#define NOMINMAX
+#include <windows.h>
+#include <atlstr.h>
+#include <d3d11.h>
+
+#include <vector>
+#include <array>
+#include <emmintrin.h>
+#include <smmintrin.h>
+
+#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
+
+inline __m128i load16( const int* rsi )
+{
+ return _mm_loadu_si128( ( const __m128i* )rsi );
+}
+inline __m128i load16( const uint32_t* rsi )
+{
+ return _mm_loadu_si128( ( const __m128i* )rsi );
+}
+inline __m128i load( const std::array<uint32_t, 4>& arr )
+{
+ return load16( arr.data() );
+}
+
+inline bool vectorEqual( __m128i a, __m128i b )
+{
+ __m128i xx = _mm_xor_si128( a, b );
+ return (bool)_mm_testz_si128( xx, xx );
+}
+
+void printError( HRESULT hr );
+
+inline const char* cstr( const CStringA& s ) { return s; }
+inline const wchar_t* cstr( const CString& s ) { return s; } \ No newline at end of file
diff --git a/Tools/compareTraces/testUtils.cpp b/Tools/compareTraces/testUtils.cpp
new file mode 100644
index 0000000..f9fa465
--- /dev/null
+++ b/Tools/compareTraces/testUtils.cpp
@@ -0,0 +1,224 @@
+#include "stdafx.h"
+#include "../../Whisper/ML/testUtils.h"
+#include <immintrin.h>
+using namespace DirectCompute;
+
+namespace
+{
+ using DirectCompute::sTensorDiff;
+
+ __forceinline __m256 load( const float* rsi )
+ {
+ return _mm256_loadu_ps( rsi );
+ }
+
+ __forceinline __m256 load( const uint16_t* rsi )
+ {
+ const __m128i iv = _mm_load_si128( ( const __m128i* )rsi );
+ return _mm256_cvtph_ps( iv );
+ }
+
+ __forceinline void loadPartial( const uint16_t* x, const uint16_t* y, size_t count, __m256& fx, __m256& fy )
+ {
+ __m128i ix, iy;
+ switch( count )
+ {
+ case 1: // load 2 bytes
+ ix = _mm_cvtsi32_si128( *x );
+ iy = _mm_cvtsi32_si128( *y );
+ break;
+ case 2: // load 4 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ iy = _mm_cvtsi32_si128( *(const int*)y );
+ break;
+ case 3: // load 6 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ iy = _mm_cvtsi32_si128( *(const int*)y );
+ ix = _mm_insert_epi16( ix, x[ 2 ], 2 );
+ iy = _mm_insert_epi16( iy, y[ 2 ], 2 );
+ break;
+ case 4: // load 8 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ break;
+ case 5: // load 10 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi16( ix, x[ 4 ], 4 );
+ iy = _mm_insert_epi16( iy, y[ 4 ], 4 );
+ break;
+ case 6: // load 12 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 );
+ break;
+ case 7: // load 14 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 );
+ ix = _mm_insert_epi16( ix, x[ 6 ], 6 );
+ iy = _mm_insert_epi16( iy, y[ 6 ], 6 );
+ break;
+ default:
+ fx = fy = _mm256_setzero_ps();
+ return;
+ }
+
+ fx = _mm256_cvtph_ps( ix );
+ fy = _mm256_cvtph_ps( iy );
+ }
+
+ inline __m128 loadFloat2( const float* rsi )
+ {
+ return _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) );
+ }
+ inline __m128 loadFloat3( const float* rsi )
+ {
+ __m128 f = loadFloat2( rsi );
+ f = _mm_insert_ps( f, _mm_load_ss( rsi + 2 ), 0x20 );
+ return f;
+ }
+ __forceinline void loadPartial( const float* x, const float* y, size_t count, __m256& fx, __m256& fy )
+ {
+ __m128 low1, high1;
+ __m128 low2, high2;
+ high1 = high2 = _mm_setzero_ps();
+ switch( count )
+ {
+ case 1:
+ low1 = _mm_load_ss( x );
+ low2 = _mm_load_ss( y );
+ break;
+ case 2:
+ low1 = loadFloat2( x );
+ low2 = loadFloat2( y );
+ break;
+ case 3:
+ low1 = loadFloat3( x );
+ low2 = loadFloat3( y );
+ break;
+ case 4:
+ low1 = _mm_loadu_ps( x );
+ low2 = _mm_loadu_ps( y );
+ break;
+ case 5:
+ low1 = _mm_loadu_ps( x );
+ low2 = _mm_loadu_ps( y );
+ high1 = _mm_load_ss( x + 4 );
+ high2 = _mm_load_ss( y + 4 );
+ break;
+ case 6:
+ low1 = _mm_loadu_ps( x );
+ low2 = _mm_loadu_ps( y );
+ high1 = loadFloat2( x + 4 );
+ high2 = loadFloat2( y + 4 );
+ break;
+ case 7: // load 14 bytes
+ low1 = _mm_loadu_ps( x );
+ low2 = _mm_loadu_ps( y );
+ high1 = loadFloat3( x + 4 );
+ high2 = loadFloat3( y + 4 );
+ break;
+ default:
+ fx = fy = _mm256_setzero_ps();
+ return;
+ }
+
+ fx = _mm256_setr_m128( low1, high1 );
+ fy = _mm256_setr_m128( low2, high2 );
+ }
+
+ __forceinline float horizontalMaximum( __m256 v )
+ {
+ __m128 s = _mm256_extractf128_ps( v, 1 );
+ s = _mm_max_ps( s, _mm256_castps256_ps128( v ) );
+ s = _mm_max_ps( s, _mm_movehl_ps( s, s ) );
+ s = _mm_max_ss( s, _mm_movehdup_ps( s ) );
+ return _mm_cvtss_f32( s );
+ }
+
+ __forceinline double horizontalSum( __m256 v )
+ {
+ __m256d d = _mm256_cvtps_pd( _mm256_extractf128_ps( v, 1 ) );
+ d = _mm256_add_pd( d, _mm256_cvtps_pd( _mm256_castps256_ps128( v ) ) );
+
+ __m128d s = _mm256_extractf128_pd( d, 1 );
+ s = _mm_add_pd( s, _mm256_castpd256_pd128( d ) );
+ s = _mm_add_sd( s, _mm_unpackhi_pd( s, s ) );
+ return _mm_cvtsd_f64( s );
+ }
+
+ __m256 maskInfNan( __m256 diff, __m256 a, __m256 b )
+ {
+ __m256i ai = _mm256_castps_si256( a );
+ __m256i bi = _mm256_castps_si256( b );
+ __m256i eqi = _mm256_cmpeq_epi32( ai, bi );
+ __m256 eq = _mm256_castsi256_ps( eqi );
+ return _mm256_andnot_ps( eq, diff );
+ }
+
+ class DiffAcc
+ {
+ __m256 maxAbs = _mm256_setzero_ps();
+ __m256 sumSquares = _mm256_setzero_ps();
+
+ public:
+
+ __forceinline void add( __m256 a, __m256 b )
+ {
+ const __m256 neg0 = _mm256_set1_ps( -0.0f );
+ __m256 diff = _mm256_sub_ps( b, a );
+ diff = maskInfNan( diff, a, b );
+ sumSquares = _mm256_fmadd_ps( diff, diff, sumSquares );
+ const __m256 absDiff = _mm256_andnot_ps( neg0, diff );
+ maxAbs = _mm256_max_ps( maxAbs, absDiff );
+ }
+
+ __forceinline sTensorDiff reduce( size_t count )
+ {
+ sTensorDiff res;
+ res.maxAbsDiff = horizontalMaximum( maxAbs );
+ res.avgDiffSquared = (float)( horizontalSum( sumSquares ) / (double)(int64_t)count );
+ res.length = count;
+ return res;
+ }
+ };
+
+ template<class E>
+ static sTensorDiff __declspec( noinline ) diffVectors( const E* a, const E* b, size_t length )
+ {
+ // const E* const aEnd = a + length;
+ const E* const aEndAligned = a + ( length / 8 ) * 8;
+ const size_t remainder = length % 8;
+
+ DiffAcc acc;
+ for( ; a < aEndAligned; a += 8, b += 8 )
+ acc.add( load( a ), load( b ) );
+
+ if( remainder != 0 )
+ {
+ __m256 va, vb;
+ loadPartial( a, b, remainder, va, vb );
+ acc.add( va, vb );
+ }
+
+ return acc.reduce( length );
+ }
+}
+
+sTensorDiff DirectCompute::computeDiff( const float* a, const float* b, size_t length )
+{
+ return diffVectors( a, b, length );
+}
+
+sTensorDiff DirectCompute::computeDiff( const uint16_t* a, const uint16_t* b, size_t length )
+{
+ return diffVectors( a, b, length );
+}
+
+void DirectCompute::sTensorDiff::print() const
+{
+ printf( "%zu elements, maxAbsDiff = %g, avgDiffSquared = %g\n", length, maxAbsDiff, avgDiffSquared );
+} \ No newline at end of file
diff --git a/Whisper/API/MfStructs.h b/Whisper/API/MfStructs.h
new file mode 100644
index 0000000..cd27659
--- /dev/null
+++ b/Whisper/API/MfStructs.h
@@ -0,0 +1,51 @@
+#pragma once
+
+namespace Whisper
+{
+ struct sCaptureDevice
+ {
+ // The display name is suitable for showing to the user, but might not be unique.
+ const wchar_t* displayName;
+
+ // Endpoint ID for an audio capture device
+ // It uniquely identifies the device on the system, but is not a readable string.
+ const wchar_t* endpoint;
+ };
+
+ using pfnFoundCaptureDevices = HRESULT( __stdcall* )( int len, const sCaptureDevice* buffer, void* pv );
+
+ // Flags for the audio capture
+ enum struct eCaptureFlags : uint32_t
+ {
+ // When the capture device supports stereo, keep stereo PCM samples in addition to mono
+ Stereo = 1,
+ };
+
+ // Parameters for audio capture
+ struct sCaptureParams
+ {
+ float minDuration = 2.0f;
+ float maxDuration = 3.0f;
+ float dropStartSilence = 0.25f;
+ float pauseDuration = 0.333f;
+ // Flags for the audio capture
+ uint32_t flags = 0;
+ };
+
+ enum struct eCaptureStatus : uint8_t
+ {
+ Listening = 1,
+ Voice = 2,
+ Transcribing = 4,
+ Stalled = 0x80,
+ };
+
+ using pfnShouldCancel = HRESULT( __stdcall* )( void* pv ) noexcept;
+ using pfnCaptureStatus = HRESULT( __stdcall* )( void* pv, eCaptureStatus status ) noexcept;
+ struct sCaptureCallbacks
+ {
+ pfnShouldCancel shouldCancel;
+ pfnCaptureStatus captureStatus;
+ void* pv;
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/Readme.txt b/Whisper/API/Readme.txt
new file mode 100644
index 0000000..7d40494
--- /dev/null
+++ b/Whisper/API/Readme.txt
@@ -0,0 +1,15 @@
+The headers in this folder define the complete public API of Whisper.dll.
+
+To consume the library in your C++ software, include exactly one of the following headers.
+
+1. If you’re building a windows app, include whisperWindows.h header, and you'll get traditional Win32 COM projection of the API.
+
+2. If you’re porting to other OS, or porting to different C++ compiler, or already using ComLight support library, include whisperComLight.h header.
+If you do that, in addition to this "Whisper/API" folder you also gonna need the "ComLightLib" dependency.
+This will get you the ComLight flavor of these COM interfaces.
+
+Internally, the actual implementation uses the ComLight flavour of the interfaces, but that’s fine because they are binary compatible.
+
+The reason for the difference between these flavors — Visual Studio’s CComPtr<T> and other related utilities expect interface IDs specified with __declspec(uuid) directive.
+
+That language extension is specific to Visual C++, not supported in GCC nor Clang compilers. \ No newline at end of file
diff --git a/Whisper/API/SpecialTokens.h b/Whisper/API/SpecialTokens.h
new file mode 100644
index 0000000..67fd020
--- /dev/null
+++ b/Whisper/API/SpecialTokens.h
@@ -0,0 +1,25 @@
+#pragma once
+
+namespace Whisper
+{
+ struct SpecialTokens
+ {
+ // The end of a transcription, token_eot
+ int TranscriptionEnd;
+ // Start of a transcription, token_sot
+ int TranscriptionStart;
+ // Represents the previous word in the transcription. It is used to help the model predict the current word based on the context of the words that came before it.
+ int PreviousWord; // token_prev
+ // Start of a sentence
+ int SentenceStart; // token_solm
+ //Represents the word "not" in the transcription
+ int Not; // token_not
+ //New transcription
+ int TranscriptionBegin; // token_beg
+
+ // token_translate
+ int TaskTranslate;
+ // token_transcribe
+ int TaskTranscribe;
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/TranscribeStructs.h b/Whisper/API/TranscribeStructs.h
new file mode 100644
index 0000000..ac28357
--- /dev/null
+++ b/Whisper/API/TranscribeStructs.h
@@ -0,0 +1,127 @@
+#pragma once
+#include <stdint.h>
+#include <assert.h>
+
+namespace Whisper
+{
+ enum struct eModelImplementation : uint32_t
+ {
+ GPU = 1,
+ Hybrid = 2,
+ Reference = 3,
+ };
+
+ struct sTimeSpanFields
+ {
+ uint32_t days;
+ uint8_t hours, minutes, seconds;
+ uint32_t ticks;
+
+ sTimeSpanFields( uint64_t tt )
+ {
+ ticks = (uint32_t)( tt % 10'000'000 );
+ tt /= 10'000'000;
+ seconds = (uint8_t)( tt % 60 );
+ tt /= 60;
+ minutes = (uint8_t)( tt % 60 );
+ tt /= 60;
+ hours = (uint8_t)( tt % 24 );
+ tt /= 24;
+ days = (uint32_t)tt;
+ }
+ };
+
+ struct sTimeSpan
+ {
+ uint64_t ticks;
+
+ operator sTimeSpanFields() const
+ {
+ return sTimeSpanFields{ ticks };
+ }
+ void operator=( uint64_t tt )
+ {
+ ticks = tt;
+ }
+ void operator=( int64_t tt )
+ {
+ assert( tt >= 0 );
+ ticks = (uint64_t)tt;
+ }
+ };
+
+ // Start and end times of the segment or token, expressed in 100-nanosecond ticks
+ struct sTimeInterval
+ {
+ sTimeSpan begin, end;
+ };
+
+ // Segment data
+ struct sSegment
+ {
+ // Segment text, null-terminated, and probably UTF-8 encoded
+ const char* text;
+ // Start and end times of the segment
+ sTimeInterval time;
+ uint32_t firstToken, countTokens;
+ };
+
+ enum eTokenFlags : uint32_t
+ {
+ None = 0,
+ Special = 1,
+ };
+ inline bool operator &( eTokenFlags a, eTokenFlags b )
+ {
+ return 0 != ( (uint32_t)a & (uint32_t)b );
+ }
+
+ // Token data
+ struct sToken
+ {
+ // Token text, null-terminated, and probably UTF-8 encoded
+ const char* text;
+ // Start and end times of the token
+ sTimeInterval time;
+ // Probability of the token
+ float probability;
+ // Probability of the timestamp token
+ float probabilityTimestamp;
+ // Sum of probabilities of all timestamp tokens
+ float ptsum;
+ // Voice length of the token
+ float vlen;
+ // Token id
+ int id;
+ eTokenFlags flags;
+ };
+
+ struct sTranscribeLength
+ {
+ uint32_t countSegments, countTokens;
+ };
+
+ enum struct eResultFlags : uint32_t
+ {
+ None = 0,
+ // Return individual tokens in addition to the segments
+ Tokens = 1,
+ // Return timestamps
+ Timestamps = 2,
+
+ // Create a new COM object for the results.
+ // Without this flag, the context returns a pointer to the COM object stored in the context.
+ // The content of that object is replaced every time you call iContext.getResults method
+ NewObject = 0x100,
+ };
+
+ inline eResultFlags operator |( eResultFlags a, eResultFlags b )
+ {
+ return (eResultFlags)( (uint32_t)a | (uint32_t)b );
+ }
+
+ inline bool operator &( eResultFlags a, eResultFlags b )
+ {
+ return 0 != ( (uint32_t)a & (uint32_t)b );
+ }
+} \ No newline at end of file
diff --git a/Whisper/API/iContext.cl.h b/Whisper/API/iContext.cl.h
new file mode 100644
index 0000000..97d34c7
--- /dev/null
+++ b/Whisper/API/iContext.cl.h
@@ -0,0 +1,66 @@
+#pragma once
+#include "../../ComLightLib/comLightCommon.h"
+#include "iTranscribeResult.cl.h"
+#include "SpecialTokens.h"
+#include "loggerApi.h"
+#include "sLanguageList.h"
+#include "sLoadModelCallbacks.h"
+
+namespace Whisper
+{
+ struct iModel;
+ struct iAudioBuffer;
+ struct iAudioReader;
+ struct iAudioCapture;
+ struct sCaptureCallbacks;
+ struct sFullParams;
+ enum struct eModelImplementation : uint32_t;
+ enum struct eSamplingStrategy : int;
+ using whisper_token = int;
+ struct sProgressSink;
+
+ struct DECLSPEC_NOVTABLE iContext : public ComLight::IUnknown
+ {
+ DEFINE_INTERFACE_ID( "{b9956374-3b18-4943-90f2-2ab18a404537}" );
+
+ // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+ // Uses the specified decoding strategy to obtain the text.
+ virtual HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) = 0;
+ virtual HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) = 0;
+ virtual HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) = 0;
+
+ virtual HRESULT COMLIGHTCALL getResults( eResultFlags flags, iTranscribeResult** pp ) const = 0;
+
+ virtual HRESULT COMLIGHTCALL getModel( iModel** pp ) = 0;
+
+ virtual HRESULT COMLIGHTCALL fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi ) = 0;
+
+ // Performance information
+ virtual HRESULT COMLIGHTCALL timingsPrint() = 0;
+ virtual HRESULT COMLIGHTCALL timingsReset() = 0;
+ };
+
+ struct DECLSPEC_NOVTABLE iModel : public ComLight::IUnknown
+ {
+ DEFINE_INTERFACE_ID( "{abefb4c9-e8d8-46a3-8747-5afbadef1adb}" );
+
+ virtual HRESULT COMLIGHTCALL createContext( iContext** pp ) = 0;
+
+ virtual HRESULT COMLIGHTCALL isMultilingual() = 0;
+
+ virtual HRESULT COMLIGHTCALL getSpecialTokens( SpecialTokens& rdi ) = 0;
+
+ // Token Id -> String
+ virtual const char* COMLIGHTCALL stringFromToken( whisper_token token ) = 0;
+ };
+
+ HRESULT COMLIGHTCALL setupLogger( const sLoggerSetup& setup );
+ HRESULT COMLIGHTCALL loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp );
+
+ uint32_t COMLIGHTCALL findLanguageKeyW( const wchar_t* lang );
+ uint32_t COMLIGHTCALL findLanguageKeyA( const char* lang );
+
+ HRESULT COMLIGHTCALL getSupportedLanguages( sLanguageList& rdi );
+}
+
+#include "sFullParams.h" \ No newline at end of file
diff --git a/Whisper/API/iContext.h b/Whisper/API/iContext.h
new file mode 100644
index 0000000..9661093
--- /dev/null
+++ b/Whisper/API/iContext.h
@@ -0,0 +1,61 @@
+#pragma once
+#include "iTranscribeResult.h"
+#include "SpecialTokens.h"
+#include "loggerApi.h"
+#include "sLanguageList.h"
+#include "sLoadModelCallbacks.h"
+
+namespace Whisper
+{
+ __interface iModel;
+ __interface iAudioBuffer;
+ __interface iAudioReader;
+ __interface iAudioCapture;
+ struct sCaptureCallbacks;
+ struct sFullParams;
+ enum struct eModelImplementation : uint32_t;
+ enum struct eSamplingStrategy : int;
+ using whisper_token = int;
+ struct sProgressSink;
+
+ __interface __declspec( novtable, uuid( "b9956374-3b18-4943-90f2-2ab18a404537" ) ) iContext : public IUnknown
+ {
+ // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+ // Uses the specified decoding strategy to obtain the text.
+ HRESULT __stdcall runFull( const sFullParams& params, const iAudioBuffer* buffer );
+ HRESULT __stdcall runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader );
+ HRESULT __stdcall runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader );
+
+ HRESULT __stdcall getResults( eResultFlags flags, iTranscribeResult** pp ) const;
+
+ HRESULT __stdcall getModel( iModel** pp );
+
+ HRESULT __stdcall fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi );
+
+ // Performance information
+ HRESULT __stdcall timingsPrint();
+ HRESULT __stdcall timingsReset();
+ };
+
+ __interface __declspec( novtable, uuid( "abefb4c9-e8d8-46a3-8747-5afbadef1adb" ) ) iModel : public IUnknown
+ {
+ HRESULT __stdcall createContext( iContext** pp );
+
+ HRESULT __stdcall isMultilingual();
+
+ HRESULT __stdcall getSpecialTokens( SpecialTokens& rdi );
+
+ // Token Id -> String
+ const char* __stdcall stringFromToken( whisper_token token );
+ };
+
+ HRESULT __stdcall setupLogger( const sLoggerSetup& setup );
+ HRESULT __stdcall loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp );
+
+ uint32_t __stdcall findLanguageKeyW( const wchar_t* lang );
+ uint32_t __stdcall findLanguageKeyA( const char* lang );
+
+ HRESULT __stdcall getSupportedLanguages( sLanguageList& rdi );
+}
+
+#include "sFullParams.h" \ No newline at end of file
diff --git a/Whisper/API/iMediaFoundation.cl.h b/Whisper/API/iMediaFoundation.cl.h
new file mode 100644
index 0000000..516b67f
--- /dev/null
+++ b/Whisper/API/iMediaFoundation.cl.h
@@ -0,0 +1,48 @@
+#pragma once
+#include "../../ComLightLib/comLightCommon.h"
+#include "MfStructs.h"
+
+struct IMFSourceReader;
+
+namespace Whisper
+{
+ struct DECLSPEC_NOVTABLE iAudioBuffer : public ComLight::IUnknown
+ {
+ DEFINE_INTERFACE_ID( "{013583aa-c9eb-42bc-83db-633c2c317051}" );
+
+ virtual uint32_t COMLIGHTCALL countSamples() const = 0;
+ virtual const float* COMLIGHTCALL getPcmMono() const = 0;
+ virtual const float* COMLIGHTCALL getPcmStereo() const = 0;
+ virtual HRESULT COMLIGHTCALL getTime( int64_t& rdi ) const = 0;
+ };
+
+ struct DECLSPEC_NOVTABLE iAudioReader : public ComLight::IUnknown
+ {
+ DEFINE_INTERFACE_ID( "{35b988da-04a6-476a-a193-d8891d5dc390}" );
+
+ virtual HRESULT COMLIGHTCALL getDuration( int64_t& rdi ) const = 0;
+ virtual HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const = 0;
+ virtual HRESULT COMLIGHTCALL requestedStereo() const = 0;
+ };
+
+ struct DECLSPEC_NOVTABLE iAudioCapture : public ComLight::IUnknown
+ {
+ DEFINE_INTERFACE_ID( "{747752c2-d9fd-40df-8847-583c781bf013}" );
+
+ virtual HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const = 0;
+ virtual const sCaptureParams& COMLIGHTCALL getParams() const = 0;
+ };
+
+ struct DECLSPEC_NOVTABLE iMediaFoundation : public ComLight::IUnknown
+ {
+ DEFINE_INTERFACE_ID( "{fb9763a5-d77d-4b6e-aff8-f494813cebd8}" );
+
+ virtual HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const = 0;
+ virtual HRESULT COMLIGHTCALL openAudioFile( LPCTSTR path, bool stereo, iAudioReader** pp ) = 0;
+
+ virtual HRESULT COMLIGHTCALL listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ) = 0;
+ virtual HRESULT COMLIGHTCALL openCaptureDevice( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) = 0;
+ };
+
+ HRESULT COMLIGHTCALL initMediaFoundation( iMediaFoundation** pp );
+} \ No newline at end of file
diff --git a/Whisper/API/iMediaFoundation.h b/Whisper/API/iMediaFoundation.h
new file mode 100644
index 0000000..93dc287
--- /dev/null
+++ b/Whisper/API/iMediaFoundation.h
@@ -0,0 +1,39 @@
+#pragma once
+#include <stdint.h>
+#include "MfStructs.h"
+struct IMFSourceReader;
+
+namespace Whisper
+{
+ __interface __declspec( novtable, uuid( "013583aa-c9eb-42bc-83db-633c2c317051" ) ) iAudioBuffer : public IUnknown
+ {
+ uint32_t __stdcall countSamples() const;
+ const float* __stdcall getPcmMono() const;
+ const float* __stdcall getPcmStereo() const;
+ HRESULT __stdcall getTime( int64_t& rdi ) const;
+ };
+
+ __interface __declspec( novtable, uuid( "35b988da-04a6-476a-a193-d8891d5dc390" ) ) iAudioReader : public IUnknown
+ {
+ HRESULT __stdcall getDuration( int64_t& rdi ) const;
+ HRESULT __stdcall getReader( IMFSourceReader** pp ) const;
+ HRESULT __stdcall requestedStereo() const;
+ };
+
+ __interface __declspec( novtable, uuid( "747752c2-d9fd-40df-8847-583c781bf013" ) ) iAudioCapture : public IUnknown
+ {
+ HRESULT __stdcall getReader( IMFSourceReader** pp ) const;
+ const sCaptureParams& __stdcall getParams() const;
+ };
+
+ __interface __declspec( novtable, uuid( "fb9763a5-d77d-4b6e-aff8-f494813cebd8" ) ) iMediaFoundation : public IUnknown
+ {
+ HRESULT __stdcall loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const;
+ HRESULT __stdcall openAudioFile( LPCTSTR path, bool stereo, iAudioReader** pp );
+
+ HRESULT __stdcall listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv );
+ HRESULT __stdcall openCaptureDevice( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp );
+ };
+
+ HRESULT __stdcall initMediaFoundation( iMediaFoundation** pp );
+} \ No newline at end of file
diff --git a/Whisper/API/iTranscribeResult.cl.h b/Whisper/API/iTranscribeResult.cl.h
new file mode 100644
index 0000000..ab65178
--- /dev/null
+++ b/Whisper/API/iTranscribeResult.cl.h
@@ -0,0 +1,15 @@
+#pragma once
+#include "TranscribeStructs.h"
+#include "../../ComLightLib/comLightCommon.h"
+
+namespace Whisper
+{
+ struct iTranscribeResult : public ComLight::IUnknown
+ {
+ DEFINE_INTERFACE_ID( "{2871a73f-5ce3-48f8-8779-6582ee11935e}" );
+
+ virtual HRESULT COMLIGHTCALL getSize( sTranscribeLength& rdi ) const = 0;
+ virtual const sSegment* COMLIGHTCALL getSegments() const = 0;
+ virtual const sToken* COMLIGHTCALL getTokens() const = 0;
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/iTranscribeResult.h b/Whisper/API/iTranscribeResult.h
new file mode 100644
index 0000000..27e0c0d
--- /dev/null
+++ b/Whisper/API/iTranscribeResult.h
@@ -0,0 +1,12 @@
+#pragma once
+#include "TranscribeStructs.h"
+
+namespace Whisper
+{
+ __interface __declspec( novtable, uuid( "2871a73f-5ce3-48f8-8779-6582ee11935e" ) ) iTranscribeResult : public IUnknown
+ {
+ HRESULT __stdcall getSize( sTranscribeLength& rdi ) const;
+ const sSegment* __stdcall getSegments() const;
+ const sToken* __stdcall getTokens() const;
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/loggerApi.h b/Whisper/API/loggerApi.h
new file mode 100644
index 0000000..6af1c4e
--- /dev/null
+++ b/Whisper/API/loggerApi.h
@@ -0,0 +1,35 @@
+#pragma once
+#include <stdint.h>
+
+namespace Whisper
+{
+ // Log level for messages
+ enum struct eLogLevel : uint8_t
+ {
+ Error = 0,
+ Warning = 1,
+ Info = 2,
+ Debug = 3
+ };
+ enum struct eLoggerFlags : uint8_t
+ {
+ UseStandardError = 1,
+ SkipFormatMessage = 2,
+ };
+
+ // C function pointer to receive log messages from the library. The messages are encoded in UTF-8.
+ using pfnLoggerSink = void( __stdcall* )( void* context, eLogLevel lvl, const char* message );
+
+ // A sink to receive log messages produced by MeshRepair.dll
+ struct sLoggerSetup
+ {
+ // C function pointer to receive log messages from the library
+ pfnLoggerSink sink = nullptr;
+ // Optional context parameter for the sink function; when consuming from C# you don't need that, pass IntPtr.Zero, delegates can capture things.
+ void* context = nullptr;
+ // Maximum log level to produce
+ eLogLevel level;
+ // Flags about the logger
+ eLoggerFlags flags = (eLoggerFlags)0;
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/sFullParams.h b/Whisper/API/sFullParams.h
new file mode 100644
index 0000000..0a1d352
--- /dev/null
+++ b/Whisper/API/sFullParams.h
@@ -0,0 +1,136 @@
+#pragma once
+#include <stdint.h>
+#include <assert.h>
+
+namespace Whisper
+{
+ // Available sampling strategies
+ enum struct eSamplingStrategy : int
+ {
+ // Always select the most probable token
+ Greedy,
+ // TODO: not implemented yet!
+ BeamSearch,
+ };
+
+ using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;
+ using pfnEncoderBegin = HRESULT( __cdecl* )( iContext* ctx, void* user_data ) noexcept;
+
+ enum struct eFullParamsFlags : uint32_t
+ {
+ Translate = 1,
+ NoContext = 2,
+ SingleSegment = 4,
+ PrintSpecial = 8,
+ PrintProgress = 0x10,
+ PrintRealtime = 0x20,
+ PrintTimestamps = 0x40,
+
+ // Experimental
+ TokenTimestamps = 0x100,
+ SpeedupAudio = 0x200,
+ };
+
+ inline eFullParamsFlags operator | ( eFullParamsFlags a, eFullParamsFlags b )
+ {
+ return (eFullParamsFlags)( (uint32_t)a | (uint32_t)b );
+ }
+ inline void operator |= ( eFullParamsFlags& a, eFullParamsFlags b )
+ {
+ a = a | b;
+ }
+
+ struct sFullParams
+ {
+ eSamplingStrategy strategy;
+ // Count of CPU threads
+ int cpuThreads;
+ int n_max_text_ctx;
+ int offset_ms; // start offset in ms
+ int duration_ms; // audio duration to process in ms
+ eFullParamsFlags flags;
+ uint32_t language;
+
+ // [EXPERIMENTAL] token-level timestamps
+ float thold_pt; // timestamp token probability threshold (~0.01)
+ float thold_ptsum; // timestamp token sum probability threshold (~0.01)
+ int max_len; // max segment length in characters
+ int max_tokens; // max tokens per segment (0 = no limit)
+
+ struct
+ {
+ int n_past;
+ } greedy;
+
+ struct
+ {
+ int n_past;
+ int beam_width;
+ int n_best;
+ } beam_search;
+
+ // [EXPERIMENTAL] speed-up techniques
+ int audio_ctx; // overwrite the audio context size (0 = use default)
+
+ // tokens to provide the whisper model as initial prompt
+ // these are prepended to any existing text context from a previous call
+ const whisper_token* prompt_tokens;
+ int prompt_n_tokens;
+
+ pfnNewSegment new_segment_callback;
+ void* new_segment_callback_user_data;
+
+ pfnEncoderBegin encoder_begin_callback;
+ void* encoder_begin_callback_user_data;
+
+ // Couple utility methods, they workaround the lack of bit fields in C++
+ inline bool flag( eFullParamsFlags f ) const
+ {
+ return 0 != ( (uint32_t)flags & (uint32_t)f );
+ }
+ inline void resetFlag( eFullParamsFlags bit )
+ {
+ uint32_t f = (uint32_t)flags;
+ f &= ~(uint32_t)bit;
+ flags = (eFullParamsFlags)f;
+ }
+ inline void setFlag( eFullParamsFlags bit, bool set = true )
+ {
+ uint32_t f = (uint32_t)flags;
+ if( set )
+ f |= (uint32_t)bit;
+ else
+ f &= ~(uint32_t)bit;
+ flags = (eFullParamsFlags)f;
+ }
+ };
+
+ struct sSegmentTime
+ {
+ int64_t begin, end;
+ };
+
+ inline uint32_t makeLanguageKey( const char* code )
+ {
+ assert( strlen( code ) <= 4 );
+ uint32_t res = 0;
+ uint32_t shift = 0;
+ for( size_t i = 0; i < 4; i++, code++, shift += 8 )
+ {
+ const char c = *code;
+ if( c == '\0' )
+ return res;
+ uint32_t u32 = (uint8_t)c;
+ u32 = u32 << shift;
+ res |= u32;
+ }
+ return res;
+ }
+
+ using pfnReportProgress = HRESULT( __stdcall* )( double val, iContext* ctx, void* pv ) noexcept;
+ struct sProgressSink
+ {
+ pfnReportProgress pfn;
+ void* pv;
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/sLanguageList.h b/Whisper/API/sLanguageList.h
new file mode 100644
index 0000000..49ca596
--- /dev/null
+++ b/Whisper/API/sLanguageList.h
@@ -0,0 +1,18 @@
+#pragma once
+#include <stdint.h>
+
+namespace Whisper
+{
+ struct sLanguageEntry
+ {
+ uint32_t key;
+ int id;
+ const char* name;
+ };
+
+ struct sLanguageList
+ {
+ uint32_t length;
+ const sLanguageEntry* pointer;
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/sLoadModelCallbacks.h b/Whisper/API/sLoadModelCallbacks.h
new file mode 100644
index 0000000..f5248c6
--- /dev/null
+++ b/Whisper/API/sLoadModelCallbacks.h
@@ -0,0 +1,14 @@
+#pragma once
+
+namespace Whisper
+{
+ using pfnLoadProgress = HRESULT( __stdcall* )( double val, void* pv ) noexcept;
+ using pfnCancel = HRESULT( __stdcall* )( void* pv ) noexcept;
+
+ struct sLoadModelCallbacks
+ {
+ pfnLoadProgress progress;
+ pfnCancel cancel;
+ void* pv;
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/whisperComLight.h b/Whisper/API/whisperComLight.h
new file mode 100644
index 0000000..c7f0b93
--- /dev/null
+++ b/Whisper/API/whisperComLight.h
@@ -0,0 +1,4 @@
+#pragma once
+#include "iMediaFoundation.cl.h"
+#include "iContext.cl.h"
+#include "iTranscribeResult.cl.h" \ No newline at end of file
diff --git a/Whisper/API/whisperWindows.h b/Whisper/API/whisperWindows.h
new file mode 100644
index 0000000..925e307
--- /dev/null
+++ b/Whisper/API/whisperWindows.h
@@ -0,0 +1,4 @@
+#pragma once
+#include "iMediaFoundation.h"
+#include "iContext.h"
+#include "iTranscribeResult.h" \ No newline at end of file
diff --git a/Whisper/CPU/BufferAllocator.cpp b/Whisper/CPU/BufferAllocator.cpp
new file mode 100644
index 0000000..156382c
--- /dev/null
+++ b/Whisper/CPU/BufferAllocator.cpp
@@ -0,0 +1,145 @@
+#include <stdafx.h>
+#include "BufferAllocator.h"
+#include <immintrin.h>
+#include <ammintrin.h>
+using namespace CpuCompute;
+
+HRESULT BufferAllocator::create( size_t cb )
+{
+ CHECK( buffer.allocate( cb ) );
+ head = 0;
+ size = cb;
+ dbgMarkUninitializedMemory( buffer.pointer(), cb );
+ return S_OK;
+}
+
+namespace
+{
+ // Round up the integer by 32 bytes
+ __forceinline size_t roundUpAlloc( size_t cb )
+ {
+ const size_t mask = 31;
+ cb += mask;
+ // We require AVX1+FMA3 support, might as well use BMI1
+ return _andn_u64( mask, cb );
+ }
+}
+
+void* BufferAllocator::allocate( size_t cb, size_t align ) noexcept
+{
+ assert( align <= 32 );
+ cb = roundUpAlloc( cb );
+
+ uint8_t* pointer = buffer.pointer();
+ if( head + cb > size || nullptr == pointer )
+ {
+ logError( u8"BufferAllocator.allocate, not enough capacity" );
+ return nullptr;
+ }
+
+ void* const res = pointer + head;
+ head += cb;
+ assert( head <= size );
+ dbgMarkUninitializedMemory( res, cb );
+ return res;
+}
+
+namespace
+{
+ // 2 MB of memory, we hope the OS kernel will then be smart enough to give us large pages.
+ constexpr size_t virtualAllocGranularityExp2 = 21;
+
+ constexpr size_t virtualAllocGranularityMask = ( ( (size_t)1 ) << virtualAllocGranularityExp2 ) - 1;
+
+ // Round up the integer by 2 megabytes
+ __forceinline size_t roundUpVirtualAlloc( size_t cb )
+ {
+ const size_t mask = virtualAllocGranularityMask;
+ cb += mask;
+ return _andn_u64( mask, cb );
+ }
+}
+
+HRESULT VirtualAllocator::create( size_t cb )
+{
+ if( nullptr != pointer )
+ return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
+ cb = roundUpVirtualAlloc( cb );
+ pointer = (uint8_t*)VirtualAlloc( NULL, cb, MEM_RESERVE, PAGE_READWRITE );
+ if( nullptr != pointer )
+ {
+ head = 0;
+ sizeAllocated = 0;
+ sizeVirtual = cb;
+ return S_OK;
+ }
+
+ const HRESULT hr = getLastHr();
+ logErrorHr( hr, u8"VirtualAlloc failed" );
+ return hr;
+}
+
+void* VirtualAllocator::allocate( size_t cb, size_t align ) noexcept
+{
+ assert( align <= 32 );
+ cb = roundUpAlloc( cb );
+
+ const size_t newHead = head + cb;
+ if( newHead <= sizeAllocated )
+ {
+ void* const res = pointer + head;
+ head = newHead;
+ dbgMarkUninitializedMemory( res, cb );
+ return res;
+ }
+
+ if( newHead <= sizeVirtual )
+ {
+ uint8_t* const ptrCommit = pointer + sizeAllocated;
+ const size_t cbCommit = roundUpVirtualAlloc( newHead ) - sizeAllocated;
+ void* const res = VirtualAlloc( ptrCommit, cbCommit, MEM_COMMIT, PAGE_READWRITE );
+ if( nullptr != res )
+ {
+ sizeAllocated += cbCommit;
+ assert( sizeAllocated <= sizeVirtual );
+ void* const res = pointer + head;
+ head = newHead;
+ dbgMarkUninitializedMemory( res, cb );
+ return res;
+ }
+
+ const HRESULT hr = getLastHr();
+ logErrorHr( hr, u8"VirtualAllocator.allocate, VirtualAlloc failed" );
+ return nullptr;
+ }
+
+ logError( u8"VirtualAllocator.allocate, not enough arena capacity" );
+ return nullptr;
+}
+
+VirtualAllocator::~VirtualAllocator()
+{
+ if( nullptr == pointer )
+ return;
+
+ if( VirtualFree( pointer, 0, MEM_RELEASE ) )
+ {
+ pointer = nullptr;
+ return;
+ }
+
+ const HRESULT hr = getLastHr();
+ logErrorHr( hr, u8"VirtualFree failed" );
+}
+
+#ifndef NDEBUG
+// Reusing Microsoft's magic numbers: https://asawicki.info/news_1292_magic_numbers_in_visual_c
+void CpuCompute::dbgMarkUninitializedMemory( void* pv, size_t cb )
+{
+ __stosb( (uint8_t*)pv, 0xCD, cb );
+}
+void CpuCompute::dbgMarkFreedMemory( void* pv, size_t cb )
+{
+ __stosd( (DWORD*)pv, 0xFEEEFEEEu, cb / 4 );
+}
+#endif \ No newline at end of file
diff --git a/Whisper/CPU/BufferAllocator.h b/Whisper/CPU/BufferAllocator.h
new file mode 100644
index 0000000..750a565
--- /dev/null
+++ b/Whisper/CPU/BufferAllocator.h
@@ -0,0 +1,64 @@
+#pragma once
+#include "LargeBuffer.h"
+#include "Tensor.h"
+
+namespace CpuCompute
+{
+#ifdef NDEBUG
+ inline void dbgMarkUninitializedMemory( void* pv, size_t cb ) { }
+ inline void dbgMarkFreedMemory( void* pv, size_t cb ) { }
+#else
+ void dbgMarkUninitializedMemory( void* pv, size_t cb );
+ void dbgMarkFreedMemory( void* pv, size_t cb );
+#endif
+
+ // An implementation of arena allocator which slices pieces of a large buffer allocated in advance
+ class BufferAllocator : public iArenaAllocator
+ {
+ LargeBuffer buffer;
+ size_t head = 0;
+ size_t size = 0;
+
+ void resetArena() noexcept override final
+ {
+ head = 0;
+ dbgMarkFreedMemory( buffer.pointer(), size );
+ }
+
+ void* allocate( size_t cb, size_t align ) noexcept override final;
+
+ public:
+ BufferAllocator() = default;
+ BufferAllocator( const BufferAllocator& ) = delete;
+ ~BufferAllocator() = default;
+
+ // Allocate a large buffer with the specified count of bytes
+ HRESULT create( size_t cb );
+ };
+
+ // An implementation of arena allocator which allocates a large chunk of virtual memory, and maps new physical pages into that memory region as needed.
+ class VirtualAllocator : public iArenaAllocator
+ {
+ uint8_t* pointer = nullptr;
+ size_t head = 0;
+ size_t sizeAllocated = 0;
+ size_t sizeVirtual = 0;
+
+ void resetArena() noexcept override final
+ {
+ head = 0;
+ dbgMarkFreedMemory( pointer, sizeAllocated );
+ }
+
+ void* allocate( size_t cb, size_t align ) noexcept override final;
+
+ public:
+
+ VirtualAllocator() = default;
+ VirtualAllocator( const VirtualAllocator& ) = delete;
+ ~VirtualAllocator();
+
+ // Reserve virtual memory space for the specified count of bytes in the arena, but don't allocate any pages
+ HRESULT create( size_t cb );
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/DecoderTensors.cpp b/Whisper/CPU/DecoderTensors.cpp
new file mode 100644
index 0000000..22de476
--- /dev/null
+++ b/Whisper/CPU/DecoderTensors.cpp
@@ -0,0 +1,68 @@
+#include "stdafx.h"
+#include "DecoderTensors.h"
+using namespace CpuCompute;
+
+#if TENSOR_GGML_COMPAT
+namespace
+{
+ class CompatContext
+ {
+ std::vector<ggml_tensor>& vec;
+ size_t index;
+
+ public:
+ CompatContext( std::vector<ggml_tensor>& dest, size_t layers ) :
+ vec( dest )
+ {
+ constexpr size_t tensorsPerLayer = 21;
+ const size_t count = tensorsPerLayer * layers + 4;
+ vec.resize( count );
+ index = 0;
+ }
+
+ void add( const Tensor& rsi, ggml_tensor*& res )
+ {
+ ggml_tensor& ten = vec[ index ];
+ index++;
+ ten = rsi.ggml();
+ res = &ten;
+ }
+
+ void add2( const TensorPair& rsi, ggml_tensor*& w, ggml_tensor*& b )
+ {
+ add( rsi.w, w );
+ add( rsi.b, b );
+ }
+
+ bool isComplete() const
+ {
+ return index == vec.size();
+ }
+ };
+}
+
+void DecoderTensors::makeCompatTensors()
+{
+ CompatContext ctx( ggml, layers.size() );
+
+ ctx.add( positionalEmbedding, d_pe );
+ ctx.add( tokenEmbedding, d_te );
+ ctx.add2( ln, d_ln_w, d_ln_b );
+
+ for( auto& i : layers )
+ {
+ ctx.add2( i.attnLn0, i.attn_ln_0_w, i.attn_ln_0_b );
+ ctx.add2( i.attnLn1, i.attn_ln_1_w, i.attn_ln_1_b );
+ ctx.add2( i.attnQuery, i.attn_q_w, i.attn_q_b );
+ ctx.add( i.attnKey, i.attn_k_w );
+ ctx.add2( i.attnValue, i.attn_v_w, i.attn_v_b );
+ ctx.add2( i.crossAttnLn0, i.cross_attn_ln_0_w, i.cross_attn_ln_0_b );
+ ctx.add2( i.crossAttnLn1, i.cross_attn_ln_1_w, i.cross_attn_ln_1_b );
+ ctx.add2( i.crossAttnQuery, i.cross_attn_q_w, i.cross_attn_q_b );
+ ctx.add2( i.mlpLn, i.mlp_ln_w, i.mlp_ln_b );
+ ctx.add2( i.mlp0, i.mlp_0_w, i.mlp_0_b );
+ ctx.add2( i.mlp1, i.mlp_1_w, i.mlp_1_b );
+ }
+ assert( ctx.isComplete() );
+}
+#endif \ No newline at end of file
diff --git a/Whisper/CPU/DecoderTensors.h b/Whisper/CPU/DecoderTensors.h
new file mode 100644
index 0000000..2efa519
--- /dev/null
+++ b/Whisper/CPU/DecoderTensors.h
@@ -0,0 +1,131 @@
+#pragma once
+#include <vector>
+#include "Tensor.h"
+#include "LargeBuffer.h"
+#if TENSOR_GGML_COMPAT
+#include "../source/ggml.h"
+#endif
+
+namespace CpuCompute
+{
+ // A set of tensors for one decoder's layer
+ struct LayerDecoder
+ {
+ // decoder.blocks.*.attn_ln
+ TensorPair attnLn0;
+ // decoder.blocks.*.attn.out
+ TensorPair attnLn1;
+ // decoder.blocks.*.attn.query
+ TensorPair attnQuery;
+ // decoder.blocks.*.attn.key
+ Tensor attnKey;
+ // decoder.blocks.*.attn.value
+ TensorPair attnValue;
+ // decoder.blocks.*.cross_attn_ln
+ TensorPair crossAttnLn0;
+ // decoder.blocks.*.cross_attn.out
+ TensorPair crossAttnLn1;
+ // decoder.blocks.*.cross_attn.query
+ TensorPair crossAttnQuery;
+
+ // decoder.blocks.*.cross_attn.key
+ // Tensor crossAttnKey;
+ // decoder.blocks.*.cross_attn.value
+ // TensorPair crossAttnValue;
+
+ // decoder.blocks.*.mlp_ln
+ TensorPair mlpLn;
+ // decoder.blocks.*.mlp.0
+ TensorPair mlp0;
+ // decoder.blocks.*.mlp.2
+ TensorPair mlp1;
+
+#if TENSOR_GGML_COMPAT
+ // decoder.blocks.*.attn_ln
+ ggml_tensor* attn_ln_0_w;
+ ggml_tensor* attn_ln_0_b;
+
+ // decoder.blocks.*.attn.out
+ ggml_tensor* attn_ln_1_w;
+ ggml_tensor* attn_ln_1_b;
+
+ // decoder.blocks.*.attn.query
+ ggml_tensor* attn_q_w;
+ ggml_tensor* attn_q_b;
+
+ // decoder.blocks.*.attn.key
+ ggml_tensor* attn_k_w;
+
+ // decoder.blocks.*.attn.value
+ ggml_tensor* attn_v_w;
+ ggml_tensor* attn_v_b;
+
+ // decoder.blocks.*.cross_attn_ln
+ ggml_tensor* cross_attn_ln_0_w;
+ ggml_tensor* cross_attn_ln_0_b;
+
+ // decoder.blocks.*.cross_attn.out
+ ggml_tensor* cross_attn_ln_1_w;
+ ggml_tensor* cross_attn_ln_1_b;
+
+ // decoder.blocks.*.cross_attn.query
+ ggml_tensor* cross_attn_q_w;
+ ggml_tensor* cross_attn_q_b;
+
+ // decoder.blocks.*.mlp_ln
+ ggml_tensor* mlp_ln_w;
+ ggml_tensor* mlp_ln_b;
+
+ // decoder.blocks.*.mlp.0
+ ggml_tensor* mlp_0_w;
+ ggml_tensor* mlp_0_b;
+
+ // decoder.blocks.*.mlp.2
+ ggml_tensor* mlp_1_w;
+ ggml_tensor* mlp_1_b;
+#endif
+ };
+
+ struct DecoderTensors
+ {
+ // decoder.positional_embedding
+ Tensor positionalEmbedding;
+
+ // decoder.token_embedding
+ Tensor tokenEmbedding;
+
+ // decoder.ln
+ TensorPair ln;
+ // A vector of layers
+ std::vector<LayerDecoder> layers;
+
+ void setMemoryBuffer( LargeBuffer&& mem ) noexcept
+ {
+ memory = std::move( mem );
+#if TENSOR_GGML_COMPAT
+ makeCompatTensors();
+#endif
+ }
+
+#if TENSOR_GGML_COMPAT
+ void makeCompatTensors();
+
+ // decoder.positional_embedding
+ ggml_tensor* d_pe; // DD
+
+ // decoder.token_embedding
+ ggml_tensor* d_te; // DD
+
+ // decoder.ln
+ ggml_tensor* d_ln_w; // DD
+ ggml_tensor* d_ln_b; // DD
+#endif
+
+ private:
+ // A smart pointer which owns the memory for all the above tensors
+ LargeBuffer memory;
+#if TENSOR_GGML_COMPAT
+ std::vector<ggml_tensor> ggml;
+#endif
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/HybridLoader.cpp b/Whisper/CPU/HybridLoader.cpp
new file mode 100644
index 0000000..e96e3ac
--- /dev/null
+++ b/Whisper/CPU/HybridLoader.cpp
@@ -0,0 +1,140 @@
+#include "stdafx.h"
+#include "HybridLoader.h"
+using namespace CpuCompute;
+using namespace ComLight;
+
+static void populateDecodeTensorsMap( CAtlMap<CStringA, Tensor*>& map, int layersDec, DecoderTensors& dec )
+{
+ dec.layers.resize( layersDec );
+
+ map[ "decoder.positional_embedding" ] = &dec.positionalEmbedding;
+ map[ "decoder.token_embedding.weight" ] = &dec.tokenEmbedding;
+ map[ "decoder.ln.weight" ] = &dec.ln.w;
+ map[ "decoder.ln.bias" ] = &dec.ln.b;
+
+ CStringA tempString;
+ auto add = [ & ]( const char* name, int i, Tensor& t )
+ {
+ tempString.Format( "decoder.blocks.%i.%s", i, name );
+ map[ tempString ] = &t;
+ };
+
+ auto add2 = [ & ]( const char* name, int i, TensorPair& tensors )
+ {
+ tempString.Format( "decoder.blocks.%i.%s.weight", i, name );
+ map[ tempString ] = &tensors.w;
+ tempString.Format( "decoder.blocks.%i.%s.bias", i, name );
+ map[ tempString ] = &tensors.b;
+ };
+
+ for( int i = 0; i < layersDec; i++ )
+ {
+ auto& gpu = dec.layers[ i ];
+ add2( "mlp_ln", i, gpu.mlpLn );
+ add2( "mlp.0", i, gpu.mlp0 );
+ add2( "mlp.2", i, gpu.mlp1 );
+ add2( "attn_ln", i, gpu.attnLn0 );
+ add2( "attn.query", i, gpu.attnQuery );
+ add( "attn.key.weight", i, gpu.attnKey );
+
+ add2( "attn.value", i, gpu.attnValue );
+ add2( "attn.out", i, gpu.attnLn1 );
+
+ add2( "cross_attn_ln", i, gpu.crossAttnLn0 );
+ add2( "cross_attn.query", i, gpu.crossAttnQuery );
+
+ // These 3 tensors are used by the encode() method, to compute cross-attention buffers
+ // Need them in VRAM even for the hybrid model
+ // add( "cross_attn.key.weight", i, gpu.cross_attn_k_w );
+ // add2( "cross_attn.value", i, gpu.cross_attn_v_w, gpu.cross_attn_v_b );
+ add2( "cross_attn.out", i, gpu.crossAttnLn1 );
+ }
+}
+
+HybridLoader::HybridLoader( DecoderTensors& m, int countLayers ) :
+ destination( m )
+{
+ populateDecodeTensorsMap( map, countLayers, destination );
+ pending.reserve( map.GetCount() );
+}
+
+HRESULT HybridLoader::setupTensor( const CStringA& name, int n_dims, int ftype, const std::array<int, 4>& ne, ComLight::iReadStream* stream, int64_t& postponedBytes )
+{
+ auto p = map.Lookup( name );
+ if( nullptr == p )
+ return S_FALSE;
+
+ Tensor& rdi = *p->m_value;
+ PendingTensor& pt = pending.emplace_back();
+
+ __m128i vec = load16( ne.data() );
+ vec = _mm_insert_epi32( vec, 1, 3 );
+ store16( &rdi.ne, vec );
+ rdi.setDenseStrides();
+
+ pt.destPointer = p->m_value;
+ CHECK( stream->getPosition( pt.streamOffset ) );
+ pt.bufferOffset = bufferBytes;
+
+ size_t cbElement;
+ if( ftype == 0 )
+ {
+ rdi.setType( eDataType::FP32 );
+ cbElement = 4;
+ }
+ else
+ {
+ rdi.setType( eDataType::FP16 );
+ cbElement = 2;
+ }
+
+ const size_t totalElts = (size_t)(uint32_t)ne[ 0 ] * (uint32_t)ne[ 1 ] * (uint32_t)ne[ 2 ];
+ if( totalElts * cbElement > UINT_MAX )
+ return DISP_E_OVERFLOW;
+
+ size_t payloadBytes = cbElement * totalElts;
+ pt.payloadBytes = payloadBytes;
+ CHECK( stream->seek( payloadBytes, eSeekOrigin::Current ) );
+ postponedBytes += (int64_t)payloadBytes;
+
+ payloadBytes = ( payloadBytes + 31 ) & ( ~( (size_t)31 ) );
+ bufferBytes += payloadBytes;
+ return S_OK;
+}
+
+HRESULT HybridLoader::completeLoad( ComLight::iReadStream* stream, iLoaderProgressSink& progressSink )
+{
+ if( pending.size() != map.GetCount() )
+ {
+ logError( u8"Not all tensors loaded from model file - expected %zu, got %zu", map.GetCount(), pending.size() );
+ return E_INVALIDARG;
+ }
+
+ LargeBuffer buffer;
+ CHECK( buffer.allocate( bufferBytes ) );
+
+ uint8_t* rdi = buffer.pointer();
+
+ for( const auto& pt : pending )
+ {
+ if( pt.payloadBytes > INT_MAX )
+ return DISP_E_OVERFLOW;
+ CHECK( stream->seek( pt.streamOffset, eSeekOrigin::Begin ) );
+
+ int written = 0;
+ CHECK( stream->read( rdi, (int)pt.payloadBytes, written ) );
+ CHECK( progressSink.gotBytes( (int64_t)pt.payloadBytes ) );
+
+ pt.destPointer->setDataPointer( rdi );
+
+ const size_t cb = ( pt.payloadBytes + 31 ) & ( ~( (size_t)31 ) );
+ rdi += cb;
+ }
+
+ CHECK( buffer.setReadOnly( bufferBytes ) );
+ destination.setMemoryBuffer( std::move( buffer ) );
+
+ constexpr double mulMb = 1.0 / ( 1 << 20 );
+ logDebug( u8"Loaded %zu decoder tensors, %g MB RAM", pending.size(), mulMb * (double)(int64_t)bufferBytes );
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/CPU/HybridLoader.h b/Whisper/CPU/HybridLoader.h
new file mode 100644
index 0000000..8b12804
--- /dev/null
+++ b/Whisper/CPU/HybridLoader.h
@@ -0,0 +1,37 @@
+#pragma once
+#include "DecoderTensors.h"
+#include <atlstr.h>
+#include <atlcoll.h>
+#include "../../ComLightLib/streams.h"
+
+namespace CpuCompute
+{
+ __interface iLoaderProgressSink
+ {
+ HRESULT gotBytes( int64_t cb );
+ };
+
+ class HybridLoader
+ {
+ DecoderTensors& destination;
+ CAtlMap<CStringA, Tensor*> map;
+ size_t bufferBytes = 0;
+
+ struct alignas( 32 ) PendingTensor
+ {
+ Tensor* destPointer = nullptr;
+ int64_t streamOffset = 0;
+ size_t bufferOffset = 0;
+ size_t payloadBytes = 0;
+ };
+ std::vector<PendingTensor> pending;
+
+ public:
+
+ HybridLoader( DecoderTensors& m, int countLayers );
+
+ HRESULT setupTensor( const CStringA& name, int n_dims, int ftype, const std::array<int, 4>& ne, ComLight::iReadStream* stream, int64_t& postponedBytes );
+
+ HRESULT completeLoad( ComLight::iReadStream* stream, iLoaderProgressSink& progressSink );
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/KvTensors.h b/Whisper/CPU/KvTensors.h
new file mode 100644
index 0000000..a7897d3
--- /dev/null
+++ b/Whisper/CPU/KvTensors.h
@@ -0,0 +1,36 @@
+#pragma once
+#include "Tensor.h"
+#include "LargeBuffer.h"
+#include "../Whisper/sModelParams.h"
+
+namespace CpuCompute
+{
+ class KvTensors
+ {
+ uint16_t* keys = nullptr;
+ uint16_t* values = nullptr;
+ uint32_t size = 0;
+
+ CpuCompute::LargeBuffer memory;
+
+ public:
+ // Create these two large tensors, FP16 precision
+ HRESULT create( const Whisper::sModelParams& mp );
+
+ // A slice of model.memory_cross_k tensor
+ Tensor keysView( uint32_t len, uint32_t off ) const
+ {
+ if( len + off <= size )
+ return Tensor::fromData( keys + off, eDataType::FP16, len );
+ throw E_BOUNDS;
+ }
+
+ // A slice of model.memory_cross_v tensor
+ Tensor valuesView( uint32_t len, uint32_t off ) const
+ {
+ if( len + off <= size )
+ return Tensor::fromData( values + off, eDataType::FP16, len );
+ throw E_BOUNDS;
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/KvTensorsCpu.cpp b/Whisper/CPU/KvTensorsCpu.cpp
new file mode 100644
index 0000000..88a70ff
--- /dev/null
+++ b/Whisper/CPU/KvTensorsCpu.cpp
@@ -0,0 +1,19 @@
+#include "stdafx.h"
+#include "KvTensors.h"
+using namespace CpuCompute;
+
+// Create these two large tensors, FP16 precision
+HRESULT KvTensors::create( const Whisper::sModelParams& mp )
+{
+ const uint32_t n_mem = mp.n_text_layer * mp.n_text_ctx;
+ const uint32_t n_elements = mp.n_text_state * n_mem;
+
+ const size_t cb = sizeof( uint16_t ) * (size_t)n_elements * 2;
+ CHECK( memory.allocate( cb ) );
+
+ uint16_t* pointer = (uint16_t*)memory.pointer();
+ keys = pointer;
+ values = pointer + n_elements;
+ size = n_elements;
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/CPU/LargeBuffer.cpp b/Whisper/CPU/LargeBuffer.cpp
new file mode 100644
index 0000000..e124686
--- /dev/null
+++ b/Whisper/CPU/LargeBuffer.cpp
@@ -0,0 +1,34 @@
+#include "stdafx.h"
+#include "LargeBuffer.h"
+using namespace CpuCompute;
+
+void LargeBuffer::deallocate()
+{
+ if( nullptr == pv )
+ return;
+ VirtualFree( pv, 0, MEM_RELEASE );
+ pv = nullptr;
+}
+
+HRESULT LargeBuffer::allocate( size_t cb )
+{
+ deallocate();
+
+ pv = VirtualAlloc( nullptr, cb, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE );
+ if( nullptr != pv )
+ return S_OK;
+ return HRESULT_FROM_WIN32( GetLastError() );
+}
+
+HRESULT LargeBuffer::setReadOnly( size_t cb )
+{
+ if( nullptr != pv )
+ {
+ DWORD op = 0;
+ if( VirtualProtect( pv, cb, PAGE_READONLY, &op ) )
+ return S_OK;
+ return HRESULT_FROM_WIN32( GetLastError() );
+ }
+ else
+ return OLE_E_BLANK;
+} \ No newline at end of file
diff --git a/Whisper/CPU/LargeBuffer.h b/Whisper/CPU/LargeBuffer.h
new file mode 100644
index 0000000..d9a9a22
--- /dev/null
+++ b/Whisper/CPU/LargeBuffer.h
@@ -0,0 +1,44 @@
+#pragma once
+
+namespace CpuCompute
+{
+ // A large memory buffer allocated with VirtualAlloc kernel API, bypassing the heap.
+ class LargeBuffer
+ {
+ void* pv = nullptr;
+ public:
+ LargeBuffer() = default;
+ LargeBuffer( const LargeBuffer& ) = delete;
+ LargeBuffer( LargeBuffer&& that ) noexcept
+ {
+ pv = that.pv;
+ that.pv = nullptr;
+ }
+ ~LargeBuffer()
+ {
+ deallocate();
+ }
+ void operator=( LargeBuffer&& that ) noexcept
+ {
+ std::swap( pv, that.pv );
+ }
+ void operator=( const LargeBuffer& that ) = delete;
+
+ // Allocate buffer with specified count of bytes, and read+write memory protection
+ // The OS kernel guarantees zero-initialization of that memory.
+ HRESULT allocate( size_t cb );
+
+ // Change memory protection of the buffer to read only
+ HRESULT setReadOnly( size_t cb );
+
+ // Unless the pointer is nullptr, deallocate the buffer
+ void deallocate();
+
+ // Pointer to the start of the buffer, aligned by memory page = 4 kilobytes
+ uint8_t* pointer() const
+ {
+ assert( nullptr != pv );
+ return (uint8_t*)pv;
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/MlContext.h b/Whisper/CPU/MlContext.h
new file mode 100644
index 0000000..062a7b8
--- /dev/null
+++ b/Whisper/CPU/MlContext.h
@@ -0,0 +1,71 @@
+#pragma once
+#include "Tensor.h"
+#include "ParallelForRunner.h"
+
+namespace CpuCompute
+{
+ class MlContext
+ {
+ ParallelForRunner pfor;
+ iMemoryAllocator* allocator = nullptr;
+
+ public:
+ MlContext( int threads );
+ MlContext( const MlContext& ) = delete;
+ ~MlContext() = default;
+
+ HRESULT setThreadsCount( int threads )
+ {
+ return pfor.setThreadsCount( threads );
+ }
+
+ iMemoryAllocator* setAllocator( iMemoryAllocator* alloc )
+ {
+ iMemoryAllocator* const ret = allocator;
+ allocator = alloc;
+ return ret;
+ }
+
+ Tensor createTensor( eDataType type, const std::array<uint32_t, 4>& size );
+ Tensor createTensor( eDataType type, std::initializer_list<uint32_t> size );
+
+ Tensor addRows( const Tensor& d_te, const Tensor& d_pe, const int* tokens, const int n_tokens, const int n_past );
+
+ Tensor norm( const Tensor& arg );
+
+ // cur = add( mul( repeat( w, cur ), cur ), repeat( b, cur ) );
+ void fmaRepeat( Tensor& cur, const Tensor& w, const Tensor& b );
+
+ inline void fmaRepeat( Tensor& cur, const TensorPair wb )
+ {
+ fmaRepeat( cur, wb.w, wb.b );
+ }
+
+ // Multiply two matrices
+ Tensor mulMat( const Tensor& a, const Tensor& b );
+
+ // cur = add( repeat( b, cur ), cur ); cur = scale(cur, scaling)
+ void addRepeatScale( Tensor& cur, const Tensor& b, float scaling );
+
+ void addRepeat( Tensor& cur, const Tensor& b );
+
+ Tensor add( const Tensor& a, const Tensor& b );
+ void addInPlace( Tensor& a, const Tensor& b );
+ void addRepeatGelu( Tensor& cur, const Tensor& b );
+
+ // cur = scale(cur, scaling)
+ void scale( Tensor& cur, float scaling );
+
+ void diagMaskInf( Tensor& cur, uint32_t n_past );
+
+ void softMax( Tensor& cur, float inputScale = 1.0f );
+
+ Tensor copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size );
+
+ HRESULT copyImpl( Tensor& result, const Tensor& source );
+
+ Tensor permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 );
+
+ void copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size );
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/MlContextCpu.cpp b/Whisper/CPU/MlContextCpu.cpp
new file mode 100644
index 0000000..e34a823
--- /dev/null
+++ b/Whisper/CPU/MlContextCpu.cpp
@@ -0,0 +1,597 @@
+#include "stdafx.h"
+#include "MlContext.h"
+#include "simdUtils.h"
+#include "mulMat.h"
+using namespace CpuCompute;
+
+MlContext::MlContext( int threads ) : pfor( threads )
+{
+}
+
+Tensor MlContext::createTensor( eDataType type, const std::array<uint32_t, 4>& size )
+{
+ Tensor res;
+ check( res.create( type, size, allocator ) );
+ return res;
+}
+
+Tensor MlContext::createTensor( eDataType type, std::initializer_list<uint32_t> size )
+{
+ Tensor res;
+ check( res.create( type, size, allocator ) );
+ return res;
+}
+
+namespace
+{
+ inline const uint16_t* getRow16( const Tensor& t, size_t index )
+ {
+ const uint16_t* rsi = t.fp16();
+ rsi += index * t.nb[ 1 ];
+ return rsi;
+ }
+ inline const float* getRow32( const Tensor& t, size_t index )
+ {
+ const float* rsi = t.fp32();
+ rsi += index * t.nb[ 1 ];
+ return rsi;
+ }
+}
+
+Tensor MlContext::addRows( const Tensor& d_te, const Tensor& d_pe, const int* tokens, const int n_tokens, const int n_past )
+{
+ if( d_te.type() != eDataType::FP16 || d_pe.type() != eDataType::FP32 )
+ throw E_INVALIDARG;
+ if( d_te.ne[ 0 ] != d_pe.ne[ 0 ] )
+ throw E_INVALIDARG;
+ if( n_tokens <= 0 )
+ throw E_BOUNDS;
+
+ Tensor res = createTensor( eDataType::FP32, { d_te.ne[ 0 ], (uint32_t)n_tokens } );
+
+ const size_t inner = (size_t)d_te.ne[ 0 ];
+ const size_t outer = (size_t)n_tokens;
+ float* rdi = res.fp32();
+ for( size_t i = 0; i < outer; i++, rdi += inner, tokens++ )
+ {
+ const uint16_t* const source1 = getRow16( d_te, *(const uint32_t*)tokens );
+ const float* const source2 = getRow32( d_pe, i + (size_t)n_past );
+ addF16to32( rdi, source1, source2, inner );
+ }
+ return res;
+}
+
+namespace
+{
+ class DispatchHelper3
+ {
+ std::array<uint32_t, 3> ne;
+
+ public:
+ DispatchHelper3() = default;
+ DispatchHelper3( uint32_t x, uint32_t y, uint32_t z )
+ {
+ assert( x > 0 && y > 0 && z > 0 );
+ ne[ 0 ] = x;
+ ne[ 1 ] = y;
+ ne[ 2 ] = z;
+ }
+ size_t groupsCount() const
+ {
+ size_t res = ne[ 0 ];
+ res *= ne[ 1 ];
+ res *= ne[ 2 ];
+ return res;
+ }
+ std::array<uint32_t, 3> unpack( size_t idx ) const
+ {
+ assert( idx < groupsCount() );
+ std::array<uint32_t, 3> res;
+ res[ 0 ] = (uint32_t)( idx % ne[ 0 ] );
+ idx = idx / ne[ 0 ];
+ res[ 1 ] = (uint32_t)( idx % ne[ 1 ] );
+ res[ 2 ] = (uint32_t)( idx / ne[ 1 ] );
+ return res;
+ }
+ void next( std::array<uint32_t, 3>& i ) const
+ {
+ i[ 0 ]++;
+ if( i[ 0 ] < ne[ 0 ] )
+ return;
+ i[ 0 ] = 0;
+ i[ 1 ]++;
+ if( i[ 1 ] < ne[ 1 ] )
+ return;
+ i[ 1 ] = 0;
+ i[ 2 ]++;
+ }
+ };
+
+ inline const float* sourceRow( const float* rsi, const std::array<uint32_t, 3>& idx, size_t nb0, size_t nb1, size_t nb2 )
+ {
+ const size_t r0 = idx[ 0 ] * nb0;
+ const size_t r1 = idx[ 1 ] * nb1;
+ const size_t r2 = idx[ 2 ] * nb2;
+ rsi = rsi + r0 + r1 + r2;
+ return rsi;
+ }
+
+ struct NormContext : public iComputeRange
+ {
+ const float* source;
+ float* result;
+ size_t inner;
+ DispatchHelper3 threads;
+ std::array<uint32_t, 3> nbInput;
+
+ HRESULT __stdcall compute( size_t i, size_t end ) const override final
+ {
+ ALIGNED_SPAN( temp, inner );
+
+ std::array<uint32_t, 3> idx = threads.unpack( i );
+ float* rdi = result + i * inner;
+ for( ; i < end; i++, rdi += inner, threads.next( idx ) )
+ {
+ const float* rsi = sourceRow( source, idx, nbInput[ 0 ], nbInput[ 1 ], nbInput[ 2 ] );
+ norm( rdi, temp, rsi, inner );
+ }
+ return S_OK;
+ }
+ };
+}
+
+Tensor MlContext::norm( const Tensor& arg )
+{
+ if( arg.type() != eDataType::FP32 || arg.nb[ 0 ] != 1 )
+ throw E_INVALIDARG;
+ Tensor res = createTensor( eDataType::FP32, arg.ne );
+
+ NormContext context;
+ context.source = arg.fp32();
+ context.result = res.fp32();
+ context.inner = arg.ne[ 0 ];
+ context.threads = DispatchHelper3( arg.ne[ 1 ], arg.ne[ 2 ], arg.ne[ 3 ] );
+ context.nbInput = { arg.nb[ 1 ], arg.nb[ 2 ], arg.nb[ 3 ] };
+
+ check( pfor.parallelFor( context, context.threads.groupsCount() ) );
+ return res;
+}
+
+void MlContext::fmaRepeat( Tensor& cur, const Tensor& w, const Tensor& b )
+{
+ if( !( cur.isContinuous() && w.isContinuous() && b.isContinuous() ) )
+ throw E_INVALIDARG;
+
+ if( !( cur.type() == eDataType::FP32 && w.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) )
+ throw E_INVALIDARG;
+
+ if( !isSameShape( w, b ) )
+ throw E_INVALIDARG;
+
+ DispatchHelper3 helper{ cur.ne[ 1 ], cur.ne[ 2 ], cur.ne[ 3 ] };
+ std::array<uint32_t, 3> idx = { 0, 0, 0 };
+ const size_t countRows = helper.groupsCount();
+
+ const size_t innerRes = cur.ne[ 0 ];
+ const size_t innerPattern = w.ne[ 0 ];
+
+ float* rdi = cur.fp32();
+ for( size_t i = 0; i < countRows; i++, helper.next( idx ), rdi += innerRes )
+ {
+ std::array<uint32_t, 3> idxPattern;
+ idxPattern[ 0 ] = idx[ 0 ] % w.ne[ 1 ];
+ idxPattern[ 1 ] = idx[ 1 ] % w.ne[ 2 ];
+ idxPattern[ 2 ] = idx[ 2 ] % w.ne[ 3 ];
+
+ const float* s1 = sourceRow( w.fp32(), idxPattern, w.nb[ 1 ], w.nb[ 2 ], w.nb[ 3 ] );
+ const float* s2 = sourceRow( b.fp32(), idxPattern, b.nb[ 1 ], b.nb[ 2 ], b.nb[ 3 ] );
+ fmaRepeatRow( rdi, innerRes, s1, s2, innerPattern );
+ }
+}
+
+Tensor MlContext::mulMat( const Tensor& a, const Tensor& b )
+{
+ if( !DirectCompute::canMulMat( a, b ) )
+ throw E_INVALIDARG;
+
+ std::array<uint32_t, 4> ne{ a.ne[ 1 ], b.ne[ 1 ], a.ne[ 2 ], b.ne[ 3 ] };
+ Tensor result = createTensor( eDataType::FP32, ne );
+
+ check( CpuCompute::mulMat( result, a, b, pfor ) );
+ return result;
+}
+
+// cur = add( repeat( b, cur ), cur ); cur = scale(cur, scaling)
+void MlContext::addRepeatScale( Tensor& cur, const Tensor& b, float scaling )
+{
+ if( !( cur.isContinuous() && b.isContinuous() ) )
+ throw E_INVALIDARG;
+ if( !( cur.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) )
+ throw E_INVALIDARG;
+
+ DispatchHelper3 helper{ cur.ne[ 1 ], cur.ne[ 2 ], cur.ne[ 3 ] };
+ std::array<uint32_t, 3> idx = { 0, 0, 0 };
+ const size_t countRows = helper.groupsCount();
+
+ const size_t innerRes = (uint32_t)cur.ne[ 0 ];
+ const size_t innerPattern = (uint32_t)b.ne[ 0 ];
+
+ float* rdi = cur.fp32();
+ const __m256 scale = _mm256_set1_ps( scaling );
+ for( size_t i = 0; i < countRows; i++, helper.next( idx ), rdi += innerRes )
+ {
+ std::array<uint32_t, 3> idxPattern;
+ idxPattern[ 0 ] = idx[ 0 ] % (uint32_t)b.ne[ 1 ];
+ idxPattern[ 1 ] = idx[ 1 ] % (uint32_t)b.ne[ 2 ];
+ idxPattern[ 2 ] = idx[ 2 ] % (uint32_t)b.ne[ 3 ];
+
+ const float* source = sourceRow( b.fp32(), idxPattern, b.nb[ 1 ], b.nb[ 2 ], b.nb[ 3 ] );
+ addRepeatScaleRow( rdi, innerRes, source, innerPattern, scale );
+ }
+}
+
+void MlContext::addRepeat( Tensor& cur, const Tensor& b )
+{
+ if( !( cur.isContinuous() && b.isContinuous() ) )
+ throw E_INVALIDARG;
+ if( !( cur.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) )
+ throw E_INVALIDARG;
+
+ DispatchHelper3 helper{ cur.ne[ 1 ], cur.ne[ 2 ], cur.ne[ 3 ] };
+ std::array<uint32_t, 3> idx = { 0, 0, 0 };
+ const size_t countRows = helper.groupsCount();
+
+ const size_t innerRes = (uint32_t)cur.ne[ 0 ];
+ const size_t innerPattern = (uint32_t)b.ne[ 0 ];
+
+ float* rdi = cur.fp32();
+ for( size_t i = 0; i < countRows; i++, helper.next( idx ), rdi += innerRes )
+ {
+ std::array<uint32_t, 3> idxPattern;
+ idxPattern[ 0 ] = idx[ 0 ] % (uint32_t)b.ne[ 1 ];
+ idxPattern[ 1 ] = idx[ 1 ] % (uint32_t)b.ne[ 2 ];
+ idxPattern[ 2 ] = idx[ 2 ] % (uint32_t)b.ne[ 3 ];
+
+ const float* source = sourceRow( b.fp32(), idxPattern, b.nb[ 1 ], b.nb[ 2 ], b.nb[ 3 ] );
+ addRepeatRow( rdi, innerRes, source, innerPattern );
+ }
+}
+
+// cur = scale(cur, scaling)
+void MlContext::scale( Tensor& cur, float scaling )
+{
+ if( !( cur.isContinuous() && cur.type() == eDataType::FP32 ) )
+ throw E_INVALIDARG;
+
+ const size_t len = cur.countElements();
+ const __m256 scale = _mm256_set1_ps( scaling );
+ scaleRow( cur.fp32(), len, scale );
+}
+
+void MlContext::diagMaskInf( Tensor& cur, uint32_t n_past )
+{
+ if( !( cur.isContinuous() && cur.type() == eDataType::FP32 ) )
+ throw E_INVALIDARG;
+
+ const size_t n = cur.countRows();
+ const size_t nc = cur.ne[ 0 ];
+ const size_t nr = cur.ne[ 1 ];
+ const size_t nz = n / nr;
+
+ for( size_t k = 0; k < nz; k++ )
+ {
+ for( size_t j = 0; j < nr; j++ )
+ {
+ float* const rdi = cur.fp32() + k * cur.nb[ 2 ] + j * cur.nb[ 1 ];
+ // +1 because the original code checked for `if( i > n_past + j )`
+ // That's why the first index to write is ( n_past + j + 1 )
+ const size_t start = n_past + j + 1;
+ const ptrdiff_t len = (ptrdiff_t)nc - (ptrdiff_t)start;
+ if( len <= 0 )
+ continue;
+
+ // Generates a store string instruction (rep stosd).
+ // The magic number is negative infinity in FP32: https://www.h-schmidt.net/FloatConverter/IEEE754.html
+ __stosd( (DWORD*)( rdi + start ), 0xff800000u, (size_t)len );
+ }
+ }
+}
+
+void MlContext::softMax( Tensor& cur, float inputScale )
+{
+ if( !( cur.isContinuous() && cur.type() == eDataType::FP32 ) )
+ throw E_INVALIDARG;
+
+ struct SoftMaxContext : public iComputeRange
+ {
+ float* data;
+ float inputScale;
+ size_t length, stride;
+
+ HRESULT __stdcall compute( size_t i, size_t end ) const override final
+ {
+ float* rdi = data + stride * i;
+ for( ; i < end; i++, rdi += stride )
+ ::softMax( rdi, length, inputScale );
+ return S_OK;
+ }
+ };
+
+ SoftMaxContext context;
+ context.data = cur.fp32();
+ context.inputScale = inputScale;
+ context.length = cur.ne[ 0 ];
+ context.stride = cur.nb[ 1 ];
+
+ const size_t n = cur.countRows();
+ pfor.parallelFor( context, n );
+}
+
+namespace
+{
+ template<class R, class S>
+ __forceinline void copyElement( R* rdi, const S* rsi )
+ {
+ static_assert( std::is_same<R, S>() );
+ *rdi = *rsi;
+ }
+ template<>
+ __forceinline void copyElement<float, uint16_t>( float* rdi, const uint16_t* rsi )
+ {
+ __m128i iv = _mm_cvtsi32_si128( *rsi );
+ __m128 fv = _mm_cvtph_ps( iv );
+ _mm_store_ss( rdi, fv );
+ }
+ template<>
+ __forceinline void copyElement<uint16_t, float>( uint16_t* rdi, const float* rsi )
+ {
+ __m128 fv = _mm_load_ss( rsi );
+ __m128i iv = _mm_cvtps_ph( fv, 0 );
+ *rdi = (uint16_t)(uint32_t)_mm_cvtsi128_si32( iv );
+ }
+
+ template<class R, class S>
+ __forceinline void copyRow( R* rdi, const S* rsi, size_t length )
+ {
+ static_assert( std::is_same<R, S>() );
+ memcpy( rdi, rsi, length * sizeof( R ) );
+ }
+ template<>
+ __forceinline void copyRow<uint16_t, float>( uint16_t* rdi, const float* rsi, size_t length )
+ {
+ floatsDowncast( rdi, rsi, length );
+ }
+ template<>
+ __forceinline void copyRow<float, uint16_t>( float* rdi, const uint16_t* rsi, size_t length )
+ {
+ floatsUpcast( rdi, rsi, length );
+ }
+
+ template<class R, class S>
+ static void __declspec( noinline ) copyImpl( R* rdi, const S* rsi, const TensorShape& shape )
+ {
+ const bool continuousRows = shape.nb[ 0 ] == 1;
+
+ for( size_t i03 = 0; i03 < shape.ne[ 3 ]; i03++, rsi += shape.nb[ 3 ] )
+ {
+ const S* source2 = rsi;
+ for( size_t i02 = 0; i02 < shape.ne[ 2 ]; i02++, source2 += shape.nb[ 2 ] )
+ {
+ const S* source1 = source2;
+ for( size_t i01 = 0; i01 < shape.ne[ 1 ]; i01++, source1 += shape.nb[ 1 ] )
+ {
+ // Performance optimization here: when the rows are dense, we can copy them much faster with memcpy()
+ // Or at least with AVX, when we need to convert between numeric types
+ if( continuousRows )
+ {
+ // This branch is very predictable, same outcome for all loop iterations
+ copyRow( rdi, source1, shape.ne[ 0 ] );
+ rdi += shape.ne[ 0 ];
+ }
+ else
+ {
+ const S* source0 = source1;
+ for( size_t i00 = 0; i00 < shape.ne[ 0 ]; i00++, source0 += shape.nb[ 0 ] )
+ {
+ copyElement( rdi, source0 );
+ rdi++;
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+HRESULT MlContext::copyImpl( Tensor& result, const Tensor& source )
+{
+ if( !( result.isContinuous() && ( result.countElements() == source.countElements() ) ) )
+ return E_INVALIDARG;
+
+ const eDataType typeResult = result.type();
+ const eDataType typeSource = source.type();
+ if( source.isContinuous() )
+ {
+ const size_t elts = result.countElements();
+ if( typeResult == typeSource )
+ {
+ const size_t bytes = elts * elementSize( typeResult );
+ memcpy( result.data(), source.data(), bytes );
+ return S_OK;
+ }
+ if( typeSource == eDataType::FP16 && typeResult == eDataType::FP32 )
+ {
+ floatsUpcast( result.fp32(), source.fp16(), elts );
+ return S_OK;
+ }
+ if( typeSource == eDataType::FP32 && typeResult == eDataType::FP16 )
+ {
+ floatsDowncast( result.fp16(), source.fp32(), elts );
+ return S_OK;
+ }
+ return E_UNEXPECTED;
+ }
+ else
+ {
+ if( typeSource == eDataType::FP16 && typeResult == eDataType::FP16 )
+ {
+ ::copyImpl( result.fp16(), source.fp16(), source );
+ return S_OK;
+ }
+ if( typeSource == eDataType::FP32 && typeResult == eDataType::FP32 )
+ {
+ ::copyImpl( result.fp32(), source.fp32(), source );
+ return S_OK;
+ }
+ if( typeSource == eDataType::FP16 && typeResult == eDataType::FP32 )
+ {
+ ::copyImpl( result.fp32(), source.fp16(), source );
+ return S_OK;
+ }
+ if( typeSource == eDataType::FP32 && typeResult == eDataType::FP16 )
+ {
+ ::copyImpl( result.fp16(), source.fp32(), source );
+ return S_OK;
+ }
+ return E_UNEXPECTED;
+ }
+}
+
+Tensor MlContext::copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size )
+{
+ const size_t dims = size.size();
+ if( 0 == dims || dims > 4 )
+ throw E_BOUNDS;
+
+ size_t nRequested = 1;
+ for( size_t i = 0; i < dims; i++ )
+ {
+ uint32_t n = size.begin()[ i ];
+ nRequested *= n;
+ }
+ if( nRequested != a.countElements() )
+ throw E_INVALIDARG;
+
+ if( a.type() == type && a.isContinuous() )
+ {
+ // Same type, and it's dense - no need to move data, equal to reshape
+ Tensor res{ a };
+ for( size_t i = 0; i < dims; i++ )
+ res.ne[ i ] = size.begin()[ i ];;
+ for( size_t i = dims; i < 4; i++ )
+ res.ne[ i ] = 1;
+ res.setDenseStrides();
+ return res;
+ }
+ else
+ {
+ // Need to convert types, and/or transpose the tensor. Make another tensor for the output
+ Tensor res = createTensor( type, size );
+ check( copyImpl( res, a ) );
+ return res;
+ }
+}
+
+Tensor MlContext::permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 )
+{
+ assert( axis0 < 4 );
+ assert( axis1 < 4 );
+ assert( axis2 < 4 );
+ assert( axis3 < 4 );
+
+ assert( axis0 != axis1 );
+ assert( axis0 != axis2 );
+ assert( axis0 != axis3 );
+ assert( axis1 != axis2 );
+ assert( axis1 != axis3 );
+ assert( axis2 != axis3 );
+
+ Tensor res = a;
+ res.ne[ axis0 ] = a.ne[ 0 ];
+ res.ne[ axis1 ] = a.ne[ 1 ];
+ res.ne[ axis2 ] = a.ne[ 2 ];
+ res.ne[ axis3 ] = a.ne[ 3 ];
+
+ res.nb[ axis0 ] = a.nb[ 0 ];
+ res.nb[ axis1 ] = a.nb[ 1 ];
+ res.nb[ axis2 ] = a.nb[ 2 ];
+ res.nb[ axis3 ] = a.nb[ 3 ];
+
+ return res;
+}
+
+void MlContext::copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size )
+{
+ assert( type == dest.type() );
+
+ const size_t dims = size.size();
+ if( 0 == dims || dims > 4 )
+ throw E_BOUNDS;
+
+ size_t nRequested = 1;
+ for( size_t i = 0; i < dims; i++ )
+ {
+ uint32_t n = size.begin()[ i ];
+ nRequested *= n;
+ }
+ if( nRequested != a.countElements() || nRequested != dest.countElements() )
+ throw E_INVALIDARG;
+
+ // Reshape the destination
+ for( size_t i = 0; i < dims; i++ )
+ dest.ne[ i ] = size.begin()[ i ];
+ for( size_t i = dims; i < 4; i++ )
+ dest.ne[ i ] = 1;
+ dest.setDenseStrides();
+
+ // Copy the data
+ check( copyImpl( dest, a ) );
+}
+
+void MlContext::addInPlace( Tensor& a, const Tensor& b )
+{
+ if( !( a.isContinuous() && b.isContinuous() && a.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) )
+ throw E_NOTIMPL;
+
+ const size_t length = a.countElements();
+ addRowInPlace( a.fp32(), b.fp32(), length );
+}
+
+Tensor MlContext::add( const Tensor& a, const Tensor& b )
+{
+ if( !( a.isContinuous() && b.isContinuous() && a.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) )
+ throw E_NOTIMPL;
+
+ Tensor res = createTensor( eDataType::FP32, a.ne );
+ const size_t length = a.countElements();
+ addRow( res.fp32(), a.fp32(), b.fp32(), length );
+ return res;
+}
+
+void MlContext::addRepeatGelu( Tensor& cur, const Tensor& b )
+{
+ if( !( cur.isContinuous() && b.isContinuous() ) )
+ throw E_INVALIDARG;
+ if( !( cur.type() == eDataType::FP32 && b.type() == eDataType::FP32 ) )
+ throw E_INVALIDARG;
+
+ DispatchHelper3 helper{ cur.ne[ 1 ], cur.ne[ 2 ], cur.ne[ 3 ] };
+ std::array<uint32_t, 3> idx = { 0, 0, 0 };
+ const size_t countRows = helper.groupsCount();
+
+ const size_t innerRes = (uint32_t)cur.ne[ 0 ];
+ const size_t innerPattern = (uint32_t)b.ne[ 0 ];
+ float* rdi = cur.fp32();
+ auto& lookupTables = getLookupTables();
+ for( size_t i = 0; i < countRows; i++, helper.next( idx ), rdi += innerRes )
+ {
+ std::array<uint32_t, 3> idxPattern;
+ idxPattern[ 0 ] = idx[ 0 ] % (uint32_t)b.ne[ 1 ];
+ idxPattern[ 1 ] = idx[ 1 ] % (uint32_t)b.ne[ 2 ];
+ idxPattern[ 2 ] = idx[ 2 ] % (uint32_t)b.ne[ 3 ];
+
+ const float* source = sourceRow( b.fp32(), idxPattern, b.nb[ 1 ], b.nb[ 2 ], b.nb[ 3 ] );
+ addRepeatGeluRow( rdi, innerRes, source, innerPattern, lookupTables );
+ }
+ return;
+} \ No newline at end of file
diff --git a/Whisper/CPU/ParallelForRunner.cpp b/Whisper/CPU/ParallelForRunner.cpp
new file mode 100644
index 0000000..7151a23
--- /dev/null
+++ b/Whisper/CPU/ParallelForRunner.cpp
@@ -0,0 +1,149 @@
+#include "stdafx.h"
+#include "ParallelForRunner.h"
+using namespace CpuCompute;
+
+ParallelForRunner::ParallelForRunner( int threads ) :
+ maxThreads( threads )
+{
+ if( maxThreads <= 1 )
+ {
+ threadBuffers.resize( 1 );
+ return;
+ }
+
+ work = CreateThreadpoolWork( &workCallbackStatic, this, nullptr );
+ if( nullptr == work )
+ throw getLastHr();
+ threadBuffers.resize( maxThreads );
+}
+
+HRESULT ParallelForRunner::setThreadsCount( int threads )
+{
+ maxThreads = threads;
+ if( threads <= 1 )
+ {
+ threadBuffers.resize( 1 );
+ return S_OK;
+ }
+
+ threadBuffers.resize( maxThreads );
+ if( nullptr == work )
+ {
+ work = CreateThreadpoolWork( &workCallbackStatic, this, nullptr );
+ if( nullptr == work )
+ return getLastHr();
+ }
+ return S_OK;
+}
+
+ParallelForRunner::~ParallelForRunner()
+{
+ if( nullptr != work )
+ {
+ if( S_FALSE == status )
+ WaitForThreadpoolWorkCallbacks( work, FALSE );
+ CloseThreadpoolWork( work );
+ }
+}
+
+namespace
+{
+ thread_local uint32_t currentThreadIndex = UINT_MAX;
+}
+
+void ParallelForRunner::runBatch( size_t ith ) noexcept
+{
+ currentThreadIndex = (uint32_t)ith;
+ const size_t begin = ( ith * countItems ) / countThreads;
+ const size_t end = ( ( ith + 1 ) * countItems ) / countThreads;
+
+ HRESULT hr = E_UNEXPECTED;
+ try
+ {
+ hr = computeRange->compute( begin, end );
+ }
+ catch( HRESULT code )
+ {
+ hr = code;
+ }
+ catch( const std::bad_alloc& )
+ {
+ hr = E_OUTOFMEMORY;
+ }
+ catch( const std::exception& )
+ {
+ hr = E_FAIL;
+ }
+ currentThreadIndex = UINT_MAX;
+ if( SUCCEEDED( hr ) )
+ return;
+ InterlockedCompareExchange( &status, hr, S_FALSE );
+}
+
+void* ParallelForRunner::threadLocalBuffer( size_t cb )
+{
+ const uint32_t idx = currentThreadIndex;
+ if( idx < threadBuffers.size() )
+ {
+ ThreadBuffer& tb = threadBuffers[ idx ];
+ if( tb.cb >= cb )
+ {
+ // We already have large enough buffer for the current thread
+ return tb.memory.pointer();
+ }
+ tb.memory.deallocate();
+ check( tb.memory.allocate( cb ) );
+ tb.cb = cb;
+ return tb.memory.pointer();
+ }
+ if( idx != UINT_MAX )
+ throw E_BOUNDS;
+ else
+ {
+ logError( u8"threadLocalBuffer() method only works from inside a pool callback" );
+ throw E_UNEXPECTED;
+ }
+}
+
+void __stdcall ParallelForRunner::workCallbackStatic( PTP_CALLBACK_INSTANCE Instance, void* pv, PTP_WORK Work ) noexcept
+{
+ ParallelForRunner& context = *(ParallelForRunner*)pv;
+ const size_t ith = (uint32_t)( InterlockedIncrement( &context.threadIndex ) );
+ context.runBatch( ith );
+}
+
+HRESULT ParallelForRunner::parallelFor( iComputeRange& compute, size_t length, size_t minBatch )
+{
+ if( maxThreads <= 1 )
+ {
+ currentThreadIndex = 0;
+ const HRESULT hr1 = compute.compute( 0, length );
+ currentThreadIndex = UINT_MAX;
+ return hr1;
+ }
+ assert( minBatch > 0 );
+
+ size_t nth = length / minBatch;
+ nth = std::min( nth, (size_t)(uint32_t)maxThreads );
+
+ computeRange = &compute;
+ countItems = length;
+ countThreads = nth;
+ threadIndex = 0;
+ status = S_FALSE;
+
+ for( size_t i = 1; i < nth; i++ )
+ SubmitThreadpoolWork( work );
+ runBatch( 0 );
+
+ if( nth > 1 )
+ WaitForThreadpoolWorkCallbacks( work, FALSE );
+
+ computeRange = nullptr;
+ const HRESULT hr = status;
+ status = S_OK;
+ if( SUCCEEDED( hr ) )
+ return S_OK;
+
+ return hr;
+} \ No newline at end of file
diff --git a/Whisper/CPU/ParallelForRunner.h b/Whisper/CPU/ParallelForRunner.h
new file mode 100644
index 0000000..baef647
--- /dev/null
+++ b/Whisper/CPU/ParallelForRunner.h
@@ -0,0 +1,52 @@
+#pragma once
+#include "LargeBuffer.h"
+
+namespace CpuCompute
+{
+ // Callback interface for the parallel `for`
+ __interface iComputeRange
+ {
+ // The implementation calls this method on multiple thread pool threads in parallel, and aggregates status codes.
+ HRESULT __stdcall compute( size_t begin, size_t end ) const;
+ };
+
+ // Similar to ThreadPoolWork in parallelFor.h, optimized to be used as a direct replacement of OpenMP pool.
+ class alignas( 64 ) ParallelForRunner
+ {
+ public:
+ ParallelForRunner( int threads );
+ ~ParallelForRunner();
+
+ HRESULT setThreadsCount( int threads );
+
+ HRESULT parallelFor( iComputeRange& compute, size_t length, size_t minBatch = 1 );
+
+ // Allocate a temporary buffer for the calling thread.
+ // The pointer is guaranteed to be aligned by page size = 4kb
+ void* threadLocalBuffer( size_t cb );
+
+ private:
+
+ int maxThreads;
+ PTP_WORK work = nullptr;
+ iComputeRange* computeRange = nullptr;
+ size_t countItems = 0;
+ size_t countThreads = 0;
+
+ // Aligning by cache lines.
+ // Avoiding cache line sharing between CPU cores improves performance, despite wasting a few bytes of memory.
+ struct alignas( 64 ) ThreadBuffer
+ {
+ LargeBuffer memory;
+ size_t cb = 0;
+ };
+ std::vector<ThreadBuffer> threadBuffers;
+
+ alignas( 64 ) volatile long threadIndex = 0;
+ volatile HRESULT status = S_OK;
+
+ void runBatch( size_t ith ) noexcept;
+
+ static void __stdcall workCallbackStatic( PTP_CALLBACK_INSTANCE Instance, void* pv, PTP_WORK Work ) noexcept;
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/Readme.txt b/Whisper/CPU/Readme.txt
new file mode 100644
index 0000000..702813d
--- /dev/null
+++ b/Whisper/CPU/Readme.txt
@@ -0,0 +1 @@
+The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_HYBRID_VERSION macro in stdafx.h \ No newline at end of file
diff --git a/Whisper/CPU/Tensor.h b/Whisper/CPU/Tensor.h
new file mode 100644
index 0000000..2fcde63
--- /dev/null
+++ b/Whisper/CPU/Tensor.h
@@ -0,0 +1,139 @@
+#pragma once
+#include "../D3D/enums.h"
+#include "../ML/TensorShape.h"
+// 1 = new tensors can be allocated with `nullptr` iMemoryAllocator, by allocating memory internally and counting these references
+// 0 = memory allocator is mandatory, create() methods will fail with E_POINTER if the allocator is `nullptr`
+#define TENSOR_INTERNAL_ALLOC 0
+
+// 1 = expose compatibility API for GGML interop
+#define TENSOR_GGML_COMPAT 0
+
+#if TENSOR_GGML_COMPAT
+#include "../source/ggml.h"
+#endif
+
+namespace CpuCompute
+{
+ using DirectCompute::TensorShape;
+ using DirectCompute::eDataType;
+
+ __interface iMemoryAllocator
+ {
+ void* allocate( size_t cb, size_t align );
+ };
+ __interface iArenaAllocator : public iMemoryAllocator
+ {
+ void resetArena();
+ };
+
+#if TENSOR_GGML_COMPAT
+ class Tensor;
+ class GgmlTensorView
+ {
+ ggml_tensor tensor;
+ public:
+ GgmlTensorView( const Tensor& t );
+ operator ggml_tensor* ( ) { return &tensor; }
+ };
+#endif
+
+ // A functional equivalent of ggml_tensor structure, designed for use from C++
+ class Tensor : public TensorShape
+ {
+ void* m_data = nullptr;
+
+ eDataType m_type = (eDataType)0xFF;
+
+#if TENSOR_INTERNAL_ALLOC
+ // True when the memory block was allocated internally by this class
+ // In this case, this class does reference counting to support cheap copies.
+ // False when it's owned by someone else, such as iMemoryAllocator object, or a GGML's tensor
+ bool ownsMemory = false;
+ void deallocate();
+#endif
+
+ // Private constructors for fromData() methods
+ Tensor( void* pointer, eDataType type, std::initializer_list<uint32_t> size );
+ Tensor( void* pointer, eDataType type, uint32_t length ) noexcept;
+ public:
+ // Trivial constructors
+ Tensor() = default;
+#if TENSOR_INTERNAL_ALLOC
+ ~Tensor()
+ {
+ deallocate();
+ }
+#else
+ ~Tensor() = default;
+#endif
+ Tensor( Tensor&& that ) noexcept;
+ void operator=( Tensor&& that ) noexcept;
+ Tensor( const Tensor& that );
+ void operator=( const Tensor& that );
+
+ // Allocate a new tensor
+ HRESULT create( eDataType type, const std::array<uint32_t, 4>& sizeElements, iMemoryAllocator* alloc = nullptr );
+ // Allocate a new tensor
+ HRESULT create( eDataType type, std::initializer_list<uint32_t> sizeElements, iMemoryAllocator* alloc = nullptr );
+ // Attach to pre-existing block of memory, interpreting the data as a dense tensor of the specified type and size
+ HRESULT attach( void* pointer, eDataType type, std::initializer_list<uint32_t> sizeElements );
+ // Attach to pre-existing block of memory, interpret the data as a dense vector of the specified type and length
+ static Tensor fromData( void* pointer, eDataType type, uint32_t length );
+
+ eDataType type() const { return m_type; }
+ void* data() const { return m_data; }
+
+ uint16_t* fp16()
+ {
+ assert( m_type == eDataType::FP16 );
+ assert( nullptr != m_data );
+ return (uint16_t*)m_data;
+ }
+ const uint16_t* fp16() const
+ {
+ assert( m_type == eDataType::FP16 );
+ assert( nullptr != m_data );
+ return (uint16_t*)m_data;
+ }
+ float* fp32()
+ {
+ assert( m_type == eDataType::FP32 );
+ assert( nullptr != m_data );
+ return (float*)m_data;
+ }
+ const float* fp32() const
+ {
+ assert( m_type == eDataType::FP32 );
+ assert( nullptr != m_data );
+ return (float*)m_data;
+ }
+
+ Tensor reshape3d( uint32_t ne0, uint32_t ne1, uint32_t ne2 ) const;
+
+ void setType( eDataType dt )
+ {
+ m_type = dt;
+ }
+ void setDataPointer( void* pv )
+ {
+ m_data = pv;
+ }
+
+#if TENSOR_GGML_COMPAT
+ // Compatibility with GGML's tensors, for testing and lulz
+ Tensor( const ggml_tensor* ggml );
+ ggml_tensor ggml() const;
+
+ operator GgmlTensorView() const
+ {
+ return GgmlTensorView( *this );
+ }
+#endif
+ };
+
+ // A pair of tensors containing weights and biases; apparently, both tensors are of the same shape
+ struct TensorPair
+ {
+ Tensor w, b;
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/TensorCpu.cpp b/Whisper/CPU/TensorCpu.cpp
new file mode 100644
index 0000000..dc34464
--- /dev/null
+++ b/Whisper/CPU/TensorCpu.cpp
@@ -0,0 +1,401 @@
+#include <stdafx.h>
+#include <atomic>
+#include "Tensor.h"
+using namespace CpuCompute;
+
+#if TENSOR_INTERNAL_ALLOC
+namespace
+{
+ // This structure is immediately before the payload of every tensor which has an internally-allocated memory buffer
+ class alignas( 32 ) sTensorMemoryHeader
+ {
+ std::atomic_ptrdiff_t refCounter;
+ public:
+ // Reset the counter to the specified value
+ void reset( ptrdiff_t rc )
+ {
+ refCounter = rc;
+ }
+ // Increment the ref.counter
+ void increment()
+ {
+ refCounter++;
+ }
+ // Decrement the ref.counter, and return true if it reached zero as the result
+ bool decrement()
+ {
+ ptrdiff_t val = --refCounter;
+ assert( val >= 0 );
+ return 0 == val;
+ }
+ };
+
+ inline sTensorMemoryHeader* getMemBlockHeader( void* pv )
+ {
+ assert( nullptr != pv );
+ uint8_t* pb = (uint8_t*)pv;
+ static_assert( sizeof( sTensorMemoryHeader ) == 32 );
+ return (sTensorMemoryHeader*)( pb - sizeof( sTensorMemoryHeader ) );
+ }
+
+ inline void releaseBlock( sTensorMemoryHeader* pointer )
+ {
+ assert( nullptr != pointer );
+ _aligned_free( pointer );
+ }
+
+ inline void* allocateBlock( size_t cb, ptrdiff_t initialRefCounter = 1 )
+ {
+ cb += sizeof( sTensorMemoryHeader );
+ void* pv = _aligned_malloc( cb, 32 );
+ if( nullptr == pv )
+ return nullptr;
+
+ sTensorMemoryHeader* header = (sTensorMemoryHeader*)pv;
+ header->reset( initialRefCounter );
+ return ( (uint8_t*)pv ) + sizeof( sTensorMemoryHeader );
+ }
+}
+
+void Tensor::deallocate()
+{
+ if( ownsMemory && nullptr != m_data )
+ {
+ sTensorMemoryHeader* const header = getMemBlockHeader( m_data );
+ if( header->decrement() )
+ {
+ // This tensor is the last one which had a reference to that block of memory
+ // Release the memory back to the heap
+ releaseBlock( header );
+ }
+ }
+ ownsMemory = false;
+
+ TensorShape::setZero();
+ m_data = nullptr;
+ m_type = (eDataType)0xFF;
+}
+#endif
+
+Tensor::Tensor( const Tensor& that )
+{
+ store( ne, that.sizeVec() );
+ store( nb, that.stridesVec() );
+ m_data = that.m_data;
+ m_type = that.m_type;
+#if TENSOR_INTERNAL_ALLOC
+ if( that.ownsMemory && nullptr != m_data )
+ {
+ getMemBlockHeader( m_data )->increment();
+ ownsMemory = true;
+ }
+ else
+ ownsMemory = false;
+#endif
+}
+
+Tensor::Tensor( Tensor&& that ) noexcept
+{
+ store( ne, that.sizeVec() );
+ store( nb, that.stridesVec() );
+ m_data = that.m_data;
+ m_type = that.m_type;
+#if TENSOR_INTERNAL_ALLOC
+ ownsMemory = that.ownsMemory;
+ that.ownsMemory = false;
+#endif
+ that.m_data = nullptr;
+}
+
+void Tensor::operator=( const Tensor& that )
+{
+ assert( this != &that );
+#if TENSOR_INTERNAL_ALLOC
+ deallocate();
+#endif
+
+ store( ne, that.sizeVec() );
+ store( nb, that.stridesVec() );
+ m_data = that.m_data;
+ m_type = that.m_type;
+#if TENSOR_INTERNAL_ALLOC
+ if( that.ownsMemory && nullptr != m_data )
+ {
+ getMemBlockHeader( m_data )->increment();
+ ownsMemory = true;
+ }
+ else
+ ownsMemory = false;
+#endif
+}
+
+void Tensor::operator=( Tensor&& that ) noexcept
+{
+ assert( this != &that );
+#if TENSOR_INTERNAL_ALLOC
+ deallocate();
+#endif
+ store( ne, that.sizeVec() );
+ store( nb, that.stridesVec() );
+ m_data = that.m_data;
+ m_type = that.m_type;
+ that.m_data = nullptr;
+#if TENSOR_INTERNAL_ALLOC
+ ownsMemory = that.ownsMemory;
+ that.ownsMemory = false;
+#endif
+}
+
+HRESULT Tensor::create( eDataType type, const std::array<uint32_t, 4>& sizeElements, iMemoryAllocator* alloc )
+{
+ const size_t len = (size_t)sizeElements[ 0 ] * sizeElements[ 1 ] * sizeElements[ 2 ] * sizeElements[ 3 ];
+ const size_t cbElement = DirectCompute::elementSize( type );
+ const size_t cb = len * cbElement;
+
+#if TENSOR_INTERNAL_ALLOC
+ deallocate();
+#endif
+
+ store( ne, load( sizeElements ) );
+ TensorShape::setDenseStrides();
+ this->m_type = type;
+
+ if( nullptr != alloc )
+ {
+#if TENSOR_INTERNAL_ALLOC
+ ownsMemory = false;
+#endif
+ m_data = alloc->allocate( cb, 32 );
+ if( nullptr == m_data )
+ return E_OUTOFMEMORY;
+ return S_OK;
+ }
+ else
+ {
+#if TENSOR_INTERNAL_ALLOC
+ m_data = allocateBlock( cb, 1 );
+ if( nullptr == m_data )
+ return E_OUTOFMEMORY;
+ ownsMemory = true;
+ return S_OK;
+#else
+ return E_POINTER;
+#endif
+ }
+}
+
+namespace
+{
+ static HRESULT arrayFromList( std::array<uint32_t, 4>& arr, std::initializer_list<uint32_t> list )
+ {
+ const size_t dims = list.size();
+ if( dims == 0 || dims > 4 )
+ return E_INVALIDARG;
+
+ for( size_t i = 0; i < dims; i++ )
+ {
+ uint32_t u = list.begin()[ i ];
+ if( u == 0 )
+ return E_INVALIDARG;
+ arr[ i ] = u;
+ }
+
+ for( size_t i = dims; i < 4; i++ )
+ arr[ i ] = 1;
+
+ return S_OK;
+ }
+}
+
+HRESULT Tensor::create( eDataType type, std::initializer_list<uint32_t> sizeElements, iMemoryAllocator* alloc )
+{
+ std::array<uint32_t, 4> arr;
+ CHECK( arrayFromList( arr, sizeElements ) );
+
+ return create( type, arr, alloc );
+}
+
+Tensor::Tensor( void* pointer, eDataType type, std::initializer_list<uint32_t> size )
+{
+ if( nullptr == pointer )
+ throw E_POINTER;
+ check( arrayFromList( ne, size ) );
+ TensorShape::setDenseStrides();
+ m_data = pointer;
+ m_type = type;
+#if TENSOR_INTERNAL_ALLOC
+ ownsMemory = false;
+#endif
+}
+
+Tensor::Tensor( void* pointer, eDataType type, uint32_t length ) noexcept
+{
+ // size = [ length, 1, 1, 1 ]
+ const __m128i one = _mm_set1_epi32( 1 );
+ __m128i v = _mm_insert_epi32( one, (int)length, 0 );
+ store( ne, v );
+ // stride = [ 1, length, length, length ]
+ v = _mm_shuffle_epi32( v, _MM_SHUFFLE( 0, 0, 0, 1 ) );
+ store( nb, v );
+
+ m_data = pointer;
+ m_type = type;
+#if TENSOR_INTERNAL_ALLOC
+ ownsMemory = false;
+#endif
+}
+
+Tensor Tensor::fromData( void* pointer, eDataType type, uint32_t length )
+{
+ HRESULT hr = E_UNEXPECTED;
+ if( nullptr != pointer )
+ {
+ if( 0 != length )
+ return Tensor{ pointer, type, length };
+ else
+ hr = E_INVALIDARG;
+ }
+ else
+ hr = E_POINTER;
+ throw hr;
+}
+
+HRESULT Tensor::attach( void* pointer, eDataType type, std::initializer_list<uint32_t> sizeElements )
+{
+ if( nullptr == pointer )
+ return E_POINTER;
+
+ std::array<uint32_t, 4> arr;
+ CHECK( arrayFromList( arr, sizeElements ) );
+
+#if TENSOR_INTERNAL_ALLOC
+ deallocate();
+#endif
+ store( ne, load( arr ) );
+ TensorShape::setDenseStrides();
+
+ m_data = pointer;
+ this->m_type = type;
+#if TENSOR_INTERNAL_ALLOC
+ ownsMemory = false;
+#endif
+ return S_OK;
+}
+
+Tensor Tensor::reshape3d( uint32_t ne0, uint32_t ne1, uint32_t ne2 ) const
+{
+ if( !isContinuous() )
+ throw E_NOTIMPL;
+ if( countElements() != ne0 * ne1 * ne2 )
+ throw E_INVALIDARG;
+
+ Tensor res = *this;
+ res.ne = { ne0, ne1, ne2, 1 };
+ res.setDenseStrides();
+ return res;
+}
+
+#if TENSOR_GGML_COMPAT
+static const __m128i s_maskAlignment16 = _mm_set1_epi64x( 1 );
+static const __m128i s_maskAlignment32 = _mm_set1_epi64x( 3 );
+
+bool isAlignedProperly( __m128i r0, __m128i r1, __m128i mask )
+{
+ __m128i test = _mm_or_si128( r0, r1 );
+ return (bool)_mm_testz_si128( test, mask );
+}
+
+Tensor::Tensor( const ggml_tensor* ggml )
+{
+ store( ne, load16( ggml->ne ) );
+
+ __m128i r0 = load16( (const int*)&ggml->nb[ 0 ] );
+ __m128i r1 = load16( (const int*)&ggml->nb[ 2 ] );
+ // Divide from bytes into elements by right-shifting the 64-bit integers in these vectors
+ switch( ggml->type )
+ {
+ case GGML_TYPE_F16:
+ assert( isAlignedProperly( r0, r1, s_maskAlignment16 ) );
+ r0 = _mm_srli_epi64( r0, 1 );
+ r1 = _mm_srli_epi64( r1, 1 );
+ m_type = eDataType::FP16;
+ break;
+
+ case GGML_TYPE_F32:
+ assert( isAlignedProperly( r0, r1, s_maskAlignment32 ) );
+ r0 = _mm_srli_epi64( r0, 2 );
+ r1 = _mm_srli_epi64( r1, 2 );
+ m_type = eDataType::FP32;
+ break;
+
+ case GGML_TYPE_I32:
+ assert( isAlignedProperly( r0, r1, s_maskAlignment32 ) );
+ r0 = _mm_srli_epi64( r0, 2 );
+ r1 = _mm_srli_epi64( r1, 2 );
+ m_type = eDataType::U32;
+ break;
+
+ default:
+ throw E_INVALIDARG;
+ }
+ // downcast uint64_t into uint32_t in a single vector
+ r0 = _mm_shuffle_epi32( r0, _MM_SHUFFLE( 3, 3, 2, 0 ) );
+ r1 = _mm_shuffle_epi32( r1, _MM_SHUFFLE( 2, 0, 3, 3 ) );
+ store( nb, _mm_blend_epi16( r0, r1, 0b11110000 ) );
+
+ m_data = ggml->data;
+}
+
+ggml_tensor Tensor::ggml() const
+{
+ ggml_tensor res;
+ memset( &res, 0, sizeof( ggml_tensor ) );
+
+ const __m128i size = sizeVec();
+ store16( res.ne, size );
+
+ const __m128i one = _mm_set1_epi32( 1 );
+ const uint32_t maskOnes = (uint32_t)_mm_movemask_ps( _mm_castsi128_ps( _mm_cmpeq_epi32( size, one ) ) );
+ const uint32_t maskNotOnes = maskOnes ^ 0b1111;
+ unsigned long idx;
+ if( _BitScanReverse( &idx, maskNotOnes ) )
+ res.n_dims = (int)idx + 1;
+ else
+ res.n_dims = 0;
+
+ const __m128i strides = stridesVec();
+ // Upcast strides from u32 to u64
+ const __m128i zero = _mm_setzero_si128();
+ __m128i r0 = _mm_unpacklo_epi32( strides, zero );
+ __m128i r1 = _mm_unpackhi_epi32( strides, zero );
+ // Scale from elements into bytes with left shift vector instructions
+ switch( m_type )
+ {
+ case eDataType::FP16:
+ r0 = _mm_slli_epi64( r0, 1 );
+ r1 = _mm_slli_epi64( r1, 1 );
+ res.type = GGML_TYPE_F16;
+ break;
+ case eDataType::FP32:
+ r0 = _mm_slli_epi64( r0, 2 );
+ r1 = _mm_slli_epi64( r1, 2 );
+ res.type = GGML_TYPE_F32;
+ break;
+ case eDataType::U32:
+ r0 = _mm_slli_epi64( r0, 2 );
+ r1 = _mm_slli_epi64( r1, 2 );
+ res.type = GGML_TYPE_I32;
+ break;
+ default:
+ throw OLE_E_BLANK;
+ }
+
+ store16( &res.nb[ 0 ], r0 );
+ store16( &res.nb[ 2 ], r1 );
+
+ res.data = m_data;
+ return res;
+}
+
+GgmlTensorView::GgmlTensorView( const Tensor& t ) : tensor( t.ggml() ) {}
+#endif \ No newline at end of file
diff --git a/Whisper/CPU/mulMat.cpp b/Whisper/CPU/mulMat.cpp
new file mode 100644
index 0000000..1b6ed25
--- /dev/null
+++ b/Whisper/CPU/mulMat.cpp
@@ -0,0 +1,54 @@
+#include "stdafx.h"
+#include "mulMat.h"
+#include "mulMatImpl.h"
+using namespace CpuCompute;
+
+namespace
+{
+ template<uint8_t panelHeightRegs, uint8_t tileWidthFloats>
+ static HRESULT mulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor )
+ {
+ MulMatImpl<panelHeightRegs, tileWidthFloats> impl{ result, a, b, pfor };
+ return impl.run( pfor );
+ }
+}
+
+HRESULT CpuCompute::mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor )
+{
+ if( a.type() != eDataType::FP16 )
+ return E_NOTIMPL;
+ if( b.type() != eDataType::FP32 )
+ return E_NOTIMPL;
+
+ // return mulMatImpl<1, 1>( result, a, b, pfor );
+
+ if( b.ne[ 1 ] == 1 )
+ {
+ // Multiplying by a single row
+ if( a.ne[ 1 ] >= 32 )
+ return mulMatImpl<4, 1>( result, a, b, pfor );
+ else
+ return mulMatImpl<1, 1>( result, a, b, pfor );
+ }
+ else if( b.ne[ 1 ] == 2 )
+ {
+ if( a.ne[ 1 ] >= 32 )
+ return mulMatImpl<4, 2>( result, a, b, pfor );
+ else
+ return mulMatImpl<1, 2>( result, a, b, pfor );
+ }
+ else if( b.ne[ 1 ] == 3 )
+ {
+ if( a.ne[ 1 ] >= 16 )
+ return mulMatImpl<2, 3>( result, a, b, pfor );
+ else
+ return mulMatImpl<1, 3>( result, a, b, pfor );
+ }
+ else
+ {
+ if( a.ne[ 1 ] >= 16 )
+ return mulMatImpl<2, 4>( result, a, b, pfor );
+ else
+ return mulMatImpl<1, 4>( result, a, b, pfor );
+ }
+} \ No newline at end of file
diff --git a/Whisper/CPU/mulMat.h b/Whisper/CPU/mulMat.h
new file mode 100644
index 0000000..36e56fe
--- /dev/null
+++ b/Whisper/CPU/mulMat.h
@@ -0,0 +1,17 @@
+#pragma once
+#include "ParallelForRunner.h"
+#include "Tensor.h"
+
+namespace CpuCompute
+{
+ HRESULT mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor );
+}
+
+#if TENSOR_GGML_COMPAT
+#include "../source/ggml.h"
+inline HRESULT mulMat( ggml_tensor* result, const ggml_tensor* a, const ggml_tensor* b, CpuCompute::ParallelForRunner& pfor )
+{
+ CpuCompute::Tensor r{ result }, lhs{ a }, rhs{ b };
+ return CpuCompute::mulMat( r, lhs, rhs, pfor );
+}
+#endif \ No newline at end of file
diff --git a/Whisper/CPU/mulMat.kernel.hpp b/Whisper/CPU/mulMat.kernel.hpp
new file mode 100644
index 0000000..80b51a0
--- /dev/null
+++ b/Whisper/CPU/mulMat.kernel.hpp
@@ -0,0 +1,742 @@
+#pragma once
+#include <stdint.h>
+#include <immintrin.h>
+#include "simdUtils.h"
+
+template<uint8_t panelHeightRegs, uint8_t tileWidthFloats>
+struct ResultTile
+{
+ static constexpr size_t totalRegs = (size_t)(tileWidthFloats)*panelHeightRegs;
+ std::array<__m256, totalRegs> arr;
+
+ template<size_t idx>
+ __forceinline void fmadd( __m256 a, __m256 b )
+ {
+ arr[ idx ] = _mm256_fmadd_ps( a, b, arr[ idx ] );
+ }
+ __forceinline void kernel( const std::array<__m256, panelHeightRegs>& panel, const float* rsi, size_t stride );
+ __forceinline void kernelPartial( const std::array<__m256, panelHeightRegs>& panel, const float* rsi, size_t stride, size_t rem )
+ {
+ throw E_UNEXPECTED;
+ }
+ __forceinline void store( float* rdi, size_t w, size_t h, size_t stride ) const;
+};
+
+#pragma region setZero functions
+__forceinline void setZero( std::array<__m256, 1>& dest )
+{
+ dest[ 0 ] = _mm256_setzero_ps();
+}
+__forceinline void setZero( std::array<__m256, 2>& dest )
+{
+ dest[ 0 ] = _mm256_setzero_ps();
+ dest[ 1 ] = _mm256_setzero_ps();
+}
+__forceinline void setZero( std::array<__m256, 3>& dest )
+{
+ dest[ 0 ] = _mm256_setzero_ps();
+ dest[ 1 ] = _mm256_setzero_ps();
+ dest[ 2 ] = _mm256_setzero_ps();
+}
+__forceinline void setZero( std::array<__m256, 4>& dest )
+{
+ dest[ 0 ] = _mm256_setzero_ps();
+ dest[ 1 ] = _mm256_setzero_ps();
+ dest[ 2 ] = _mm256_setzero_ps();
+ dest[ 3 ] = _mm256_setzero_ps();
+}
+__forceinline void setZero( std::array<__m256, 6>& dest )
+{
+ dest[ 0 ] = _mm256_setzero_ps();
+ dest[ 1 ] = _mm256_setzero_ps();
+ dest[ 2 ] = _mm256_setzero_ps();
+ dest[ 3 ] = _mm256_setzero_ps();
+ dest[ 4 ] = _mm256_setzero_ps();
+ dest[ 5 ] = _mm256_setzero_ps();
+}
+__forceinline void setZero( std::array<__m256, 8>& dest )
+{
+ dest[ 0 ] = _mm256_setzero_ps();
+ dest[ 1 ] = _mm256_setzero_ps();
+ dest[ 2 ] = _mm256_setzero_ps();
+ dest[ 3 ] = _mm256_setzero_ps();
+ dest[ 4 ] = _mm256_setzero_ps();
+ dest[ 5 ] = _mm256_setzero_ps();
+ dest[ 6 ] = _mm256_setzero_ps();
+ dest[ 7 ] = _mm256_setzero_ps();
+}
+#pragma endregion
+
+#pragma region Micro-kernels
+__forceinline void ResultTile<1, 1>::kernel( const std::array<__m256, 1>& panel, const float* rsi, size_t stride )
+{
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+}
+__forceinline void ResultTile<1, 2>::kernel( const std::array<__m256, 1>& panel, const float* rsi, size_t stride )
+{
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<1>( panel[ 0 ], b );
+}
+__forceinline void ResultTile<1, 2>::kernelPartial( const std::array<__m256, 1>& panel, const float* rsi, size_t stride, size_t rem )
+{
+ assert( 1 == rem );
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+}
+__forceinline void ResultTile<1, 3>::kernel( const std::array<__m256, 1>& panel, const float* rsi, size_t stride )
+{
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<1>( panel[ 0 ], b );
+ b = _mm256_broadcast_ss( rsi + stride * 2 );
+ fmadd<2>( panel[ 0 ], b );
+}
+__forceinline void ResultTile<1, 3>::kernelPartial( const std::array<__m256, 1>& panel, const float* rsi, size_t stride, size_t rem )
+{
+ assert( rem > 0 && rem < 3 );
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ if( rem > 1 )
+ {
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<1>( panel[ 0 ], b );
+ }
+}
+
+__forceinline void ResultTile<1, 4>::kernel( const std::array<__m256, 1>& panel, const float* rsi, size_t stride )
+{
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<1>( panel[ 0 ], b );
+ b = _mm256_broadcast_ss( rsi + stride * 2 );
+ fmadd<2>( panel[ 0 ], b );
+ b = _mm256_broadcast_ss( rsi + stride * 3 );
+ fmadd<3>( panel[ 0 ], b );
+}
+__forceinline void ResultTile<1, 4>::kernelPartial( const std::array<__m256, 1>& panel, const float* rsi, size_t stride, size_t rem )
+{
+ assert( rem > 0 && rem < 4 );
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+
+ switch( rem )
+ {
+ case 3:
+ b = _mm256_broadcast_ss( rsi + stride * 2 );
+ fmadd<2>( panel[ 0 ], b );
+ case 2:
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<1>( panel[ 0 ], b );
+ }
+}
+__forceinline void ResultTile<4, 1>::kernel( const std::array<__m256, 4>& panel, const float* rsi, size_t stride )
+{
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ fmadd<1>( panel[ 1 ], b );
+ fmadd<2>( panel[ 2 ], b );
+ fmadd<3>( panel[ 3 ], b );
+}
+__forceinline void ResultTile<2, 4>::kernel( const std::array<__m256, 2>& panel, const float* rsi, size_t stride )
+{
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ fmadd<1>( panel[ 1 ], b );
+
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<2>( panel[ 0 ], b );
+ fmadd<3>( panel[ 1 ], b );
+
+ b = _mm256_broadcast_ss( rsi + stride * 2 );
+ fmadd<4>( panel[ 0 ], b );
+ fmadd<5>( panel[ 1 ], b );
+
+ b = _mm256_broadcast_ss( rsi + stride * 3 );
+ fmadd<6>( panel[ 0 ], b );
+ fmadd<7>( panel[ 1 ], b );
+}
+
+__forceinline void ResultTile<2, 4>::kernelPartial( const std::array<__m256, 2>& panel, const float* rsi, size_t stride, size_t rem )
+{
+ assert( rem > 0 && rem < 4 );
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ fmadd<1>( panel[ 1 ], b );
+
+ switch( rem )
+ {
+ case 3:
+ b = _mm256_broadcast_ss( rsi + stride * 2 );
+ fmadd<4>( panel[ 0 ], b );
+ fmadd<5>( panel[ 1 ], b );
+ case 2:
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<2>( panel[ 0 ], b );
+ fmadd<3>( panel[ 1 ], b );
+ }
+}
+
+__forceinline void ResultTile<2, 3>::kernel( const std::array<__m256, 2>& panel, const float* rsi, size_t stride )
+{
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ fmadd<1>( panel[ 1 ], b );
+
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<2>( panel[ 0 ], b );
+ fmadd<3>( panel[ 1 ], b );
+
+ b = _mm256_broadcast_ss( rsi + stride * 2 );
+ fmadd<4>( panel[ 0 ], b );
+ fmadd<5>( panel[ 1 ], b );
+}
+__forceinline void ResultTile<2, 3>::kernelPartial( const std::array<__m256, 2>& panel, const float* rsi, size_t stride, size_t rem )
+{
+ assert( rem > 0 && rem < 3 );
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ fmadd<1>( panel[ 1 ], b );
+ if( rem > 1 )
+ {
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<2>( panel[ 0 ], b );
+ fmadd<3>( panel[ 1 ], b );
+ }
+}
+
+__forceinline void ResultTile<4, 2>::kernel( const std::array<__m256, 4>& panel, const float* rsi, size_t stride )
+{
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ fmadd<1>( panel[ 1 ], b );
+ fmadd<2>( panel[ 2 ], b );
+ fmadd<3>( panel[ 3 ], b );
+
+ b = _mm256_broadcast_ss( rsi + stride );
+ fmadd<4>( panel[ 0 ], b );
+ fmadd<5>( panel[ 1 ], b );
+ fmadd<6>( panel[ 2 ], b );
+ fmadd<7>( panel[ 3 ], b );
+}
+__forceinline void ResultTile<4, 2>::kernelPartial( const std::array<__m256, 4>& panel, const float* rsi, size_t stride, size_t rem )
+{
+ assert( 1 == rem );
+ __m256 b = _mm256_broadcast_ss( rsi );
+ fmadd<0>( panel[ 0 ], b );
+ fmadd<1>( panel[ 1 ], b );
+ fmadd<2>( panel[ 2 ], b );
+ fmadd<3>( panel[ 3 ], b );
+}
+#pragma endregion
+
+#pragma region Loads
+// This function should compile into a single `vcvtph2ps` instruction, with memory operand
+__forceinline __m256 loadUpcasted( const uint16_t* rsi )
+{
+ __m128i i = _mm_load_si128( ( const __m128i* )rsi );
+ return _mm256_cvtph_ps( i );
+}
+
+// We loading the panel from the temporary buffer.
+// For this reason, we don't need to handle remainders, the code which made the buffer wrote zeros into the remainder elements
+// We can even use aligned load instructions.
+__forceinline void loadPanel( const uint16_t* rsi, std::array<__m256, 1>& dest )
+{
+ dest[ 0 ] = loadUpcasted( rsi );
+}
+__forceinline void loadPanel( const uint16_t* rsi, std::array<__m256, 2>& dest )
+{
+ dest[ 0 ] = loadUpcasted( rsi );
+ dest[ 1 ] = loadUpcasted( rsi + 8 );
+}
+__forceinline void loadPanel( const uint16_t* rsi, std::array<__m256, 3>& dest )
+{
+ dest[ 0 ] = loadUpcasted( rsi );
+ dest[ 1 ] = loadUpcasted( rsi + 8 );
+ dest[ 2 ] = loadUpcasted( rsi + 8 * 2 );
+}
+__forceinline void loadPanel( const uint16_t* rsi, std::array<__m256, 4>& dest )
+{
+ dest[ 0 ] = loadUpcasted( rsi );
+ dest[ 1 ] = loadUpcasted( rsi + 8 );
+ dest[ 2 ] = loadUpcasted( rsi + 8 * 2 );
+ dest[ 3 ] = loadUpcasted( rsi + 8 * 3 );
+}
+#pragma endregion
+
+#pragma region Stores
+__forceinline void ResultTile<1, 1>::store( float* rdi, size_t w, size_t h, size_t stride ) const
+{
+ assert( h == 1 && w > 0 && w <= 8 );
+ if( w == 8 )
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ else
+ {
+ const __m256i mask = loadTailMaskInt( w );
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ }
+}
+
+__forceinline void ResultTile<1, 2>::store( float* rdi, size_t w, size_t h, size_t stride ) const
+{
+ assert( h > 0 && w > 0 && h <= 2 && w <= 8 );
+ if( w == 8 )
+ {
+ switch( h )
+ {
+ case 2:
+ _mm256_storeu_ps( rdi + stride, arr[ 1 ] );
+ case 1:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ }
+ }
+ else
+ {
+ const __m256i mask = loadTailMaskInt( w );
+ switch( h )
+ {
+ case 2:
+ _mm256_maskstore_ps( rdi + stride, mask, arr[ 1 ] );
+ case 1:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ }
+ }
+}
+
+__forceinline void ResultTile<1, 3>::store( float* rdi, size_t w, size_t h, size_t stride ) const
+{
+ assert( h > 0 && w > 0 && h <= 3 && w <= 8 );
+ if( w == 8 )
+ {
+ switch( h )
+ {
+ case 3:
+ _mm256_storeu_ps( rdi + stride * 2, arr[ 2 ] );
+ case 2:
+ _mm256_storeu_ps( rdi + stride, arr[ 1 ] );
+ case 1:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ }
+ }
+ else
+ {
+ const __m256i mask = loadTailMaskInt( w );
+ switch( h )
+ {
+ case 3:
+ _mm256_maskstore_ps( rdi + stride * 2, mask, arr[ 2 ] );
+ case 2:
+ _mm256_maskstore_ps( rdi + stride, mask, arr[ 1 ] );
+ case 1:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ }
+ }
+}
+
+__forceinline void ResultTile<1, 4>::store( float* rdi, size_t w, size_t h, size_t stride ) const
+{
+ assert( h > 0 && w > 0 && h <= 4 && w <= 8 );
+
+ if( w == 8 )
+ {
+ switch( h )
+ {
+ case 4:
+ _mm256_storeu_ps( rdi + stride * 3, arr[ 3 ] );
+ case 3:
+ _mm256_storeu_ps( rdi + stride * 2, arr[ 2 ] );
+ case 2:
+ _mm256_storeu_ps( rdi + stride, arr[ 1 ] );
+ case 1:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ }
+ }
+ else
+ {
+ const __m256i mask = loadTailMaskInt( w );
+ switch( h )
+ {
+ case 4:
+ _mm256_maskstore_ps( rdi + stride * 3, mask, arr[ 3 ] );
+ case 3:
+ _mm256_maskstore_ps( rdi + stride * 2, mask, arr[ 2 ] );
+ case 2:
+ _mm256_maskstore_ps( rdi + stride, mask, arr[ 1 ] );
+ case 1:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ }
+ }
+}
+
+__forceinline void ResultTile<4, 1>::store( float* rdi, size_t w, size_t h, size_t stride ) const
+{
+ assert( h == 1 && w > 0 && w <= 32 );
+ if( w == 32 )
+ {
+ // 4 complete vectors, this branch is very likely to be taken
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] );
+ _mm256_storeu_ps( rdi + 8 * 3, arr[ 3 ] );
+ }
+ else
+ {
+ const size_t rem = w % 8;
+ const __m256i mask = loadTailMaskInt<false>( rem );
+ const size_t completeVectors = w / 8;
+ const size_t key = ( completeVectors << 1 ) | ( ( 0 == rem ) ? 0 : 1 );
+ switch( key )
+ {
+ case 1: // 0 complete vectors + remainder
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ break;
+ case 2: // 1 complete vector
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ break;
+ case 3: // 1 complete vector + remainder
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ break;
+ case 4: // 2 complete vectors
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ break;
+ case 5: // 2 complete vectors + remainder
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_maskstore_ps( rdi + 8 * 2, mask, arr[ 2 ] );
+ break;
+ case 6: // 3 complete vectors
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] );
+ break;
+ case 7: // 3 complete vectors + remainder
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi + 8 * 3, mask, arr[ 3 ] );
+ break;
+ default:
+ throw E_UNEXPECTED;
+ }
+ }
+}
+__forceinline void ResultTile<4, 2>::store( float* rdi, size_t w, size_t h, size_t stride ) const
+{
+ assert( h > 0 && w > 0 && h <= 2 && w <= 32 );
+ const bool twoRows = h == 2;
+ float* const rdi1 = rdi + stride;
+ if( w == 32 )
+ {
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] );
+ _mm256_storeu_ps( rdi + 8 * 3, arr[ 3 ] );
+
+ if( twoRows )
+ {
+ _mm256_storeu_ps( rdi1, arr[ 4 ] );
+ _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] );
+ _mm256_storeu_ps( rdi1 + 8 * 2, arr[ 6 ] );
+ _mm256_storeu_ps( rdi1 + 8 * 3, arr[ 7 ] );
+ }
+ }
+ else
+ {
+ const size_t rem = w % 8;
+ const __m256i mask = loadTailMaskInt<false>( rem );
+ const size_t completeVectors = w / 8;
+ // Lowest bit: remainder
+ // Next bit: set when storing 2 rows
+ // Next 2 bits: count of complete vectors in X direction, [ 0..3 ]
+ const size_t key = ( completeVectors << 2 ) | ( ( 0 == rem ) ? 0 : 1 ) | ( twoRows ? 2 : 0 );
+ switch( key )
+ {
+ case 1: // 0 complete vectors + remainder, 1 row
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ break;
+ case 3: // 0 complete vectors + remainder, 2 rows
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi1, mask, arr[ 4 ] );
+ break;
+ case 4: // 1 complete vector, 1 row
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ break;
+ case 5: // 1 complete vector + remainder, 1 row
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ break;
+ case 6: // 1 complete vector, 2 rows
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi1, arr[ 4 ] );
+ break;
+ case 7: // 1 complete vector + remainder, 2 rows
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+
+ _mm256_storeu_ps( rdi1, arr[ 4 ] );
+ _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 5 ] );
+ break;
+ case 8: // 2 complete vectors, 1 row
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ break;
+ case 9: // 2 complete vectors + remainder, 1 row
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_maskstore_ps( rdi + 8 * 2, mask, arr[ 2 ] );
+ break;
+ case 10: // 2 complete vectors, 2 rows
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+
+ _mm256_storeu_ps( rdi1, arr[ 4 ] );
+ _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] );
+ break;
+ case 11: // 2 complete vectors + remainder, 2 rows
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_maskstore_ps( rdi + 8 * 2, mask, arr[ 2 ] );
+
+ _mm256_storeu_ps( rdi1, arr[ 4 ] );
+ _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] );
+ _mm256_maskstore_ps( rdi1 + 8 * 2, mask, arr[ 6 ] );
+ break;
+ case 12: // 3 complete vectors, 1 row
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] );
+ break;
+ case 13: // 3 complete vectors + remainder, 1 row
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi + 8 * 3, mask, arr[ 3 ] );
+ break;
+ case 14: // 3 complete vectors, 2 rows
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] );
+
+ _mm256_storeu_ps( rdi1, arr[ 4 ] );
+ _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] );
+ _mm256_storeu_ps( rdi1 + 8 * 2, arr[ 6 ] );
+ break;
+ case 15: // 3 complete vectors + remainder, 2 rows
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ _mm256_storeu_ps( rdi + 8 * 2, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi + 8 * 3, mask, arr[ 3 ] );
+
+ _mm256_storeu_ps( rdi1, arr[ 4 ] );
+ _mm256_storeu_ps( rdi1 + 8, arr[ 5 ] );
+ _mm256_storeu_ps( rdi1 + 8 * 2, arr[ 6 ] );
+ _mm256_maskstore_ps( rdi1 + 8 * 3, mask, arr[ 7 ] );
+ break;
+ default:
+ throw E_UNEXPECTED;
+ }
+ }
+}
+
+__forceinline void ResultTile<2, 4>::store( float* rdi, size_t w, size_t h, size_t stride ) const
+{
+ assert( h > 0 && w > 0 && h <= 4 && w <= 16 );
+ h--;
+ float* const rdi1 = rdi + stride;
+ float* const rdi2 = rdi + stride * 2;
+ float* const rdi3 = rdi + stride * 3;
+
+ if( w == 16 )
+ {
+ switch( h )
+ {
+ case 3:
+ _mm256_storeu_ps( rdi3, arr[ 6 ] );
+ _mm256_storeu_ps( rdi3 + 8, arr[ 7 ] );
+ case 2:
+ _mm256_storeu_ps( rdi2, arr[ 4 ] );
+ _mm256_storeu_ps( rdi2 + 8, arr[ 5 ] );
+ case 1:
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_storeu_ps( rdi1 + 8, arr[ 3 ] );
+ case 0:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ }
+ }
+ else
+ {
+ const size_t rem = w % 8;
+ const __m256i mask = loadTailMaskInt<false>( rem );
+ // 0 for partial first vector, 1 for exactly 1 complete vector, 2 for 1 complete vector with remainder
+ const size_t partialCase = ( w < 8 ) ? 0 : ( ( w == 8 ) ? 1 : 2 );
+ // Merge into a single integer for the switch statement
+ const size_t key = partialCase + h * 3;
+
+ switch( key )
+ {
+ // h = 1
+ case 0:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ break;
+ case 1:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ break;
+ case 2:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ break;
+ // h = 2
+ case 3:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] );
+ break;
+ case 4:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ break;
+ case 5:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] );
+ break;
+ // h = 3
+ case 6:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi2, mask, arr[ 4 ] );
+ break;
+ case 7:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_storeu_ps( rdi2, arr[ 4 ] );
+ break;
+ case 8:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] );
+ _mm256_storeu_ps( rdi2, arr[ 4 ] );
+ _mm256_maskstore_ps( rdi2 + 8, mask, arr[ 5 ] );
+ break;
+ // h = 4
+ case 9:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi2, mask, arr[ 4 ] );
+ _mm256_maskstore_ps( rdi3, mask, arr[ 6 ] );
+ break;
+ case 10:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_storeu_ps( rdi2, arr[ 4 ] );
+ _mm256_storeu_ps( rdi3, arr[ 6 ] );
+ break;
+ case 11:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] );
+ _mm256_storeu_ps( rdi2, arr[ 4 ] );
+ _mm256_maskstore_ps( rdi2 + 8, mask, arr[ 5 ] );
+ _mm256_storeu_ps( rdi3, arr[ 6 ] );
+ _mm256_maskstore_ps( rdi3 + 8, mask, arr[ 7 ] );
+ break;
+ default:
+ throw E_UNEXPECTED;
+ }
+ }
+}
+
+__forceinline void ResultTile<2, 3>::store( float* rdi, size_t w, size_t h, size_t stride ) const
+{
+ assert( h > 0 && w > 0 && h <= 3 && w <= 16 );
+ float* const rdi1 = rdi + stride;
+ float* const rdi2 = rdi + stride * 2;
+ h--;
+
+ if( w == 16 )
+ {
+ switch( h )
+ {
+ case 2:
+ _mm256_storeu_ps( rdi2, arr[ 4 ] );
+ _mm256_storeu_ps( rdi2 + 8, arr[ 5 ] );
+ case 1:
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_storeu_ps( rdi1 + 8, arr[ 3 ] );
+ case 0:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi + 8, arr[ 1 ] );
+ }
+ }
+ else
+ {
+ const size_t rem = w % 8;
+ const __m256i mask = loadTailMaskInt<false>( rem );
+ // 0 for partial first vector, 1 for exactly 1 complete vector, 2 for 1 complete vector with remainder
+ const size_t partialCase = ( w < 8 ) ? 0 : ( ( w == 8 ) ? 1 : 2 );
+ // Merge into a single integer for the switch statement
+ const size_t key = partialCase + h * 3;
+
+ switch( key )
+ {
+ // h = 1
+ case 0:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ break;
+ case 1:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ break;
+ case 2:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ break;
+ // h = 2
+ case 3:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] );
+ break;
+ case 4:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ break;
+ case 5:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] );
+ break;
+ // h = 3
+ case 6:
+ _mm256_maskstore_ps( rdi, mask, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi1, mask, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi2, mask, arr[ 4 ] );
+ break;
+ case 7:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_storeu_ps( rdi2, arr[ 4 ] );
+ break;
+ case 8:
+ _mm256_storeu_ps( rdi, arr[ 0 ] );
+ _mm256_maskstore_ps( rdi + 8, mask, arr[ 1 ] );
+ _mm256_storeu_ps( rdi1, arr[ 2 ] );
+ _mm256_maskstore_ps( rdi1 + 8, mask, arr[ 3 ] );
+ _mm256_storeu_ps( rdi2, arr[ 4 ] );
+ _mm256_maskstore_ps( rdi2 + 8, mask, arr[ 5 ] );
+ break;
+ default:
+ throw E_UNEXPECTED;
+ }
+ }
+}
+#pragma endregion \ No newline at end of file
diff --git a/Whisper/CPU/mulMatImpl.avx2.cpp b/Whisper/CPU/mulMatImpl.avx2.cpp
new file mode 100644
index 0000000..b15ae63
--- /dev/null
+++ b/Whisper/CPU/mulMatImpl.avx2.cpp
@@ -0,0 +1,362 @@
+#include "stdafx.h"
+#include "mulMatImpl.h"
+#include <immintrin.h>
+#include "mulMatUtils.hpp"
+using namespace CpuCompute;
+
+namespace
+{
+ constexpr size_t prefetchBytes = 96;
+ constexpr int prefetchHint = _MM_HINT_T0;
+
+ constexpr size_t maskAlign16 = ~(size_t)15;
+
+ __forceinline __m256i load( const void* rsi )
+ {
+ return _mm256_loadu_si256( ( const __m256i* )rsi );
+ }
+
+#define TRANSPOSE_8X16() \
+ \
+ __m256i t0 = _mm256_unpacklo_epi16( r0, r1 ); \
+ __m256i t1 = _mm256_unpackhi_epi16( r0, r1 ); \
+ __m256i t2 = _mm256_unpacklo_epi16( r2, r3 ); \
+ __m256i t3 = _mm256_unpackhi_epi16( r2, r3 ); \
+ __m256i t4 = _mm256_unpacklo_epi16( r4, r5 ); \
+ __m256i t5 = _mm256_unpackhi_epi16( r4, r5 ); \
+ __m256i t6 = _mm256_unpacklo_epi16( r6, r7 ); \
+ __m256i t7 = _mm256_unpackhi_epi16( r6, r7 ); \
+ \
+ r0 = _mm256_unpacklo_epi32( t0, t2 ); \
+ r1 = _mm256_unpackhi_epi32( t0, t2 ); \
+ r2 = _mm256_unpacklo_epi32( t1, t3 ); \
+ r3 = _mm256_unpackhi_epi32( t1, t3 ); \
+ r4 = _mm256_unpacklo_epi32( t4, t6 ); \
+ r5 = _mm256_unpackhi_epi32( t4, t6 ); \
+ r6 = _mm256_unpacklo_epi32( t5, t7 ); \
+ r7 = _mm256_unpackhi_epi32( t5, t7 ); \
+ \
+ t0 = _mm256_unpacklo_epi64( r0, r4 ); \
+ t1 = _mm256_unpackhi_epi64( r0, r4 ); \
+ t2 = _mm256_unpacklo_epi64( r1, r5 ); \
+ t3 = _mm256_unpackhi_epi64( r1, r5 ); \
+ t4 = _mm256_unpacklo_epi64( r2, r6 ); \
+ t5 = _mm256_unpackhi_epi64( r2, r6 ); \
+ t6 = _mm256_unpacklo_epi64( r3, r7 ); \
+ t7 = _mm256_unpackhi_epi64( r3, r7 )
+
+ __forceinline void storeLow( void* rdi, __m256i v )
+ {
+ __m128i i = _mm256_castsi256_si128( v );
+ _mm_store_si128( ( __m128i* )rdi, i );
+ }
+
+#define STORE_8X16_LOW() \
+ storeLow( rdi, t0 ); \
+ storeLow( rdi + destStride, t1 ); \
+ storeLow( rdi + destStride * 2, t2 ); \
+ rdi += destStride * 8; \
+ storeLow( rdiMid, t3 ); \
+ storeLow( rdiMid + destStride, t4 ); \
+ storeLow( rdiMid + destStride * 2, t5 ); \
+ rdiMid += destStride * 8; \
+ storeLow( rdiLast, t6 ); \
+ storeLow( rdiLast + destStride, t7 ); \
+ rdiLast += destStride * 8
+
+ __forceinline void storeHigh( void* rdi, __m256i v )
+ {
+ __m128i i = _mm256_extracti128_si256( v, 1 );
+ _mm_store_si128( ( __m128i* )rdi, i );
+ }
+
+#define STORE_8X16_HIGH() \
+ storeHigh( rdi, t0 ); \
+ storeHigh( rdi + destStride, t1 ); \
+ storeHigh( rdi + destStride * 2, t2 ); \
+ rdi += destStride * 8; \
+ storeHigh( rdiMid, t3 ); \
+ storeHigh( rdiMid + destStride, t4 ); \
+ storeHigh( rdiMid + destStride * 2, t5 ); \
+ rdiMid += destStride * 8; \
+ storeHigh( rdiLast, t6 ); \
+ storeHigh( rdiLast + destStride, t7 ); \
+ rdiLast += destStride * 8
+
+ __forceinline void prefetch( const uint8_t* p )
+ {
+ _mm_prefetch( (const char*)p, prefetchHint );
+ }
+
+ __forceinline void transpose8Avx2( uint16_t* rdiWords, size_t w, const uint16_t* rsiWords, size_t sourceStride, size_t destStride )
+ {
+ assert( 0 == ( (size_t)rdiWords ) % 16 );
+ assert( 0 == destStride % 8 );
+ assert( w <= sourceStride );
+
+ // Scale strides to bytes, and cast the pointers
+ sourceStride *= 2;
+ destStride *= 2;
+ uint8_t* rdi = (uint8_t*)rdiWords;
+ const uint8_t* rsi = (const uint8_t*)rsiWords;
+
+ const uint8_t* const rsiEndAligned = rsi + ( w & maskAlign16 ) * 2;
+ const uint8_t* const rsiEnd = rsi + w * 2;
+ const uint8_t* rsiMid = rsi + sourceStride * 3;
+ const uint8_t* rsiLast = rsi + sourceStride * 6;
+ uint8_t* rdiMid = rdi + destStride * 3;
+ uint8_t* rdiLast = rdi + destStride * 6;
+
+ while( rsi < rsiEndAligned )
+ {
+ // Load 16x8 block into 8 registers
+ __m256i r0 = load( rsi );
+ __m256i r1 = load( rsi + sourceStride );
+ __m256i r2 = load( rsi + sourceStride * 2 );
+ rsi += 32;
+ __m256i r3 = load( rsiMid );
+ __m256i r4 = load( rsiMid + sourceStride );
+ __m256i r5 = load( rsiMid + sourceStride * 2 );
+ rsiMid += 32;
+ __m256i r6 = load( rsiLast );
+ __m256i r7 = load( rsiLast + sourceStride );
+ rsiLast += 32;
+
+ // Transpose FP16 values in registers
+ TRANSPOSE_8X16();
+
+ // Store
+ STORE_8X16_LOW();
+ STORE_8X16_HIGH();
+
+ if constexpr( prefetchBytes > 0 )
+ {
+ if( rsi + prefetchBytes < rsiEnd )
+ {
+ prefetch( rsi + prefetchBytes );
+ prefetch( rsi + sourceStride + prefetchBytes );
+ prefetch( rsi + sourceStride * 2 + prefetchBytes );
+ prefetch( rsiMid + prefetchBytes );
+ prefetch( rsiMid + sourceStride + prefetchBytes );
+ prefetch( rsiMid + sourceStride * 2 + prefetchBytes );
+ prefetch( rsiLast + prefetchBytes );
+ prefetch( rsiLast + sourceStride + prefetchBytes );
+ }
+ }
+ }
+
+ if( rsi < rsiEnd )
+ {
+ // Loading 8 elements into corresponding lanes of 8 vectors
+ // This way there's no data dependencies between these load instructions
+ // Out of order execution should hopefully do it's magic in the CPU, running all these loads in parallel.
+ __m128i r0;
+ __m128i r1 = _mm_setzero_si128();
+ __m128i r2 = _mm_setzero_si128();
+ __m128i r3 = _mm_setzero_si128();
+ __m128i r4 = _mm_setzero_si128();
+ __m128i r5 = _mm_setzero_si128();
+ __m128i r6 = _mm_setzero_si128();
+ __m128i r7 = _mm_setzero_si128();
+
+ __m128i t0, t1, t2, t3, t4, t5, t6;
+
+#pragma loop( no_vector )
+ while( rsi < rsiEnd )
+ {
+ r0 = _mm_cvtsi32_si128( *(const uint16_t*)rsi );
+ r1 = _mm_insert_epi16( r1, *(const int16_t*)( rsi + sourceStride ), 1 );
+ r2 = _mm_insert_epi16( r2, *(const int16_t*)( rsi + sourceStride * 2 ), 2 );
+ rsi += 2;
+ r3 = _mm_insert_epi16( r3, *(const int16_t*)( rsiMid ), 3 );
+ r4 = _mm_insert_epi16( r4, *(const int16_t*)( rsiMid + sourceStride ), 4 );
+ r5 = _mm_insert_epi16( r5, *(const int16_t*)( rsiMid + sourceStride * 2 ), 5 );
+ rsiMid += 2;
+ r6 = _mm_insert_epi16( r6, *(const int16_t*)( rsiLast ), 6 );
+ r7 = _mm_insert_epi16( r7, *(const int16_t*)( rsiLast + sourceStride ), 7 );
+ rsiLast += 2;
+
+ // Bitwise operations are pretty fast, AMD Zen3 CPU can run 4 of them every clock cycle
+ // Combine 8 vectors into one
+ t0 = _mm_or_si128( r0, r1 );
+ t1 = _mm_or_si128( r2, r3 );
+ t2 = _mm_or_si128( r4, r5 );
+ t3 = _mm_or_si128( r6, r7 );
+
+ t4 = _mm_or_si128( t0, t1 );
+ t5 = _mm_or_si128( t2, t3 );
+
+ t6 = _mm_or_si128( t4, t5 );
+ // Store 8 FP16 values, the destination is aligned
+ _mm_store_si128( ( __m128i* )rdi, t6 );
+ rdi += destStride;
+ }
+ }
+ }
+
+ __forceinline void transpose8PartialAvx2( uint16_t* rdiWords, size_t w, size_t h, const uint16_t* rsiWords, size_t sourceStride, size_t destStride )
+ {
+ assert( 0 == ( (size_t)rdiWords ) % 16 );
+ assert( 0 == destStride % 8 );
+ assert( w <= sourceStride );
+ assert( h > 0 && h < 8 );
+
+ // Scale strides to bytes, and cast the pointers
+ sourceStride *= 2;
+ destStride *= 2;
+ uint8_t* rdi = (uint8_t*)rdiWords;
+ const uint8_t* rsi = (const uint8_t*)rsiWords;
+
+ const uint8_t* const rsiEndAligned = rsi + ( w & maskAlign16 ) * 2;
+ const uint8_t* const rsiEnd = rsi + w * 2;
+ const uint8_t* rsiMid = rsi + sourceStride * 3;
+ const uint8_t* rsiLast = rsi + sourceStride * 6;
+ uint8_t* rdiMid = rdi + destStride * 3;
+ uint8_t* rdiLast = rdi + destStride * 6;
+
+ while( rsi < rsiEndAligned )
+ {
+ // Load the block into 8 registers, set unused rows to zero
+ __m256i r0 = load( rsi );
+ __m256i r1 = _mm256_setzero_si256();
+ __m256i r2 = _mm256_setzero_si256();
+ __m256i r3 = _mm256_setzero_si256();
+ __m256i r4 = _mm256_setzero_si256();
+ __m256i r5 = _mm256_setzero_si256();
+ __m256i r6 = _mm256_setzero_si256();
+ // These branches, whether direct or indirect, are very predictable: same outcome for all iterations of the outer loop
+ switch( h )
+ {
+ case 7:
+ r6 = load( rsiLast );
+ case 6:
+ r5 = load( rsiMid + sourceStride * 2 );
+ case 5:
+ r4 = load( rsiMid + sourceStride );
+ case 4:
+ r3 = load( rsiMid );
+ case 3:
+ r2 = load( rsi + sourceStride * 2 );
+ case 2:
+ r1 = load( rsi + sourceStride );
+ }
+ rsi += 32;
+ rsiMid += 32;
+ rsiLast += 32;
+
+ __m256i r7 = _mm256_setzero_si256();
+
+ // Transpose FP16 values in registers
+ TRANSPOSE_8X16();
+
+ // Store
+ STORE_8X16_LOW();
+
+ STORE_8X16_HIGH();
+ }
+
+ if( rsi < rsiEnd )
+ {
+ // Loading 8 elements into corresponding lanes of 8 vectors
+ // This way there's no data dependencies between these load instructions
+ // Out of order execution should hopefully do it's magic in the CPU, running all these loads in parallel.
+ __m128i r0;
+ __m128i r1 = _mm_setzero_si128();
+ __m128i r2 = _mm_setzero_si128();
+ __m128i r3 = _mm_setzero_si128();
+ __m128i r4 = _mm_setzero_si128();
+ __m128i r5 = _mm_setzero_si128();
+ __m128i r6 = _mm_setzero_si128();
+
+ __m128i t0, t1, t2, t3, t4, t5;
+
+#pragma loop( no_vector )
+ while( rsi < rsiEnd )
+ {
+ r0 = _mm_cvtsi32_si128( *(const uint16_t*)rsi );
+
+ switch( h )
+ {
+ case 7:
+ r6 = _mm_insert_epi16( r6, *(const int16_t*)( rsiLast ), 6 );
+ case 6:
+ r5 = _mm_insert_epi16( r5, *(const int16_t*)( rsiMid + sourceStride * 2 ), 5 );
+ case 5:
+ r4 = _mm_insert_epi16( r4, *(const int16_t*)( rsiMid + sourceStride ), 4 );
+ case 4:
+ r3 = _mm_insert_epi16( r3, *(const int16_t*)( rsiMid ), 3 );
+ case 3:
+ r2 = _mm_insert_epi16( r2, *(const int16_t*)( rsi + sourceStride * 2 ), 2 );
+ case 2:
+ r1 = _mm_insert_epi16( r1, *(const int16_t*)( rsi + sourceStride ), 1 );
+ }
+ rsi += 2;
+ rsiMid += 2;
+ rsiLast += 2;
+
+ // Bitwise operations are pretty fast, AMD Zen3 CPU can run 4 of them every clock cycle
+ // Combine 7 vectors into one
+ t0 = _mm_or_si128( r0, r1 );
+ t1 = _mm_or_si128( r2, r3 );
+ t2 = _mm_or_si128( r4, r5 );
+
+ t3 = _mm_or_si128( t0, t1 );
+ t4 = _mm_or_si128( t2, r6 );
+
+ t5 = _mm_or_si128( t3, t4 );
+ // Store 8 FP16 values, the destination is aligned
+ _mm_store_si128( ( __m128i* )rdi, t5 );
+ rdi += destStride;
+ }
+ }
+ }
+}
+
+// At least for the hybrid decoder, this method absolutely dominates the CPU time.
+// And not due to the integer shuffles - the bottleneck is loading data from the source matrix.
+HRESULT MulMatBase::transposePanelAvx2( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 0 ] == 1 );
+
+ const size_t heightFloats = (size_t)panelHeightRegisters * 8;
+ i *= heightFloats;
+
+ const uint16_t* rsi = (const uint16_t*)pa;
+ rsi += m3 * stridesA[ 3 ];
+ rsi += m2 * stridesA[ 2 ];
+ rsi += i * stridesA[ 1 ];
+
+ const size_t resultStride = heightFloats;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel
+ for( size_t i = 0; i < panelHeightRegisters; i++ )
+ {
+ transpose8Avx2( rdi, length, rsi, stridesA[ 1 ], resultStride );
+ // Advance by 8 floats in the output buffer
+ rdi += 8;
+ // Advance by 8 rows in the source matrix
+ rsi += 8 * stridesA[ 1 ];
+ }
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+ zeroAlignedMemory( rdi, resultStride * length * sizeof( uint16_t ) );
+
+ const size_t completePanels = remainder / 8;
+ for( size_t i = 0; i < completePanels; i++ )
+ {
+ transpose8Avx2( rdi, length, rsi, stridesA[ 1 ], resultStride );
+ rdi += 8;
+ rsi += 8 * stridesA[ 1 ];
+ }
+ const size_t lastPanel = remainder % 8;
+ if( 0 != lastPanel )
+ transpose8PartialAvx2( rdi, length, lastPanel, rsi, stridesA[ 1 ], resultStride );
+ }
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/CPU/mulMatImpl.cpp b/Whisper/CPU/mulMatImpl.cpp
new file mode 100644
index 0000000..fc50b03
--- /dev/null
+++ b/Whisper/CPU/mulMatImpl.cpp
@@ -0,0 +1,213 @@
+#include "stdafx.h"
+#include <intrin.h>
+#include "mulMatImpl.h"
+#include "mulMat.kernel.hpp"
+
+#define DBG_TRACK_TEMPLATE_INSTANTIATION 0
+
+#if DBG_TRACK_TEMPLATE_INSTANTIATION
+#include <unordered_set>
+static std::unordered_set<uint16_t> g_mulMatTemplates;
+#endif
+
+namespace
+{
+ using namespace CpuCompute;
+
+ bool checkAvx2Support()
+ {
+ int cpuInfo[ 4 ];
+ __cpuid( cpuInfo, 7 );
+ return ( cpuInfo[ 1 ] & ( 1 << 5 ) ) != 0;
+ }
+
+ // a / b, rounded up to the next integer
+ inline uint32_t divRoundUp( uint32_t a, uint32_t b )
+ {
+ assert( b != 0 );
+ return ( a + ( b - 1 ) ) / b;
+ }
+}
+
+const bool MulMatBase::haveAvx2 = checkAvx2Support();
+
+MulMatBase::MulMatBase( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor, uint8_t panelHeightRegs, uint8_t tileWidthFloats ) :
+ resultPointer( result.fp32() ),
+ pa( a.data() ),
+ pb( b.data() ),
+ runner( pfor )
+{
+ length = a.ne[ 0 ];
+ resultStrides[ 0 ] = result.nb[ 1 ];
+ resultStrides[ 1 ] = result.nb[ 2 ];
+ resultStrides[ 2 ] = result.nb[ 3 ];
+ store( resultSize, result.sizeVec() );
+ store( stridesA, a.stridesVec() );
+ store( stridesB, b.stridesVec() );
+
+ countPanels = divRoundUp( resultSize[ 0 ], panelHeightRegs * 8 );
+ completeTilesPerPanel = resultSize[ 1 ] / tileWidthFloats;
+ lastColumnsInPanel = (uint8_t)( resultSize[ 1 ] % tileWidthFloats );
+ this->panelHeightRegisters = panelHeightRegs;
+ this->tileWidth = tileWidthFloats;
+
+ // Pick a method which reshapes a panel of the matrix A into the shape we need to compute the product
+ // Store the pointer to that method in the field of this class
+ if( a.nb[ 0 ] == 1 )
+ {
+ if( haveAvx2 )
+ pfnMakePanel = &MulMatBase::transposePanelAvx2;
+ else
+ pfnMakePanel = &MulMatBase::transposePanel;
+ }
+ else if( a.nb[ 1 ] == 1 )
+ {
+ switch( panelHeightRegs )
+ {
+ case 1:
+ pfnMakePanel = &MulMatBase::copyPanelColumnMajor8;
+ break;
+ case 2:
+ pfnMakePanel = &MulMatBase::copyPanelColumnMajor16;
+ break;
+ case 4:
+ pfnMakePanel = &MulMatBase::copyPanelColumnMajor32;
+ break;
+ default:
+ throw E_NOTIMPL;
+ }
+ }
+ else
+ pfnMakePanel = &MulMatBase::gatherPanel;
+
+ // That last version is generic and very simple, unlikely to have weird bugs
+ // pfnMakePanel = &MulMatBase::gatherPanel;
+
+#if DBG_TRACK_TEMPLATE_INSTANTIATION
+ uint16_t key = panelHeightRegs;
+ key = key << 8;
+ key |= tileWidthFloats;
+ if( !g_mulMatTemplates.emplace( key ).second )
+ return;
+ logDebug( u8"MulMatImpl<panelHeightRegs = %i, tileWidthFloats = %i>", (int)panelHeightRegs, (int)tileWidthFloats );
+#endif
+}
+
+HRESULT MulMatBase::run( ParallelForRunner& pfor )
+{
+ size_t length = (size_t)countPanels * resultSize[ 2 ] * resultSize[ 3 ];
+ return pfor.parallelFor( *this, length );
+}
+
+const float* MulMatBase::getLayerB( size_t m2, size_t m3 ) const
+{
+ const float* rsi = (const float*)this->pb;
+ rsi += m2 * stridesB[ 2 ];
+ rsi += m3 * stridesB[ 3 ];
+ return rsi;
+}
+
+// This method is the main one, it’s called by the thread pool
+template<uint8_t panelHeightRegs, uint8_t tileWidthFloats>
+HRESULT __stdcall MulMatImpl<panelHeightRegs, tileWidthFloats>::compute( size_t i, size_t end ) const noexcept
+{
+ // Allocate a thread-local buffer for the transposed panel
+ constexpr size_t panelHeightFloats = panelHeightRegs * 8;
+ uint16_t* const panel = (uint16_t*)runner.threadLocalBuffer( floatsPerPanel() * 2 );
+ const size_t resultStride = resultStrides[ 0 ];
+
+ // Load a few numbers from this class into local variables, while upcasting from DWORD into size_t
+ const size_t length = this->length;
+ const std::array<size_t, 2> stridesB{ this->stridesB[ 0 ], this->stridesB[ 1 ] };
+
+ // This outer loop iterates over the panels assigned to the current thread
+ // For example, matrix A of size [ 1024, 1024 ] may be split into panels of size [ 1024, 16 ]
+ // Each iteration of that loop computes matrix product of that panel, with the complete matrix B
+ for( ; i < end; i++ )
+ {
+ const size_t iPanel = i % countPanels;
+ size_t j = i / countPanels;
+ const size_t m2 = j % (size_t)resultSize[ 2 ];
+ const size_t m3 = j / (size_t)resultSize[ 2 ];
+
+ CHECK( ( this->*pfnMakePanel )( panel, iPanel, m2, m3 ) );
+ // We got a column-major panel in the thread local buffer, of size [ length, panelHeightRegs * 8 ]
+ // Hopefully, these buffers should all fit at least in L3 cache
+ // The longest matrix I saw in the debugger had 4096 elements, with panelHeightRegs = 4 that's 256 kb of data in the panel
+ const float* pb = getLayerB( m2, m3 );
+ float* rdi = getPanelDest( iPanel, m2, m3 );
+
+ const size_t storeWidth = std::min( panelHeightFloats, (size_t)resultSize[ 0 ] - iPanel * panelHeightFloats );
+ std::array<__m256, panelHeightRegs> vecPanel;
+#if 1
+ ResultTile<panelHeightRegs, tileWidthFloats> tile;
+
+ // This loop iterates over tiles within the panel.
+ // Each iteration of the loop computes an output tile of the result matrix.
+ for( j = 0; j < completeTilesPerPanel; j++, pb += tileWidthFloats * stridesB[ 1 ], rdi += resultStride * tileWidthFloats )
+ {
+ setZero( tile.arr );
+ const uint16_t* rsiA = panel;
+ const uint16_t* const rsiAEnd = panel + length * panelHeightFloats;
+ const float* rsiB = pb;
+ // This loop runs for `length` iterations, iterates over the first dimensions of both matrices, accumulating these dot products we're after
+ for( ; rsiA < rsiAEnd; rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] )
+ {
+ loadPanel( rsiA, vecPanel );
+ tile.kernel( vecPanel, rsiB, stridesB[ 1 ] );
+ }
+ tile.store( rdi, storeWidth, tileWidthFloats, resultStride );
+ }
+
+ if( 0 != lastColumnsInPanel )
+ {
+ setZero( tile.arr );
+ const uint16_t* rsiA = panel;
+ const uint16_t* rsiAEnd = panel + length * panelHeightFloats;
+ const float* rsiB = pb;
+ for( ; rsiA < rsiAEnd; rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] )
+ {
+ loadPanel( rsiA, vecPanel );
+ tile.kernelPartial( vecPanel, rsiB, stridesB[ 1 ], lastColumnsInPanel );
+ }
+ tile.store( rdi, storeWidth, lastColumnsInPanel, resultStride );
+ }
+#else
+ // This version bypasses horizontal tiling, instead implements a brute force algorithm to multiply the current panel by the complete B matrix
+ // Not terribly efficient, only implemented for debugging purposes
+ const size_t resHeight = resultSize[ 1 ];
+ std::array<__m256, panelHeightRegs> tile;
+ for( size_t j = 0; j < resHeight; j++, pb += stridesB[ 1 ], rdi += resultStride )
+ {
+ setZero( tile );
+
+ const uint16_t* rsiA = panel;
+ const uint16_t* const rsiAEnd = panel + length * panelHeightFloats;
+ const float* rsiB = pb;
+ for( size_t k = 0; k < length; k++, rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] )
+ {
+ loadPanel( rsiA, vecPanel );
+ const __m256 b = _mm256_broadcast_ss( rsiB );
+ for( size_t r = 0; r < panelHeightRegs; r++ )
+ tile[ r ] = _mm256_fmadd_ps( vecPanel[ r ], b, tile[ r ] );
+ }
+
+ alignas( 32 ) std::array<float, panelHeightFloats> arr;
+ for( size_t k = 0; k < panelHeightRegs; k++ )
+ _mm256_store_ps( &arr[ k * 8 ], tile[ k ] );
+ memcpy( rdi, arr.data(), storeWidth * 4 );
+ }
+#endif
+ }
+ return S_OK;
+}
+
+// Instantiate the templates we need
+template class MulMatImpl<4, 1>;
+template class MulMatImpl<1, 1>;
+template class MulMatImpl<4, 2>;
+template class MulMatImpl<1, 2>;
+template class MulMatImpl<2, 3>;
+template class MulMatImpl<1, 3>;
+template class MulMatImpl<2, 4>;
+template class MulMatImpl<1, 4>; \ No newline at end of file
diff --git a/Whisper/CPU/mulMatImpl.h b/Whisper/CPU/mulMatImpl.h
new file mode 100644
index 0000000..8e0062f
--- /dev/null
+++ b/Whisper/CPU/mulMatImpl.h
@@ -0,0 +1,106 @@
+#pragma once
+// Matrix*matrix multiplication is the most expensive algorithm in the model, by far.
+// For this reason, the code in this source file, and in the mulMat.kernel.hpp header, is optimized for performance. Readability suffers.
+// The implementation is inspired by following two articles:
+// https://gist.github.com/nadavrot/5b35d44e8ba3dd718e595e40184d03f0
+// https://link.springer.com/article/10.1007/s11227-022-05003-3
+#include "ParallelForRunner.h"
+#include "Tensor.h"
+
+namespace CpuCompute
+{
+ // Abstract base class for all implementations, to reduce binary size
+ class MulMatBase : public iComputeRange
+ {
+ protected:
+ // Pointers to the payload of the output matrix
+ float* const resultPointer;
+
+ // Lengths of the dot products to compute, equal to width of both source matrices
+ uint32_t length;
+
+ // Last 3 strides of the output matrix, expressed as count of elements. The first one is always 1 because the output matrix is continuous.
+ std::array<uint32_t, 3> resultStrides;
+
+ // Size of the output matrix
+ std::array<uint32_t, 4> resultSize;
+
+ // Pointers to the payload of the source matrices
+ const void* const pa;
+ const void* const pb;
+
+ // Matrix strides, expressed as count of elements
+ std::array<uint32_t, 4> stridesA, stridesB;
+
+ // Total count of panels in the layer of the output matrix.
+ // The last panel might be incomplete, with smaller height.
+ // The thread-local buffer however is always complete, unused elements will be zeros.
+ uint32_t countPanels;
+
+ // Complete tiles in the length of the panel
+ uint32_t completeTilesPerPanel;
+
+ // Count of the last remainder columns in the panel, can be 0
+ uint8_t lastColumnsInPanel;
+
+ // Same as panelHeightRegs template argument - height of the panels, in AVX vectors
+ uint8_t panelHeightRegisters;
+
+ // Same as tileWidthFloats template argument - width of the tile, in floats
+ uint8_t tileWidth;
+
+ // Method pointer to reshape a panel from the source matrix into a thread-local buffer
+ using pfnTransposePanel = HRESULT( MulMatBase::* )( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
+ pfnTransposePanel pfnMakePanel;
+ // The object which implements multithreading for this job, and supplies memory for thread-local buffers
+ ParallelForRunner& runner;
+
+ // Count of FP16 values in the thread-local panel buffer
+ uint32_t floatsPerPanel() const
+ {
+ return length * panelHeightRegisters * 8;
+ }
+
+ // Transpose a horizontal panel of the first matrix, when the rows are continuous in that matrix
+ HRESULT transposePanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
+ HRESULT transposePanelAvx2( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
+ // Copy a horizontal panel of the first matrix without transpose, for column major layout of that matrix
+ HRESULT copyPanelColumnMajor8( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
+ HRESULT copyPanelColumnMajor16( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
+ HRESULT copyPanelColumnMajor32( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
+ // Transpose a panel of the first matrix for irregular layout of that matrix, when neither rows nor columns are at sequential addresses.
+ // This one ain't implemented yet.
+ HRESULT gatherPanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
+
+ const uint16_t* getPanelA( size_t i, size_t m2, size_t m3 ) const;
+ // Pointer to the first element of the second source matrix in the specified layer
+ const float* getLayerB( size_t m2, size_t m3 ) const;
+
+ // Pointer to the first element of the output tile of the result matrix
+ float* getPanelDest( size_t i, size_t m2, size_t m3 ) const
+ {
+ float* rdi = resultPointer;
+ rdi += m2 * resultStrides[ 1 ];
+ rdi += m3 * resultStrides[ 2 ];
+ rdi += i * panelHeightRegisters * 8;
+ return rdi;
+ }
+
+ static const bool haveAvx2;
+ public:
+ MulMatBase( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor, uint8_t panelHeightRegs, uint8_t tileWidthFloats );
+ HRESULT run( ParallelForRunner& pfor );
+ };
+
+ // This class actually contains the kernels implementations
+ template<uint8_t panelHeightRegs, uint8_t tileWidthFloats>
+ class MulMatImpl : public MulMatBase
+ {
+ HRESULT __stdcall compute( size_t i, size_t end ) const noexcept override final;
+
+ public:
+ MulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) :
+ MulMatBase( result, a, b, pfor, panelHeightRegs, tileWidthFloats )
+ { }
+ };
+} \ No newline at end of file
diff --git a/Whisper/CPU/mulMatImpl.panel.cpp b/Whisper/CPU/mulMatImpl.panel.cpp
new file mode 100644
index 0000000..f3baf21
--- /dev/null
+++ b/Whisper/CPU/mulMatImpl.panel.cpp
@@ -0,0 +1,274 @@
+#include "stdafx.h"
+#include <intrin.h>
+#include "mulMatImpl.h"
+#include "mulMatUtils.hpp"
+using namespace CpuCompute;
+
+// We want to keep code size reasonable, that's why these panel reshaping methods are in the base class
+HRESULT MulMatBase::transposePanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 0 ] == 1 );
+
+ const size_t heightFloats = (size_t)panelHeightRegisters * 8;
+ i *= heightFloats;
+
+ const uint16_t* rsi = (const uint16_t*)pa;
+ rsi += m3 * stridesA[ 3 ];
+ rsi += m2 * stridesA[ 2 ];
+ rsi += i * stridesA[ 1 ];
+
+ const size_t resultStride = heightFloats;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel
+ for( size_t i = 0; i < panelHeightRegisters; i++ )
+ {
+ transpose8( rdi, length, rsi, stridesA[ 1 ], resultStride );
+ // Advance by 8 floats in the output buffer
+ rdi += 8;
+ // Advance by 8 rows in the source matrix
+ rsi += 8 * stridesA[ 1 ];
+ }
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+ zeroAlignedMemory( rdi, resultStride * length * sizeof( uint16_t ) );
+
+ const size_t completePanels = remainder / 8;
+ for( size_t i = 0; i < completePanels; i++ )
+ {
+ transpose8( rdi, length, rsi, stridesA[ 1 ], resultStride );
+ rdi += 8;
+ rsi += 8 * stridesA[ 1 ];
+ }
+ const size_t lastPanel = remainder % 8;
+ if( 0 != lastPanel )
+ transpose8Partial( rdi, length, lastPanel, rsi, stridesA[ 1 ], resultStride );
+ }
+ return S_OK;
+}
+
+inline const uint16_t* MulMatBase::getPanelA( size_t i, size_t m2, size_t m3 ) const
+{
+ const uint16_t* rsi = (const uint16_t*)pa;
+ rsi += m3 * stridesA[ 3 ];
+ rsi += m2 * stridesA[ 2 ];
+ rsi += i * stridesA[ 1 ];
+ return rsi;
+}
+
+HRESULT MulMatBase::copyPanelColumnMajor8( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 1 ] == 1 );
+ assert( panelHeightRegisters == 1 );
+
+ constexpr size_t heightFloats = 8;
+ i *= heightFloats;
+ const uint16_t* rsi = getPanelA( i, m2, m3 );
+
+ constexpr size_t resultStride = heightFloats;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel, height = 8 elements
+ copyColumnMajor( rdi, length, rsi, stridesA[ 0 ], resultStride );
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+ copyColumnMajorPartial( rdi, length, remainder, rsi, stridesA[ 0 ], resultStride );
+ }
+ return S_OK;
+}
+
+__forceinline __m128i load8Partial( const uint16_t* x, size_t len )
+{
+ assert( len > 0 && len < 8 );
+ __m128i ix = _mm_setzero_si128();
+ switch( len )
+ {
+ case 1: // load 2 bytes
+ ix = _mm_cvtsi32_si128( *x );
+ break;
+ case 2: // load 4 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ break;
+ case 3: // load 6 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ ix = _mm_insert_epi16( ix, x[ 2 ], 2 );
+ break;
+ case 4: // load 8 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ break;
+ case 5: // load 10 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi16( ix, x[ 4 ], 4 );
+ break;
+ case 6: // load 12 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ break;
+ case 7: // load 14 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ ix = _mm_insert_epi16( ix, x[ 6 ], 6 );
+ break;
+ }
+ return ix;
+}
+
+__forceinline __m256i load16Partial( const uint16_t* rsi, size_t len )
+{
+ assert( len > 0 && len < 16 );
+
+ if( len < 8 )
+ {
+ __m128i low = load8Partial( rsi, len );
+ return _mm256_setr_m128i( low, _mm_setzero_si128() );
+ }
+ else if( len > 8 )
+ {
+ __m128i low = load16( (const int*)rsi );
+ __m128i high = load8Partial( rsi + 8, len - 8 );
+ return _mm256_setr_m128i( low, high );
+ }
+ else
+ {
+ __m128i low = load16( (const int*)rsi );
+ return _mm256_setr_m128i( low, _mm_setzero_si128() );
+ }
+}
+
+HRESULT MulMatBase::copyPanelColumnMajor16( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 1 ] == 1 );
+ assert( panelHeightRegisters == 2 );
+
+ constexpr size_t heightFloats = 16;
+ i *= heightFloats;
+
+ const uint16_t* rsi = getPanelA( i, m2, m3 );
+ uint16_t* const rdiEnd = rdi + 16 * length;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel, height = 16 elements
+ for( ; rdi < rdiEnd; rdi += 16, rsi += stridesA[ 0 ] )
+ {
+ __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ }
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+
+ for( ; rdi < rdiEnd; rdi += 16, rsi += stridesA[ 0 ] )
+ {
+ __m256i v = load16Partial( rsi, remainder );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ }
+ }
+ return S_OK;
+}
+
+HRESULT MulMatBase::copyPanelColumnMajor32( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 1 ] == 1 );
+ assert( panelHeightRegisters == 4 );
+
+ constexpr size_t heightFloats = 32;
+ i *= heightFloats;
+
+ const uint16_t* rsi = getPanelA( i, m2, m3 );
+ uint16_t* const rdiEnd = rdi + 32 * length;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel, height = 32 elements
+ for( ; rdi < rdiEnd; rdi += 32, rsi += stridesA[ 0 ] )
+ {
+ __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ v = _mm256_loadu_si256( ( const __m256i* )( rsi + 16 ) );
+ _mm256_store_si256( ( __m256i* )( rdi + 16 ), v );
+ }
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+
+ // _mm256_setzero_si256 probably compiles into vpxor, that's AVX2, we don't want that here
+ const __m256 zero = _mm256_setzero_ps();
+
+ for( ; rdi < rdiEnd; rdi += 32, rsi += stridesA[ 0 ] )
+ {
+ if( remainder < 16 )
+ {
+ __m256i v = load16Partial( rsi, remainder );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ _mm256_store_ps( (float*)( rdi + 16 ), zero );
+ }
+ else if( remainder > 16 )
+ {
+ __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ v = load16Partial( rsi + 16, remainder - 16 );
+ _mm256_store_si256( ( __m256i* )( rdi + 16 ), v );
+ }
+ else
+ {
+ __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ _mm256_store_ps( (float*)( rdi + 16 ), zero );
+ }
+ }
+ }
+ return S_OK;
+}
+
+HRESULT MulMatBase::gatherPanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ // BTW, I never saw this method called.
+ const size_t heightFloats = (size_t)panelHeightRegisters * 8;
+ const size_t length = this->length;
+
+ zeroAlignedMemory( rdi, length * heightFloats * sizeof( uint16_t ) );
+
+ const size_t height = std::min( heightFloats, resultSize[ 0 ] - i );
+ const size_t strideElement = stridesA[ 0 ];
+ const size_t strideRow = stridesA[ 1 ];
+ const uint16_t* rsi = getPanelA( i * heightFloats, m2, m3 );
+
+ if( strideElement < strideRow )
+ {
+ for( size_t r = 0; r < height; r++, rsi += strideRow, rdi++ )
+ {
+ const uint16_t* sourceRow = rsi;
+ uint16_t* destRow = rdi;
+ for( size_t c = 0; c < length; c++, sourceRow += strideElement, destRow += heightFloats )
+ *destRow = *sourceRow;
+ }
+ }
+ else
+ {
+ for( size_t c = 0; c < length; c++, rsi += strideElement, rdi += heightFloats )
+ {
+ const uint16_t* sourceCol = rsi;
+ uint16_t* destCol = rdi;
+ for( size_t r = 0; r < height; r++, sourceCol += strideRow, destCol++ )
+ *destCol = *sourceCol;
+ }
+ }
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/CPU/mulMatUtils.hpp b/Whisper/CPU/mulMatUtils.hpp
new file mode 100644
index 0000000..1276323
--- /dev/null
+++ b/Whisper/CPU/mulMatUtils.hpp
@@ -0,0 +1,301 @@
+#pragma once
+#include <immintrin.h>
+#include <stdint.h>
+#include <assert.h>
+
+__forceinline __m128i f16Load( const uint16_t* rsi )
+{
+ return _mm_loadu_si128( ( const __m128i* )rsi );
+}
+
+constexpr size_t maskAlign8 = ~(size_t)7;
+
+__forceinline void transpose8( uint16_t* rdi, size_t w, const uint16_t* rsi, size_t sourceStride, size_t destStride )
+{
+ assert( 0 == ( (size_t)rdi ) % 16 );
+ assert( 0 == destStride % 8 );
+ assert( w <= sourceStride );
+
+ const uint16_t* const rsiEndAligned = rsi + ( w & maskAlign8 );
+ const uint16_t* rsi5 = rsi + sourceStride * 5;
+ uint16_t* rdi5 = rdi + destStride * 5;
+ const size_t rem = w % 8;
+ for( ; rsi < rsiEndAligned; rsi += 8, rsi5 += 8, rdi += 8 * destStride, rdi5 += 8 * destStride )
+ {
+ // Load 8x8 block into 8 registers
+ __m128i r0 = f16Load( rsi ); // 00, 01, 02, 03, 04, 05, 06, 07
+ __m128i r1 = f16Load( rsi + sourceStride ); // 10, 11, 12, 13, 14, 15, 16, 17
+ __m128i r2 = f16Load( rsi + sourceStride * 2 ); // 20, 21, 22, 23, 24, 25, 26, 27
+ __m128i r3 = f16Load( rsi5 - sourceStride * 2 ); // 30, 31, 32, 33, 34, 35, 36, 37
+ __m128i r4 = f16Load( rsi5 - sourceStride ); // 40, 41, 42, 43, 44, 45, 46, 47
+ __m128i r5 = f16Load( rsi5 ); // 50, 51, 52, 53, 54, 55, 56, 57
+ __m128i r6 = f16Load( rsi5 + sourceStride ); // 60, 61, 62, 63, 64, 65, 66, 67
+ __m128i r7 = f16Load( rsi5 + sourceStride * 2 ); // 70, 71, 72, 73, 74, 75, 76, 77
+
+ // Transpose FP16 values in registers
+ __m128i t0 = _mm_unpacklo_epi16( r0, r1 ); // 00, 10, 01, 11, 02, 12, 03, 13
+ __m128i t1 = _mm_unpackhi_epi16( r0, r1 ); // 04, 14, 05, 15, 06, 16, 07, 17
+ __m128i t2 = _mm_unpacklo_epi16( r2, r3 ); // 20, 30, 21, 31, 22, 32, 23, 33
+ __m128i t3 = _mm_unpackhi_epi16( r2, r3 ); // 24, 34, 25, 35, 26, 36, 27, 37
+ __m128i t4 = _mm_unpacklo_epi16( r4, r5 ); // 40, 50, 41, 52, 42, 52, 43, 53
+ __m128i t5 = _mm_unpackhi_epi16( r4, r5 ); // 44, 54, 45, 55, 46, 56, 47, 57
+ __m128i t6 = _mm_unpacklo_epi16( r6, r7 ); // 60, 70, 61, 71, 62, 72, 63, 73
+ __m128i t7 = _mm_unpackhi_epi16( r6, r7 ); // 64, 74, 65, 75, 66, 76, 67, 77
+
+ r0 = _mm_unpacklo_epi32( t0, t2 ); // 00, 10, 20, 30, 01, 11, 21, 31
+ r1 = _mm_unpackhi_epi32( t0, t2 ); // 02, 12, 22, 32, 03, 13, 23, 33
+ r2 = _mm_unpacklo_epi32( t1, t3 ); // 04, 14, 24, 34, 05, 15, 25, 35
+ r3 = _mm_unpackhi_epi32( t1, t3 ); // 06, 16, 26, 36, 07, 17, 27, 37
+ r4 = _mm_unpacklo_epi32( t4, t6 ); // 40, 50, 60, 70, 41, 51, 61, 71
+ r5 = _mm_unpackhi_epi32( t4, t6 ); // 42, 52, 62, 72, 43, 53, 63, 73
+ r6 = _mm_unpacklo_epi32( t5, t7 ); // 44, 54, 64, 74, 45, 55, 65, 75
+ r7 = _mm_unpackhi_epi32( t5, t7 ); // 46, 56, 66, 76, 47, 57, 67, 77
+
+ t0 = _mm_unpacklo_epi64( r0, r4 ); // 00, 10, 20, 30, 40, 50, 60, 70
+ t1 = _mm_unpackhi_epi64( r0, r4 ); // 01, 11, 21, 31, 41, 52, 61, 71
+ t2 = _mm_unpacklo_epi64( r1, r5 ); // 02, 12, 22, 32, 42, 52, 62, 72
+ t3 = _mm_unpackhi_epi64( r1, r5 ); // 03, 13, 23, 33, 43, 53, 63, 73
+ t4 = _mm_unpacklo_epi64( r2, r6 );
+ t5 = _mm_unpackhi_epi64( r2, r6 );
+ t6 = _mm_unpacklo_epi64( r3, r7 );
+ t7 = _mm_unpackhi_epi64( r3, r7 );
+
+ // Store
+ store16( rdi, t0 );
+ store16( rdi + destStride, t1 );
+ store16( rdi + destStride * 2, t2 );
+ store16( rdi5 - destStride * 2, t3 );
+ store16( rdi5 - destStride, t4 );
+ store16( rdi5, t5 );
+ store16( rdi5 + destStride, t6 );
+ store16( rdi5 + destStride * 2, t7 );
+ }
+
+#pragma loop( no_vector )
+ for( size_t i = 0; i < rem; rsi++, rsi5++, rdi += destStride )
+ {
+ const int16_t* p0 = (const int16_t*)rsi;
+ const int16_t* p5 = (const int16_t*)rsi5;
+ // Load a complete column into a vector
+ __m128i v = _mm_cvtsi32_si128( *rsi );
+ v = _mm_insert_epi16( v, *( p0 + sourceStride ), 1 );
+ v = _mm_insert_epi16( v, *( p0 + sourceStride * 2 ), 2 );
+ v = _mm_insert_epi16( v, *( p5 - sourceStride * 2 ), 3 );
+ v = _mm_insert_epi16( v, *( p5 - sourceStride ), 4 );
+ v = _mm_insert_epi16( v, *( p5 ), 5 );
+ v = _mm_insert_epi16( v, *( p5 + sourceStride ), 6 );
+ v = _mm_insert_epi16( v, *( p5 + sourceStride * 2 ), 7 );
+ // Store 8 FP16 values
+ store16( rdi, v );
+ }
+}
+
+inline void transpose8Partial( uint16_t* rdi, size_t w, size_t h, const uint16_t* rsi, size_t sourceStride, size_t destStride )
+{
+ assert( 0 == ( (size_t)rdi ) % 16 );
+ assert( 0 == destStride % 8 );
+ assert( w <= sourceStride );
+ assert( h > 0 && h < 8 );
+
+ const uint16_t* const rsiEndAligned = rsi + ( w & maskAlign8 );
+ const uint16_t* rsi5 = rsi + sourceStride * 5;
+ uint16_t* rdi5 = rdi + destStride * 5;
+ const size_t rem = w % 8;
+ for( ; rsi < rsiEndAligned; rsi += 8, rsi5 += 8, rdi += 8 * destStride, rdi5 += 8 * destStride )
+ {
+ // Load the block into 8 registers, set unused rows to zero
+ __m128i r0 = f16Load( rsi );
+ __m128i r1 = _mm_setzero_si128();
+ __m128i r2 = _mm_setzero_si128();
+ __m128i r3 = _mm_setzero_si128();
+ __m128i r4 = _mm_setzero_si128();
+ __m128i r5 = _mm_setzero_si128();
+ __m128i r6 = _mm_setzero_si128();
+ // These branches, whether direct or indirect, are very predictable: same outcome for all iterations of the outer loop
+ switch( h )
+ {
+ case 7:
+ r6 = f16Load( rsi5 + sourceStride );
+ case 6:
+ r5 = f16Load( rsi5 );
+ case 5:
+ r4 = f16Load( rsi5 - sourceStride );
+ case 4:
+ r3 = f16Load( rsi5 - sourceStride * 2 );
+ case 3:
+ r2 = f16Load( rsi + sourceStride * 2 );
+ case 2:
+ r1 = f16Load( rsi + sourceStride );
+ }
+ __m128i r7 = _mm_setzero_si128();
+
+ // Transpose FP16 values in registers
+ __m128i t0 = _mm_unpacklo_epi16( r0, r1 ); // 00, 10, 01, 11, 02, 12, 03, 13
+ __m128i t1 = _mm_unpackhi_epi16( r0, r1 ); // 04, 14, 05, 15, 06, 16, 07, 17
+ __m128i t2 = _mm_unpacklo_epi16( r2, r3 ); // 20, 30, 21, 31, 22, 32, 23, 33
+ __m128i t3 = _mm_unpackhi_epi16( r2, r3 ); // 24, 34, 25, 35, 26, 36, 27, 37
+ __m128i t4 = _mm_unpacklo_epi16( r4, r5 ); // 40, 50, 41, 52, 42, 52, 43, 53
+ __m128i t5 = _mm_unpackhi_epi16( r4, r5 ); // 44, 54, 45, 55, 46, 56, 47, 57
+ __m128i t6 = _mm_unpacklo_epi16( r6, r7 ); // 60, 70, 61, 71, 62, 72, 63, 73
+ __m128i t7 = _mm_unpackhi_epi16( r6, r7 ); // 64, 74, 65, 75, 66, 76, 67, 77
+
+ r0 = _mm_unpacklo_epi32( t0, t2 ); // 00, 10, 20, 30, 01, 11, 21, 31
+ r1 = _mm_unpackhi_epi32( t0, t2 ); // 02, 12, 22, 32, 03, 13, 23, 33
+ r2 = _mm_unpacklo_epi32( t1, t3 ); // 04, 14, 24, 34, 05, 15, 25, 35
+ r3 = _mm_unpackhi_epi32( t1, t3 ); // 06, 16, 26, 36, 07, 17, 27, 37
+ r4 = _mm_unpacklo_epi32( t4, t6 ); // 40, 50, 60, 70, 41, 51, 61, 71
+ r5 = _mm_unpackhi_epi32( t4, t6 ); // 42, 52, 62, 72, 43, 53, 63, 73
+ r6 = _mm_unpacklo_epi32( t5, t7 ); // 44, 54, 64, 74, 45, 55, 65, 75
+ r7 = _mm_unpackhi_epi32( t5, t7 ); // 46, 56, 66, 76, 47, 57, 67, 77
+
+ t0 = _mm_unpacklo_epi64( r0, r4 ); // 00, 10, 20, 30, 40, 50, 60, 70
+ t1 = _mm_unpackhi_epi64( r0, r4 ); // 01, 11, 21, 31, 41, 52, 61, 71
+ t2 = _mm_unpacklo_epi64( r1, r5 ); // 02, 12, 22, 32, 42, 52, 62, 72
+ t3 = _mm_unpackhi_epi64( r1, r5 ); // 03, 13, 23, 33, 43, 53, 63, 73
+ t4 = _mm_unpacklo_epi64( r2, r6 );
+ t5 = _mm_unpackhi_epi64( r2, r6 );
+ t6 = _mm_unpacklo_epi64( r3, r7 );
+ t7 = _mm_unpackhi_epi64( r3, r7 );
+
+ // Store
+ store16( rdi, t0 );
+ store16( rdi + destStride, t1 );
+ store16( rdi + destStride * 2, t2 );
+ store16( rdi5 - destStride * 2, t3 );
+ store16( rdi5 - destStride, t4 );
+ store16( rdi5, t5 );
+ store16( rdi5 + destStride, t6 );
+ store16( rdi5 + destStride * 2, t7 );
+ }
+
+#pragma loop( no_vector )
+ for( size_t i = 0; i < rem; rsi++, rsi5++, rdi += destStride )
+ {
+ const int16_t* p0 = (const int16_t*)rsi;
+ const int16_t* p5 = (const int16_t*)rsi5;
+ // Load a partial column into vector
+ __m128i v = _mm_cvtsi32_si128( *rsi );
+ switch( h )
+ {
+ case 7:
+ v = _mm_insert_epi16( v, *( p5 + sourceStride ), 6 );
+ case 6:
+ v = _mm_insert_epi16( v, *( p5 ), 5 );
+ case 5:
+ v = _mm_insert_epi16( v, *( p5 - sourceStride ), 4 );
+ case 4:
+ v = _mm_insert_epi16( v, *( p5 - sourceStride * 2 ), 3 );
+ case 3:
+ v = _mm_insert_epi16( v, *( p0 + sourceStride * 2 ), 2 );
+ case 2:
+ v = _mm_insert_epi16( v, *( p0 + sourceStride ), 1 );
+ }
+ // Store 8 FP16 values
+ store16( rdi, v );
+ }
+}
+
+// Same as above, but skip the transpose. The source stride is distance between columns of the matrix.
+__forceinline void copyColumnMajor( uint16_t* rdi, size_t w, const uint16_t* rsi, size_t sourceStride, size_t destStride )
+{
+ assert( 0 == ( (size_t)rdi ) % 16 );
+ assert( 0 == destStride % 8 );
+
+ constexpr size_t maskAlign4 = ~(size_t)3;
+
+ const uint16_t* const rsiEndAligned = rsi + sourceStride * ( w & maskAlign4 );
+ const uint16_t* const rsiEnd = rsi + sourceStride * w;
+ for( ; rsi < rsiEndAligned; rsi += sourceStride * 4, rdi += destStride * 4 )
+ {
+ __m128i c = f16Load( rsi );
+ store16( rdi, c );
+
+ c = f16Load( rsi + sourceStride );
+ store16( rdi + destStride, c );
+
+ c = f16Load( rsi + sourceStride * 2 );
+ store16( rdi + destStride * 2, c );
+
+ c = f16Load( rsi + sourceStride * 3 );
+ store16( rdi + destStride * 3, c );
+ }
+
+ for( ; rsi < rsiEnd; rsi += sourceStride, rdi += destStride )
+ {
+ __m128i c = f16Load( rsi );
+ store16( rdi, c );
+ }
+}
+
+__forceinline __m128i loadPartial( const uint16_t* x, size_t count )
+{
+ assert( count < 8 );
+ __m128i ix;
+ switch( count )
+ {
+ case 1: // load 2 bytes
+ ix = _mm_cvtsi32_si128( *x );
+ break;
+ case 2: // load 4 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ break;
+ case 3: // load 6 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ ix = _mm_insert_epi16( ix, x[ 2 ], 2 );
+ break;
+ case 4: // load 8 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ break;
+ case 5: // load 10 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi16( ix, x[ 4 ], 4 );
+ break;
+ case 6: // load 12 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ break;
+ case 7: // load 14 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ ix = _mm_insert_epi16( ix, x[ 6 ], 6 );
+ break;
+ default:
+ return _mm_setzero_si128();
+ }
+ return ix;
+}
+
+inline void copyColumnMajorPartial( uint16_t* rdi, size_t w, size_t h, const uint16_t* rsi, size_t sourceStride, size_t destStride )
+{
+ assert( 0 == ( (size_t)rdi ) % 32 );
+ assert( 0 == destStride % 8 );
+ assert( h > 0 && h < 8 );
+
+ const uint16_t* const rsiEnd = rsi + sourceStride * w;
+ for( ; rsi < rsiEnd; rsi += sourceStride, rdi += destStride )
+ {
+ // Can't use mask loads because loading 2-byte elements
+ // Still, that switch() in loadPartial makes a very predictable branch, same outcome for all iterations of this loop.
+ __m128i c = loadPartial( rsi, h );
+ store16( rdi, c );
+ }
+}
+
+// Store zeros into block of memory, with aligned AVX store instructions
+__forceinline void zeroAlignedMemory( void* pv, size_t cb )
+{
+ assert( 0 == cb % 16 );
+ assert( 0 == ( (size_t)pv % 32 ) );
+
+ uint8_t* rdi = (uint8_t*)pv;
+ constexpr size_t maskAlign32 = ~(size_t)31;
+ uint8_t* const rdiEndAligned = rdi + ( cb & maskAlign32 );
+ uint8_t* const rdiEnd = rdi + cb;
+
+ const __m256 zero = _mm256_setzero_ps();
+ for( ; rdi < rdiEndAligned; rdi += 32 )
+ _mm256_store_ps( (float*)rdi, zero );
+
+ if( rdi < rdiEnd )
+ _mm_store_ps( (float*)rdi, _mm_setzero_ps() );
+} \ No newline at end of file
diff --git a/Whisper/CPU/simdUtils.cpp b/Whisper/CPU/simdUtils.cpp
new file mode 100644
index 0000000..0e5f77d
--- /dev/null
+++ b/Whisper/CPU/simdUtils.cpp
@@ -0,0 +1,738 @@
+#include "stdafx.h"
+#include "simdUtils.h"
+#include "../ML/LookupTablesData.h"
+#include <cmath>
+#include <memory>
+
+namespace
+{
+ constexpr size_t maskAlign8 = ~(size_t)7;
+
+ __forceinline __m256 load8( const uint16_t* rsi )
+ {
+ __m128i i = _mm_loadu_si128( ( const __m128i* )rsi );
+ return _mm256_cvtph_ps( i );
+ }
+
+ __forceinline void loadPartial( const uint16_t* x, const uint16_t* y, size_t count, __m256& fx, __m256& fy )
+ {
+ assert( count < 8 );
+
+ __m128i ix, iy;
+ switch( count )
+ {
+ case 1: // load 2 bytes
+ ix = _mm_cvtsi32_si128( *x );
+ iy = _mm_cvtsi32_si128( *y );
+ break;
+ case 2: // load 4 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ iy = _mm_cvtsi32_si128( *(const int*)y );
+ break;
+ case 3: // load 6 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ iy = _mm_cvtsi32_si128( *(const int*)y );
+ ix = _mm_insert_epi16( ix, x[ 2 ], 2 );
+ iy = _mm_insert_epi16( iy, y[ 2 ], 2 );
+ break;
+ case 4: // load 8 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ break;
+ case 5: // load 10 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi16( ix, x[ 4 ], 4 );
+ iy = _mm_insert_epi16( iy, y[ 4 ], 4 );
+ break;
+ case 6: // load 12 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 );
+ break;
+ case 7: // load 14 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 );
+ ix = _mm_insert_epi16( ix, x[ 6 ], 6 );
+ iy = _mm_insert_epi16( iy, y[ 6 ], 6 );
+ break;
+ default:
+ fx = fy = _mm256_setzero_ps();
+ return;
+ }
+
+ fx = _mm256_cvtph_ps( ix );
+ fy = _mm256_cvtph_ps( iy );
+ }
+
+ __forceinline __m256 loadPartial( const uint16_t* x, size_t count )
+ {
+ assert( count < 8 );
+ __m128i ix;
+ switch( count )
+ {
+ case 1: // load 2 bytes
+ ix = _mm_cvtsi32_si128( *x );
+ break;
+ case 2: // load 4 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ break;
+ case 3: // load 6 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ ix = _mm_insert_epi16( ix, x[ 2 ], 2 );
+ break;
+ case 4: // load 8 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ break;
+ case 5: // load 10 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi16( ix, x[ 4 ], 4 );
+ break;
+ case 6: // load 12 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ break;
+ case 7: // load 14 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ ix = _mm_insert_epi16( ix, x[ 6 ], 6 );
+ break;
+ default:
+ return _mm256_setzero_ps();
+ }
+ return _mm256_cvtph_ps( ix );
+ }
+
+ __forceinline __m128 loadFloat2( const float* rsi )
+ {
+ return _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) );
+ }
+ __forceinline __m128 loadFloat3( const float* rsi )
+ {
+ __m128 f = loadFloat2( rsi );
+ f = _mm_insert_ps( f, _mm_load_ss( rsi + 2 ), 0x20 );
+ return f;
+ }
+
+ __forceinline __m256 loadPartial( const float* rsi, size_t count )
+ {
+ assert( count < 8 );
+ __m128 low = _mm_setzero_ps();
+ __m128 high = _mm_setzero_ps();
+ switch( count )
+ {
+ case 1:
+ low = _mm_load_ss( rsi );
+ break;
+ case 2:
+ low = loadFloat2( rsi );
+ break;
+ case 3:
+ low = loadFloat3( rsi );
+ break;
+ case 4:
+ low = _mm_loadu_ps( rsi );
+ break;
+ case 5:
+ low = _mm_loadu_ps( rsi );
+ high = _mm_load_ss( rsi + 4 );
+ break;
+ case 6:
+ low = _mm_loadu_ps( rsi );
+ high = loadFloat2( rsi + 4 );
+ break;
+ case 7:
+ low = _mm_loadu_ps( rsi );
+ high = loadFloat3( rsi + 4 );
+ break;
+ }
+ return _mm256_setr_m128( low, high );
+ }
+
+ __forceinline void storeFloat2( float* rdi, __m128 vec )
+ {
+ _mm_store_sd( (double*)rdi, _mm_castps_pd( vec ) );
+ }
+
+ __forceinline void storePartial( float* rdi, __m256 vec, size_t count )
+ {
+ assert( count < 8 );
+
+ __m128 tmp = _mm256_castps256_ps128( vec );
+ if( count >= 4 )
+ {
+ _mm_storeu_ps( rdi, tmp );
+ if( count == 4 )
+ return;
+ count -= 4;
+ rdi += 4;
+ tmp = _mm256_extractf128_ps( vec, 1 );
+ }
+
+ switch( count )
+ {
+ case 1:
+ _mm_store_ss( rdi, tmp );
+ return;
+ case 2:
+ storeFloat2( rdi, tmp );
+ return;
+ case 3:
+ storeFloat2( rdi, tmp );
+ ( (int*)rdi )[ 2 ] = _mm_extract_ps( tmp, 2 );
+ return;
+ }
+ }
+}
+
+void addF16to32( float* rdi, const uint16_t* a, const uint16_t* b, size_t length )
+{
+ const uint16_t* const endAligned = a + ( length & maskAlign8 );
+ const size_t rem = length % 8;
+
+ for( ; a < endAligned; a += 8, b += 8, rdi += 8 )
+ {
+ __m256 f1 = load8( a );
+ __m256 f2 = load8( b );
+ __m256 res = _mm256_add_ps( f1, f2 );
+ _mm256_storeu_ps( rdi, res );
+ }
+
+ if( rem != 0 )
+ {
+ __m256 f1, f2;
+ loadPartial( a, b, rem, f1, f2 );
+ __m256 res = _mm256_add_ps( f1, f2 );
+ storePartial( rdi, res, rem );
+ }
+}
+
+void addF16to32( float* rdi, const uint16_t* a, const float* b, size_t length )
+{
+ const uint16_t* const endAligned = a + ( length & maskAlign8 );
+ const size_t rem = length % 8;
+
+ for( ; a < endAligned; a += 8, b += 8, rdi += 8 )
+ {
+ __m256 f1 = load8( a );
+ __m256 f2 = _mm256_loadu_ps( b );
+ __m256 res = _mm256_add_ps( f1, f2 );
+ _mm256_storeu_ps( rdi, res );
+ }
+
+ if( rem != 0 )
+ {
+ __m256 f1 = loadPartial( a, rem );
+ __m256 f2 = loadPartial( b, rem );
+ __m256 res = _mm256_add_ps( f1, f2 );
+ storePartial( rdi, res, rem );
+ }
+}
+
+alignas( 64 ) const std::array<int, 16> s_zeroTailMask =
+{
+ -1,-1,-1,-1,-1,-1,-1,-1,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+};
+
+namespace
+{
+ __forceinline float horizontalSum( __m256 vec )
+ {
+ __m128 v = _mm256_extractf128_ps( vec, 1 );
+ v = _mm_add_ps( v, _mm256_castps256_ps128( vec ) );
+ v = _mm_add_ps( v, _mm_movehl_ps( v, v ) );
+ v = _mm_add_ss( v, _mm_movehdup_ps( v ) );
+ return _mm_cvtss_f32( v );
+ }
+}
+
+void norm( float* rdi, float* temp, const float* rsi, size_t length )
+{
+ assert( (size_t)temp % 32 == 0 );
+ const float* rsiEndAligned = rsi + ( length & maskAlign8 );
+ const size_t rem = length % 8;
+
+ // First pass: copy to temp buffer, and compute the sum; computeVectorSum() in HLSL
+ __m256 sum = _mm256_setzero_ps();
+ float* t;
+ for( t = temp; rsi < rsiEndAligned; rsi += 8, t += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rsi );
+ sum = _mm256_add_ps( sum, v );
+ _mm256_store_ps( t, v );
+ }
+ float* const tEndAligned = t;
+ if( 0 != rem )
+ {
+ __m256 v = loadPartial( rsi, rem );
+ sum = _mm256_add_ps( sum, v );
+ _mm256_store_ps( t, v );
+ t += 8;
+ }
+
+ const float lengthFloat = (float)(int)length;
+ const float meanScalar = horizontalSum( sum ) / lengthFloat;
+ const __m256 mean = _mm256_set1_ps( meanScalar );
+
+ // Second pass, offsetAndComputeSumSquares() in HLSL
+ sum = _mm256_setzero_ps();
+ for( t = temp; t < tEndAligned; t += 8 )
+ {
+ __m256 v = _mm256_load_ps( t );
+ v = _mm256_sub_ps( v, mean );
+ _mm256_store_ps( t, v );
+ sum = _mm256_fmadd_ps( v, v, sum );
+ }
+ if( 0 != rem )
+ {
+ __m256 v = _mm256_load_ps( t );
+ v = _mm256_sub_ps( v, mean );
+ v = _mm256_and_ps( v, loadTailMaskFloats( rem ) );
+ _mm256_store_ps( t, v );
+ sum = _mm256_fmadd_ps( v, v, sum );
+ }
+
+ // Final pass: scale, and copy from temporary buffer into the destination row
+
+ constexpr float eps = 1e-5f; // TODO: make this a parameter
+ const float scaleScalar = 1.0f / std::sqrtf( horizontalSum( sum ) / lengthFloat + eps );
+ const __m256 scale = _mm256_set1_ps( scaleScalar );
+
+ for( t = temp; t < tEndAligned; t += 8, rdi += 8 )
+ {
+ __m256 v = _mm256_load_ps( t );
+ v = _mm256_mul_ps( v, scale );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ __m256 v = _mm256_load_ps( t );
+ v = _mm256_mul_ps( v, scale );
+ storePartial( rdi, v, rem );
+ }
+}
+
+void fmaRepeatRow( float* rdi, size_t len, const float* w, const float* b, size_t lenPattern )
+{
+ float* rdiEndAligned = rdi + ( len & maskAlign8 );
+ const size_t rem = len % 8;
+
+ if( 1 == lenPattern )
+ {
+ const __m256 v1 = _mm256_broadcast_ss( w );
+ const __m256 v2 = _mm256_broadcast_ss( b );
+ for( ; rdi < rdiEndAligned; rdi += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ v = _mm256_fmadd_ps( v, v1, v2 );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ v = _mm256_fmadd_ps( v, v1, v2 );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+ }
+ else if( len == lenPattern )
+ {
+ for( ; rdi < rdiEndAligned; rdi += 8, w += 8, b += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ __m256 v1 = _mm256_loadu_ps( w );
+ __m256 v2 = _mm256_loadu_ps( b );
+ v = _mm256_fmadd_ps( v, v1, v2 );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ __m256 v1 = _mm256_maskload_ps( w, mask );
+ __m256 v2 = _mm256_maskload_ps( b, mask );
+ v = _mm256_fmadd_ps( v, v1, v2 );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+ }
+ else
+ {
+ // TODO: implement if this actually happens
+ throw E_NOTIMPL;
+ }
+}
+
+void __vectorcall addRepeatScaleRow( float* rdi, size_t len, const float* b, size_t lenPattern, const __m256 scale )
+{
+ float* rdiEndAligned = rdi + ( len & maskAlign8 );
+ const size_t rem = len % 8;
+
+ if( 1 == lenPattern )
+ {
+ const __m256 v2 = _mm256_broadcast_ss( b );
+ for( ; rdi < rdiEndAligned; rdi += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ v = _mm256_add_ps( v, v2 );
+ v = _mm256_mul_ps( v, scale );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ v = _mm256_add_ps( v, v2 );
+ v = _mm256_mul_ps( v, scale );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+ return;
+ }
+ else if( len == lenPattern )
+ {
+ for( ; rdi < rdiEndAligned; rdi += 8, b += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ __m256 v2 = _mm256_loadu_ps( b );
+ v = _mm256_add_ps( v, v2 );
+ v = _mm256_mul_ps( v, scale );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ __m256 v2 = _mm256_maskload_ps( b, mask );
+ v = _mm256_add_ps( v, v2 );
+ v = _mm256_mul_ps( v, scale );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+ return;
+ }
+ else
+ {
+ // TODO: implement if this actually happens
+ throw E_NOTIMPL;
+ }
+}
+
+void addRepeatRow( float* rdi, size_t len, const float* b, size_t lenPattern )
+{
+ float* rdiEndAligned = rdi + ( len & maskAlign8 );
+ const size_t rem = len % 8;
+
+ if( 1 == lenPattern )
+ {
+ const __m256 v2 = _mm256_broadcast_ss( b );
+ for( ; rdi < rdiEndAligned; rdi += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ v = _mm256_add_ps( v, v2 );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ v = _mm256_add_ps( v, v2 );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+ return;
+ }
+ else if( len == lenPattern )
+ {
+ for( ; rdi < rdiEndAligned; rdi += 8, b += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ __m256 v2 = _mm256_loadu_ps( b );
+ v = _mm256_add_ps( v, v2 );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ __m256 v2 = _mm256_maskload_ps( b, mask );
+ v = _mm256_add_ps( v, v2 );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+ return;
+ }
+ else
+ {
+ // TODO: implement if this actually happens
+ throw E_NOTIMPL;
+ }
+}
+
+namespace
+{
+ __forceinline __m256 gelu( __m256 x, const DirectCompute::LookupTablesData& lookup )
+ {
+ __m128i iv = _mm256_cvtps_ph( x, 0 );
+ alignas( 16 ) std::array<uint16_t, 8> arr;
+ _mm_store_si128( ( __m128i* )arr.data(), iv );
+ for( uint16_t& a : arr )
+ a = lookup.gelu[ a ];
+ iv = _mm_load_si128( ( __m128i* )arr.data() );
+ return _mm256_cvtph_ps( iv );
+ }
+}
+
+void addRepeatGeluRow( float* rdi, size_t len, const float* b, size_t lenPattern, const DirectCompute::LookupTablesData& lookup )
+{
+ float* rdiEndAligned = rdi + ( len & maskAlign8 );
+ const size_t rem = len % 8;
+
+ if( 1 == lenPattern )
+ {
+ const __m256 v2 = _mm256_broadcast_ss( b );
+ for( ; rdi < rdiEndAligned; rdi += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ v = _mm256_add_ps( v, v2 );
+ v = gelu( v, lookup );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ v = _mm256_add_ps( v, v2 );
+ v = gelu( v, lookup );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+ return;
+ }
+ else if( len == lenPattern )
+ {
+ for( ; rdi < rdiEndAligned; rdi += 8, b += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ __m256 v2 = _mm256_loadu_ps( b );
+ v = _mm256_add_ps( v, v2 );
+ v = gelu( v, lookup );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ __m256 v2 = _mm256_maskload_ps( b, mask );
+ v = _mm256_add_ps( v, v2 );
+ v = gelu( v, lookup );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+ return;
+ }
+ else
+ {
+ // TODO: implement if this actually happens
+ throw E_NOTIMPL;
+ }
+}
+
+void __vectorcall scaleRow( float* rdi, size_t len, const __m256 scale )
+{
+ float* rdiEndAligned = rdi + ( len & maskAlign8 );
+ const size_t rem = len % 8;
+ for( ; rdi < rdiEndAligned; rdi += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ v = _mm256_mul_ps( v, scale );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 v = _mm256_maskload_ps( rdi, mask );
+ v = _mm256_mul_ps( v, scale );
+ _mm256_maskstore_ps( rdi, mask, v );
+ }
+}
+
+namespace
+{
+ using DirectCompute::LookupTablesData;
+
+ __forceinline float horizontalMax( __m256 vec )
+ {
+ __m128 v = _mm256_extractf128_ps( vec, 1 );
+ v = _mm_max_ps( v, _mm256_castps256_ps128( vec ) );
+ v = _mm_max_ps( v, _mm_movehl_ps( v, v ) );
+ v = _mm_max_ss( v, _mm_movehdup_ps( v ) );
+ return _mm_cvtss_f32( v );
+ }
+
+ __forceinline float _cvtsh_ss( uint16_t f16 )
+ {
+ __m128i i = _mm_cvtsi32_si128( f16 );
+ __m128 f = _mm_cvtph_ps( i );
+ return _mm_cvtss_f32( f );
+ }
+
+ __forceinline uint16_t _cvtss_sh( float f, int rounding )
+ {
+ assert( 0 == rounding );
+ __m128 v = _mm_set_ss( f );
+ __m128i i = _mm_cvtps_ph( v, 0 );
+ return (uint16_t)(uint32_t)_mm_cvtsi128_si32( i );
+ }
+}
+
+const LookupTablesData& getLookupTables()
+{
+ static const std::unique_ptr<LookupTablesData> res = std::make_unique<LookupTablesData>();
+ return *res;
+}
+
+void softMax( float* rdi, size_t length, const float inputScale )
+{
+ float* const rdiBegin = rdi;
+ float* const rdiEndAligned = rdi + ( length & maskAlign8 );
+ const size_t remainder = length % 8;
+ // First pass, compute maximum
+ __m256 max = _mm256_set1_ps( -INFINITY );
+ for( rdi = rdiBegin; rdi < rdiEndAligned; rdi += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ max = _mm256_max_ps( max, v );
+ }
+ __m256i tailMask;
+ if( 0 != remainder )
+ {
+ tailMask = loadTailMaskInt( remainder );
+ __m256 v = _mm256_maskload_ps( rdi, tailMask );
+ v = _mm256_max_ps( max, v );
+ max = _mm256_blendv_ps( max, v, _mm256_castsi256_ps( tailMask ) );
+ }
+
+ // Second pass: apply initial scale, compute the exponent, and compute total sum over the row
+ const LookupTablesData& lookup = getLookupTables();
+ const float maxScalar = horizontalMax( max );
+
+ float* const rdiEnd = rdiBegin + length;
+ double sum = 0;
+ for( rdi = rdiBegin; rdi < rdiEnd; rdi++ )
+ {
+ // Possible to vectorize, but relatively hard
+ // An easy way is upcast the complete lookup table to FP32 and then use two _mm256_i32gather_ps instructions per iteration
+ // However, that instruction is from AVX2 set. Let's hope this loop won't be a bottleneck.
+ float f = *rdi;
+ if( f != -INFINITY )
+ {
+ f = ( f - maxScalar ) * inputScale;
+ uint16_t f16 = _cvtss_sh( f, 0 );
+ f16 = lookup.exponent[ f16 ];
+ f = _cvtsh_ss( f16 );
+ sum += f;
+ }
+ else
+ f = 0;
+
+ *rdi = f;
+ }
+
+ // Final pass: apply the final scale
+ const __m256 finalScale = _mm256_set1_ps( (float)( 1.0 / sum ) );
+ for( rdi = rdiBegin; rdi < rdiEndAligned; rdi += 8 )
+ {
+ __m256 v = _mm256_loadu_ps( rdi );
+ v = _mm256_mul_ps( v, finalScale );
+ _mm256_storeu_ps( rdi, v );
+ }
+ if( 0 != remainder )
+ {
+ __m256 v = _mm256_maskload_ps( rdi, tailMask );
+ v = _mm256_mul_ps( v, finalScale );
+ _mm256_maskstore_ps( rdi, tailMask, v );
+ }
+}
+
+void floatsUpcast( float* rdi, const uint16_t* rsi, size_t length )
+{
+ const uint16_t* rsiEndAligned = rsi + ( length & maskAlign8 );
+ const size_t rem = length % 8;
+
+ for( ; rsi < rsiEndAligned; rsi += 8, rdi += 8 )
+ _mm256_storeu_ps( rdi, load8( rsi ) );
+
+ if( 0 != rem )
+ {
+ __m256 v = loadPartial( rsi, rem );
+ _mm256_maskstore_ps( rdi, loadTailMaskInt( rem ), v );
+ }
+}
+
+void floatsDowncast( uint16_t* rdi, const float* rsi, size_t length )
+{
+ const float* rsiEndAligned = rsi + ( length & maskAlign8 );
+ size_t rem = length % 8;
+
+ for( ; rsi < rsiEndAligned; rsi += 8, rdi += 8 )
+ {
+ __m256 vf = _mm256_loadu_ps( rsi );
+ __m128i vi = _mm256_cvtps_ph( vf, 0 );
+ store16( rdi, vi );
+ }
+
+ if( 0 != rem )
+ {
+ __m256 vf = _mm256_maskload_ps( rsi, loadTailMaskInt( rem ) );
+ __m128i vi = _mm256_cvtps_ph( vf, 0 );
+ for( size_t i = 0; i < rem; i++, rdi++ )
+ {
+ *rdi = (uint16_t)(uint32_t)_mm_cvtsi128_si32( vi );
+ vi = _mm_srli_si128( vi, 2 );
+ }
+ }
+}
+
+void addRowInPlace( float* rdi, const float* rsi, size_t length )
+{
+ const float* rdiEndAligned = rdi + ( length & maskAlign8 );
+ size_t rem = length % 8;
+
+ for( ; rdi < rdiEndAligned; rdi += 8, rsi += 8 )
+ {
+ __m256 a = _mm256_loadu_ps( rdi );
+ __m256 b = _mm256_loadu_ps( rsi );
+ a = _mm256_add_ps( a, b );
+ _mm256_storeu_ps( rdi, a );
+ }
+
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 a = _mm256_maskload_ps( rdi, mask );
+ __m256 b = _mm256_maskload_ps( rsi, mask );
+ a = _mm256_add_ps( a, b );
+ _mm256_maskstore_ps( rdi, mask, a );
+ }
+}
+
+void addRow( float* rdi, const float* a, const float* b, size_t length )
+{
+ const float* aEndAligned = a + ( length & maskAlign8 );
+ size_t rem = length % 8;
+
+ for( ; a < aEndAligned; a += 8, b += 8, rdi += 8 )
+ {
+ __m256 x = _mm256_loadu_ps( a );
+ __m256 y = _mm256_loadu_ps( b );
+ x = _mm256_add_ps( x, y );
+ _mm256_storeu_ps( rdi, x );
+ }
+
+ if( 0 != rem )
+ {
+ const __m256i mask = loadTailMaskInt( rem );
+ __m256 x = _mm256_maskload_ps( a, mask );
+ __m256 y = _mm256_maskload_ps( b, mask );
+ x = _mm256_add_ps( x, y );
+ _mm256_maskstore_ps( rdi, mask, x );
+ }
+} \ No newline at end of file
diff --git a/Whisper/CPU/simdUtils.h b/Whisper/CPU/simdUtils.h
new file mode 100644
index 0000000..a7a4bac
--- /dev/null
+++ b/Whisper/CPU/simdUtils.h
@@ -0,0 +1,82 @@
+#pragma once
+#include <immintrin.h>
+
+void addF16to32( float* rdi, const uint16_t* a, const uint16_t* b, size_t length );
+void addF16to32( float* rdi, const uint16_t* a, const float* b, size_t length );
+
+class AlignedSpan
+{
+ float* pointer;
+
+public:
+ AlignedSpan( void* data )
+ {
+ size_t i = (size_t)data;
+ constexpr size_t mask32 = ~(size_t)31;
+ i = ( i + 31 ) & mask32;
+ pointer = (float*)i;
+ }
+
+ operator float* ( ) { return pointer; }
+};
+
+inline size_t tempBufferForFloats( size_t count )
+{
+ // Round up by 8 to be able to use full-vector loads and stores
+ constexpr size_t mask8 = ~(size_t)7;
+ count = ( count + 7 ) & mask8;
+
+ // Add 32 more bytes to align the temporary buffer
+ return ( count * 4 ) + 32;
+}
+
+#define ALIGNED_SPAN( name, countFloats ) AlignedSpan name{ _alloca( tempBufferForFloats( countFloats ) ) }
+
+void norm( float* rdi, float* temp, const float* rsi, size_t length );
+
+void fmaRepeatRow( float* rdi, size_t len, const float* w, const float* b, size_t lenPattern );
+void __vectorcall addRepeatScaleRow( float* rdi, size_t len, const float* b, size_t lenPattern, const __m256 scale );
+void addRepeatRow( float* rdi, size_t len, const float* b, size_t lenPattern );
+void __vectorcall scaleRow( float* rdi, size_t len, const __m256 scale );
+
+namespace DirectCompute
+{
+ struct LookupTablesData;
+}
+const DirectCompute::LookupTablesData& getLookupTables();
+void addRepeatGeluRow( float* rdi, size_t len, const float* b, size_t lenPattern, const DirectCompute::LookupTablesData& lookup );
+
+void softMax( float* rdi, size_t length, const float inputScale );
+
+// A cache line-aligned array where first 8 elements have all bits set, last 8 elements are zeros
+extern const std::array<int, 16> s_zeroTailMask;
+
+// Load a tail mask as FP32 vector, for use with _mm256_and_ps or _mm256_blendv_ps instructions
+__forceinline __m256 loadTailMaskFloats( size_t remainder )
+{
+ assert( remainder > 0 && remainder < 8 );
+ const float* rsi = (const float*)&s_zeroTailMask;
+ rsi += 8;
+ return _mm256_loadu_ps( rsi - remainder );
+}
+
+// Load a tail mask as int32 vector, for use with _mm256_maskstore_ps instruction
+template<bool assertIncomplete = true>
+__forceinline __m256i loadTailMaskInt( size_t remainder )
+{
+ if constexpr( assertIncomplete )
+ assert( remainder > 0 && remainder < 8 );
+ else
+ assert( remainder >= 0 && remainder <= 8 );
+
+ const int* rsi = (const int*)&s_zeroTailMask;
+ rsi += 8;
+ return _mm256_loadu_si256( ( const __m256i* )( rsi - remainder ) );
+}
+
+void floatsUpcast( float* rdi, const uint16_t* rsi, size_t length );
+
+void floatsDowncast( uint16_t* rdi, const float* rsi, size_t length );
+
+void addRowInPlace( float* rdi, const float* rsi, size_t length );
+void addRow( float* rdi, const float* a, const float* b, size_t length ); \ No newline at end of file
diff --git a/Whisper/D3D/Binder.cpp b/Whisper/D3D/Binder.cpp
new file mode 100644
index 0000000..9caf7aa
--- /dev/null
+++ b/Whisper/D3D/Binder.cpp
@@ -0,0 +1,63 @@
+#include "stdafx.h"
+#include "Binder.h"
+#include <algorithm>
+using namespace DirectCompute;
+
+void Binder::bind( ID3D11ShaderResourceView* srv0, ID3D11UnorderedAccessView* uav0 )
+{
+ ID3D11DeviceContext* const ctx = context();
+ ctx->CSSetUnorderedAccessViews( 0, 1, &uav0, nullptr );
+
+ ctx->CSSetShaderResources( 0, 1, &srv0 );
+
+ maxSrv = std::max( maxSrv, (uint8_t)1 );
+ maxUav = std::max( maxUav, (uint8_t)1 );
+}
+
+void Binder::bind( ID3D11UnorderedAccessView* uav0 )
+{
+ context()->CSSetUnorderedAccessViews( 0, 1, &uav0, nullptr );
+ maxUav = std::max( maxUav, (uint8_t)1 );
+}
+
+void Binder::bind( ID3D11ShaderResourceView* srv0, ID3D11ShaderResourceView* srv1, ID3D11UnorderedAccessView* uav0 )
+{
+ ID3D11DeviceContext* const ctx = context();
+ ctx->CSSetUnorderedAccessViews( 0, 1, &uav0, nullptr );
+
+ std::array< ID3D11ShaderResourceView*, 2> arr = { srv0, srv1 };
+ ctx->CSSetShaderResources( 0, 2, arr.data() );
+
+ maxSrv = std::max( maxSrv, (uint8_t)2 );
+ maxUav = std::max( maxUav, (uint8_t)1 );
+}
+
+void Binder::bind( std::initializer_list<ID3D11ShaderResourceView*> srvs, std::initializer_list<ID3D11UnorderedAccessView*> uavs )
+{
+ ID3D11DeviceContext* const ctx = context();
+
+ const size_t lengthResources = srvs.size();
+ const size_t lengthUnordered = uavs.size();
+ assert( lengthResources > 0 && lengthResources < D3D11_COMMONSHADER_INPUT_RESOURCE_REGISTER_COUNT );
+ assert( lengthUnordered > 0 && lengthUnordered < D3D11_PS_CS_UAV_REGISTER_COUNT );
+
+ ctx->CSSetUnorderedAccessViews( 0, (UINT)lengthUnordered, uavs.begin(), nullptr );
+ ctx->CSSetShaderResources( 0, (UINT)lengthResources, srvs.begin() );
+
+ maxSrv = std::max( maxSrv, (uint8_t)lengthResources );
+ maxUav = std::max( maxUav, (uint8_t)lengthUnordered );
+}
+
+Binder::~Binder()
+{
+ uint8_t count = std::max( maxSrv, maxUav );
+ if( 0 == count )
+ return;
+#pragma warning (disable: 6255) // Compiler doesn't know we have very few of these things
+ size_t* arr = (size_t*)_alloca( count * sizeof( size_t ) );
+ memset( arr, 0, count * sizeof( size_t ) );
+
+ ID3D11DeviceContext* const ctx = context();
+ ctx->CSSetShaderResources( 0, maxSrv, (ID3D11ShaderResourceView**)arr );
+ ctx->CSSetUnorderedAccessViews( 0, maxUav, (ID3D11UnorderedAccessView**)arr, nullptr );
+} \ No newline at end of file
diff --git a/Whisper/D3D/Binder.h b/Whisper/D3D/Binder.h
new file mode 100644
index 0000000..bf7ffb2
--- /dev/null
+++ b/Whisper/D3D/Binder.h
@@ -0,0 +1,21 @@
+#pragma once
+#include "device.h"
+
+namespace DirectCompute
+{
+ class Binder
+ {
+ uint8_t maxSrv = 0;
+ uint8_t maxUav = 0;
+
+ public:
+ Binder() = default;
+ Binder( const Binder& ) = delete;
+
+ void bind( ID3D11ShaderResourceView* srv0, ID3D11UnorderedAccessView* uav0 );
+ void bind( ID3D11ShaderResourceView* srv0, ID3D11ShaderResourceView* srv1, ID3D11UnorderedAccessView* uav0 );
+ void bind( std::initializer_list<ID3D11ShaderResourceView*> srvs, std::initializer_list<ID3D11UnorderedAccessView*> uavs );
+ void bind( ID3D11UnorderedAccessView* uav0 );
+ ~Binder();
+ };
+} \ No newline at end of file
diff --git a/Whisper/D3D/MappedResource.cpp b/Whisper/D3D/MappedResource.cpp
new file mode 100644
index 0000000..d6e8119
--- /dev/null
+++ b/Whisper/D3D/MappedResource.cpp
@@ -0,0 +1,33 @@
+#include "stdafx.h"
+#include "MappedResource.h"
+using namespace DirectCompute;
+#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
+
+MappedResource::MappedResource()
+{
+ mapped.pData = nullptr;
+ mapped.RowPitch = mapped.DepthPitch = 0;
+ resource = nullptr;
+}
+
+HRESULT MappedResource::map( ID3D11Resource* res, bool reading )
+{
+ if( nullptr == resource )
+ {
+ D3D11_MAP mt = reading ? D3D11_MAP_READ : D3D11_MAP_WRITE_DISCARD;
+ CHECK( context()->Map( res, 0, mt, 0, &mapped ) );
+ resource = res;
+ return S_OK;
+ }
+ return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
+}
+
+MappedResource::~MappedResource()
+{
+ if( nullptr != resource )
+ {
+ context()->Unmap( resource, 0 );
+ resource = nullptr;
+ mapped.pData = nullptr;
+ }
+} \ No newline at end of file
diff --git a/Whisper/D3D/MappedResource.h b/Whisper/D3D/MappedResource.h
new file mode 100644
index 0000000..a6b046b
--- /dev/null
+++ b/Whisper/D3D/MappedResource.h
@@ -0,0 +1,22 @@
+#pragma once
+#include "device.h"
+#include <assert.h>
+
+namespace DirectCompute
+{
+ class MappedResource
+ {
+ D3D11_MAPPED_SUBRESOURCE mapped;
+ ID3D11Resource* resource;
+ public:
+ MappedResource();
+ HRESULT map( ID3D11Resource* res, bool reading );
+ ~MappedResource();
+
+ void* data() const
+ {
+ assert( nullptr != mapped.pData );
+ return mapped.pData;
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/D3D/RenderDoc/renderDoc.cpp b/Whisper/D3D/RenderDoc/renderDoc.cpp
new file mode 100644
index 0000000..e811c1e
--- /dev/null
+++ b/Whisper/D3D/RenderDoc/renderDoc.cpp
@@ -0,0 +1,72 @@
+#include "stdafx.h"
+#include "renderDoc.h"
+#include "renderdoc_app.h"
+#include "../device.h"
+
+#define ENABLE_RENDERDOC_DEBUGGER 1
+
+#if ENABLE_RENDERDOC_DEBUGGER
+namespace
+{
+ static HMODULE hmRenderDoc = nullptr;
+ static RENDERDOC_API_1_6_0* api = nullptr;
+}
+
+bool DirectCompute::initializeRenderDoc()
+{
+ hmRenderDoc = GetModuleHandleW( L"renderdoc.dll" );
+ if( nullptr == hmRenderDoc )
+ return false;
+
+ pRENDERDOC_GetAPI getApi = (pRENDERDOC_GetAPI)GetProcAddress( hmRenderDoc, "RENDERDOC_GetAPI" );
+ if( nullptr == getApi )
+ return false;
+ if( 1 != getApi( eRENDERDOC_API_Version_1_6_0, (void**)&api ) )
+ return false;
+ if( nullptr == api )
+ return false;
+
+ return true;
+}
+
+namespace
+{
+ using namespace DirectCompute;
+ inline bool isKeyPressed( int vKey )
+ {
+ return 0 != ( GetAsyncKeyState( vKey ) & 0x8000 );
+ }
+}
+
+CaptureRaii::CaptureRaii() : capturing( false )
+{
+ if( nullptr == api )
+ return;
+ if( !isKeyPressed( VK_F12 ) )
+ return;
+ ID3D11Device* const dev = device();
+ if( nullptr == dev )
+ return;
+
+ api->StartFrameCapture( dev, nullptr );
+ capturing = true;
+}
+
+CaptureRaii::~CaptureRaii()
+{
+ if( !capturing )
+ return;
+ api->EndFrameCapture( device(), nullptr );
+}
+#else // !ENABLE_RENDERDOC_DEBUGGER
+bool DirectCompute::initializeRenderDoc()
+{
+ return false;
+}
+DirectCompute::CaptureRaii::CaptureRaii() : capturing( false )
+{
+}
+DirectCompute::CaptureRaii::~CaptureRaii()
+{
+}
+#endif \ No newline at end of file
diff --git a/Whisper/D3D/RenderDoc/renderDoc.h b/Whisper/D3D/RenderDoc/renderDoc.h
new file mode 100644
index 0000000..791052d
--- /dev/null
+++ b/Whisper/D3D/RenderDoc/renderDoc.h
@@ -0,0 +1,15 @@
+#pragma once
+
+namespace DirectCompute
+{
+ bool initializeRenderDoc();
+
+ class CaptureRaii
+ {
+ bool capturing;
+ public:
+ CaptureRaii();
+ CaptureRaii( const CaptureRaii& ) = delete;
+ ~CaptureRaii();
+ };
+} \ No newline at end of file
diff --git a/Whisper/D3D/RenderDoc/renderdoc_app.h b/Whisper/D3D/RenderDoc/renderdoc_app.h
new file mode 100644
index 0000000..402dd3d
--- /dev/null
+++ b/Whisper/D3D/RenderDoc/renderdoc_app.h
@@ -0,0 +1,724 @@
+/******************************************************************************
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2019-2022 Baldur Karlsson
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ ******************************************************************************/
+
+#pragma once
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+//
+// Documentation for the API is available at https://renderdoc.org/docs/in_application_api.html
+//
+
+#if !defined(RENDERDOC_NO_STDINT)
+#include <stdint.h>
+#endif
+
+#if defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER)
+#define RENDERDOC_CC __cdecl
+#elif defined(__linux__)
+#define RENDERDOC_CC
+#elif defined(__APPLE__)
+#define RENDERDOC_CC
+#else
+#error "Unknown platform"
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+// Constants not used directly in below API
+
+// This is a GUID/magic value used for when applications pass a path where shader debug
+// information can be found to match up with a stripped shader.
+// the define can be used like so: const GUID RENDERDOC_ShaderDebugMagicValue =
+// RENDERDOC_ShaderDebugMagicValue_value
+#define RENDERDOC_ShaderDebugMagicValue_struct \
+ { \
+ 0xeab25520, 0x6670, 0x4865, 0x84, 0x29, 0x6c, 0x8, 0x51, 0x54, 0x00, 0xff \
+ }
+
+// as an alternative when you want a byte array (assuming x86 endianness):
+#define RENDERDOC_ShaderDebugMagicValue_bytearray \
+ { \
+ 0x20, 0x55, 0xb2, 0xea, 0x70, 0x66, 0x65, 0x48, 0x84, 0x29, 0x6c, 0x8, 0x51, 0x54, 0x00, 0xff \
+ }
+
+// truncated version when only a uint64_t is available (e.g. Vulkan tags):
+#define RENDERDOC_ShaderDebugMagicValue_truncated 0x48656670eab25520ULL
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+// RenderDoc capture options
+//
+
+typedef enum RENDERDOC_CaptureOption {
+ // Allow the application to enable vsync
+ //
+ // Default - enabled
+ //
+ // 1 - The application can enable or disable vsync at will
+ // 0 - vsync is force disabled
+ eRENDERDOC_Option_AllowVSync = 0,
+
+ // Allow the application to enable fullscreen
+ //
+ // Default - enabled
+ //
+ // 1 - The application can enable or disable fullscreen at will
+ // 0 - fullscreen is force disabled
+ eRENDERDOC_Option_AllowFullscreen = 1,
+
+ // Record API debugging events and messages
+ //
+ // Default - disabled
+ //
+ // 1 - Enable built-in API debugging features and records the results into
+ // the capture, which is matched up with events on replay
+ // 0 - no API debugging is forcibly enabled
+ eRENDERDOC_Option_APIValidation = 2,
+ eRENDERDOC_Option_DebugDeviceMode = 2, // deprecated name of this enum
+
+ // Capture CPU callstacks for API events
+ //
+ // Default - disabled
+ //
+ // 1 - Enables capturing of callstacks
+ // 0 - no callstacks are captured
+ eRENDERDOC_Option_CaptureCallstacks = 3,
+
+ // When capturing CPU callstacks, only capture them from actions.
+ // This option does nothing without the above option being enabled
+ //
+ // Default - disabled
+ //
+ // 1 - Only captures callstacks for actions.
+ // Ignored if CaptureCallstacks is disabled
+ // 0 - Callstacks, if enabled, are captured for every event.
+ eRENDERDOC_Option_CaptureCallstacksOnlyDraws = 4,
+ eRENDERDOC_Option_CaptureCallstacksOnlyActions = 4,
+
+ // Specify a delay in seconds to wait for a debugger to attach, after
+ // creating or injecting into a process, before continuing to allow it to run.
+ //
+ // 0 indicates no delay, and the process will run immediately after injection
+ //
+ // Default - 0 seconds
+ //
+ eRENDERDOC_Option_DelayForDebugger = 5,
+
+ // Verify buffer access. This includes checking the memory returned by a Map() call to
+ // detect any out-of-bounds modification, as well as initialising buffers with undefined contents
+ // to a marker value to catch use of uninitialised memory.
+ //
+ // NOTE: This option is only valid for OpenGL and D3D11. Explicit APIs such as D3D12 and Vulkan do
+ // not do the same kind of interception & checking and undefined contents are really undefined.
+ //
+ // Default - disabled
+ //
+ // 1 - Verify buffer access
+ // 0 - No verification is performed, and overwriting bounds may cause crashes or corruption in
+ // RenderDoc.
+ eRENDERDOC_Option_VerifyBufferAccess = 6,
+
+ // The old name for eRENDERDOC_Option_VerifyBufferAccess was eRENDERDOC_Option_VerifyMapWrites.
+ // This option now controls the filling of uninitialised buffers with 0xdddddddd which was
+ // previously always enabled
+ eRENDERDOC_Option_VerifyMapWrites = eRENDERDOC_Option_VerifyBufferAccess,
+
+ // Hooks any system API calls that create child processes, and injects
+ // RenderDoc into them recursively with the same options.
+ //
+ // Default - disabled
+ //
+ // 1 - Hooks into spawned child processes
+ // 0 - Child processes are not hooked by RenderDoc
+ eRENDERDOC_Option_HookIntoChildren = 7,
+
+ // By default RenderDoc only includes resources in the final capture necessary
+ // for that frame, this allows you to override that behaviour.
+ //
+ // Default - disabled
+ //
+ // 1 - all live resources at the time of capture are included in the capture
+ // and available for inspection
+ // 0 - only the resources referenced by the captured frame are included
+ eRENDERDOC_Option_RefAllResources = 8,
+
+ // **NOTE**: As of RenderDoc v1.1 this option has been deprecated. Setting or
+ // getting it will be ignored, to allow compatibility with older versions.
+ // In v1.1 the option acts as if it's always enabled.
+ //
+ // By default RenderDoc skips saving initial states for resources where the
+ // previous contents don't appear to be used, assuming that writes before
+ // reads indicate previous contents aren't used.
+ //
+ // Default - disabled
+ //
+ // 1 - initial contents at the start of each captured frame are saved, even if
+ // they are later overwritten or cleared before being used.
+ // 0 - unless a read is detected, initial contents will not be saved and will
+ // appear as black or empty data.
+ eRENDERDOC_Option_SaveAllInitials = 9,
+
+ // In APIs that allow for the recording of command lists to be replayed later,
+ // RenderDoc may choose to not capture command lists before a frame capture is
+ // triggered, to reduce overheads. This means any command lists recorded once
+ // and replayed many times will not be available and may cause a failure to
+ // capture.
+ //
+ // NOTE: This is only true for APIs where multithreading is difficult or
+ // discouraged. Newer APIs like Vulkan and D3D12 will ignore this option
+ // and always capture all command lists since the API is heavily oriented
+ // around it and the overheads have been reduced by API design.
+ //
+ // 1 - All command lists are captured from the start of the application
+ // 0 - Command lists are only captured if their recording begins during
+ // the period when a frame capture is in progress.
+ eRENDERDOC_Option_CaptureAllCmdLists = 10,
+
+ // Mute API debugging output when the API validation mode option is enabled
+ //
+ // Default - enabled
+ //
+ // 1 - Mute any API debug messages from being displayed or passed through
+ // 0 - API debugging is displayed as normal
+ eRENDERDOC_Option_DebugOutputMute = 11,
+
+ // Option to allow vendor extensions to be used even when they may be
+ // incompatible with RenderDoc and cause corrupted replays or crashes.
+ //
+ // Default - inactive
+ //
+ // No values are documented, this option should only be used when absolutely
+ // necessary as directed by a RenderDoc developer.
+ eRENDERDOC_Option_AllowUnsupportedVendorExtensions = 12,
+
+} RENDERDOC_CaptureOption;
+
+// Sets an option that controls how RenderDoc behaves on capture.
+//
+// Returns 1 if the option and value are valid
+// Returns 0 if either is invalid and the option is unchanged
+typedef int(RENDERDOC_CC *pRENDERDOC_SetCaptureOptionU32)(RENDERDOC_CaptureOption opt, uint32_t val);
+typedef int(RENDERDOC_CC *pRENDERDOC_SetCaptureOptionF32)(RENDERDOC_CaptureOption opt, float val);
+
+// Gets the current value of an option as a uint32_t
+//
+// If the option is invalid, 0xffffffff is returned
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_GetCaptureOptionU32)(RENDERDOC_CaptureOption opt);
+
+// Gets the current value of an option as a float
+//
+// If the option is invalid, -FLT_MAX is returned
+typedef float(RENDERDOC_CC *pRENDERDOC_GetCaptureOptionF32)(RENDERDOC_CaptureOption opt);
+
+typedef enum RENDERDOC_InputButton {
+ // '0' - '9' matches ASCII values
+ eRENDERDOC_Key_0 = 0x30,
+ eRENDERDOC_Key_1 = 0x31,
+ eRENDERDOC_Key_2 = 0x32,
+ eRENDERDOC_Key_3 = 0x33,
+ eRENDERDOC_Key_4 = 0x34,
+ eRENDERDOC_Key_5 = 0x35,
+ eRENDERDOC_Key_6 = 0x36,
+ eRENDERDOC_Key_7 = 0x37,
+ eRENDERDOC_Key_8 = 0x38,
+ eRENDERDOC_Key_9 = 0x39,
+
+ // 'A' - 'Z' matches ASCII values
+ eRENDERDOC_Key_A = 0x41,
+ eRENDERDOC_Key_B = 0x42,
+ eRENDERDOC_Key_C = 0x43,
+ eRENDERDOC_Key_D = 0x44,
+ eRENDERDOC_Key_E = 0x45,
+ eRENDERDOC_Key_F = 0x46,
+ eRENDERDOC_Key_G = 0x47,
+ eRENDERDOC_Key_H = 0x48,
+ eRENDERDOC_Key_I = 0x49,
+ eRENDERDOC_Key_J = 0x4A,
+ eRENDERDOC_Key_K = 0x4B,
+ eRENDERDOC_Key_L = 0x4C,
+ eRENDERDOC_Key_M = 0x4D,
+ eRENDERDOC_Key_N = 0x4E,
+ eRENDERDOC_Key_O = 0x4F,
+ eRENDERDOC_Key_P = 0x50,
+ eRENDERDOC_Key_Q = 0x51,
+ eRENDERDOC_Key_R = 0x52,
+ eRENDERDOC_Key_S = 0x53,
+ eRENDERDOC_Key_T = 0x54,
+ eRENDERDOC_Key_U = 0x55,
+ eRENDERDOC_Key_V = 0x56,
+ eRENDERDOC_Key_W = 0x57,
+ eRENDERDOC_Key_X = 0x58,
+ eRENDERDOC_Key_Y = 0x59,
+ eRENDERDOC_Key_Z = 0x5A,
+
+ // leave the rest of the ASCII range free
+ // in case we want to use it later
+ eRENDERDOC_Key_NonPrintable = 0x100,
+
+ eRENDERDOC_Key_Divide,
+ eRENDERDOC_Key_Multiply,
+ eRENDERDOC_Key_Subtract,
+ eRENDERDOC_Key_Plus,
+
+ eRENDERDOC_Key_F1,
+ eRENDERDOC_Key_F2,
+ eRENDERDOC_Key_F3,
+ eRENDERDOC_Key_F4,
+ eRENDERDOC_Key_F5,
+ eRENDERDOC_Key_F6,
+ eRENDERDOC_Key_F7,
+ eRENDERDOC_Key_F8,
+ eRENDERDOC_Key_F9,
+ eRENDERDOC_Key_F10,
+ eRENDERDOC_Key_F11,
+ eRENDERDOC_Key_F12,
+
+ eRENDERDOC_Key_Home,
+ eRENDERDOC_Key_End,
+ eRENDERDOC_Key_Insert,
+ eRENDERDOC_Key_Delete,
+ eRENDERDOC_Key_PageUp,
+ eRENDERDOC_Key_PageDn,
+
+ eRENDERDOC_Key_Backspace,
+ eRENDERDOC_Key_Tab,
+ eRENDERDOC_Key_PrtScrn,
+ eRENDERDOC_Key_Pause,
+
+ eRENDERDOC_Key_Max,
+} RENDERDOC_InputButton;
+
+// Sets which key or keys can be used to toggle focus between multiple windows
+//
+// If keys is NULL or num is 0, toggle keys will be disabled
+typedef void(RENDERDOC_CC *pRENDERDOC_SetFocusToggleKeys)(RENDERDOC_InputButton *keys, int num);
+
+// Sets which key or keys can be used to capture the next frame
+//
+// If keys is NULL or num is 0, captures keys will be disabled
+typedef void(RENDERDOC_CC *pRENDERDOC_SetCaptureKeys)(RENDERDOC_InputButton *keys, int num);
+
+typedef enum RENDERDOC_OverlayBits {
+ // This single bit controls whether the overlay is enabled or disabled globally
+ eRENDERDOC_Overlay_Enabled = 0x1,
+
+ // Show the average framerate over several seconds as well as min/max
+ eRENDERDOC_Overlay_FrameRate = 0x2,
+
+ // Show the current frame number
+ eRENDERDOC_Overlay_FrameNumber = 0x4,
+
+ // Show a list of recent captures, and how many captures have been made
+ eRENDERDOC_Overlay_CaptureList = 0x8,
+
+ // Default values for the overlay mask
+ eRENDERDOC_Overlay_Default = (eRENDERDOC_Overlay_Enabled | eRENDERDOC_Overlay_FrameRate |
+ eRENDERDOC_Overlay_FrameNumber | eRENDERDOC_Overlay_CaptureList),
+
+ // Enable all bits
+ eRENDERDOC_Overlay_All = ~0U,
+
+ // Disable all bits
+ eRENDERDOC_Overlay_None = 0,
+} RENDERDOC_OverlayBits;
+
+// returns the overlay bits that have been set
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_GetOverlayBits)();
+// sets the overlay bits with an and & or mask
+typedef void(RENDERDOC_CC *pRENDERDOC_MaskOverlayBits)(uint32_t And, uint32_t Or);
+
+// this function will attempt to remove RenderDoc's hooks in the application.
+//
+// Note: that this can only work correctly if done immediately after
+// the module is loaded, before any API work happens. RenderDoc will remove its
+// injected hooks and shut down. Behaviour is undefined if this is called
+// after any API functions have been called, and there is still no guarantee of
+// success.
+typedef void(RENDERDOC_CC *pRENDERDOC_RemoveHooks)();
+
+// DEPRECATED: compatibility for code compiled against pre-1.4.1 headers.
+typedef pRENDERDOC_RemoveHooks pRENDERDOC_Shutdown;
+
+// This function will unload RenderDoc's crash handler.
+//
+// If you use your own crash handler and don't want RenderDoc's handler to
+// intercede, you can call this function to unload it and any unhandled
+// exceptions will pass to the next handler.
+typedef void(RENDERDOC_CC *pRENDERDOC_UnloadCrashHandler)();
+
+// Sets the capture file path template
+//
+// pathtemplate is a UTF-8 string that gives a template for how captures will be named
+// and where they will be saved.
+//
+// Any extension is stripped off the path, and captures are saved in the directory
+// specified, and named with the filename and the frame number appended. If the
+// directory does not exist it will be created, including any parent directories.
+//
+// If pathtemplate is NULL, the template will remain unchanged
+//
+// Example:
+//
+// SetCaptureFilePathTemplate("my_captures/example");
+//
+// Capture #1 -> my_captures/example_frame123.rdc
+// Capture #2 -> my_captures/example_frame456.rdc
+typedef void(RENDERDOC_CC *pRENDERDOC_SetCaptureFilePathTemplate)(const char *pathtemplate);
+
+// returns the current capture path template, see SetCaptureFileTemplate above, as a UTF-8 string
+typedef const char *(RENDERDOC_CC *pRENDERDOC_GetCaptureFilePathTemplate)();
+
+// DEPRECATED: compatibility for code compiled against pre-1.1.2 headers.
+typedef pRENDERDOC_SetCaptureFilePathTemplate pRENDERDOC_SetLogFilePathTemplate;
+typedef pRENDERDOC_GetCaptureFilePathTemplate pRENDERDOC_GetLogFilePathTemplate;
+
+// returns the number of captures that have been made
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_GetNumCaptures)();
+
+// This function returns the details of a capture, by index. New captures are added
+// to the end of the list.
+//
+// filename will be filled with the absolute path to the capture file, as a UTF-8 string
+// pathlength will be written with the length in bytes of the filename string
+// timestamp will be written with the time of the capture, in seconds since the Unix epoch
+//
+// Any of the parameters can be NULL and they'll be skipped.
+//
+// The function will return 1 if the capture index is valid, or 0 if the index is invalid
+// If the index is invalid, the values will be unchanged
+//
+// Note: when captures are deleted in the UI they will remain in this list, so the
+// capture path may not exist anymore.
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_GetCapture)(uint32_t idx, char *filename,
+ uint32_t *pathlength, uint64_t *timestamp);
+
+// Sets the comments associated with a capture file. These comments are displayed in the
+// UI program when opening.
+//
+// filePath should be a path to the capture file to add comments to. If set to NULL or ""
+// the most recent capture file created made will be used instead.
+// comments should be a NULL-terminated UTF-8 string to add as comments.
+//
+// Any existing comments will be overwritten.
+typedef void(RENDERDOC_CC *pRENDERDOC_SetCaptureFileComments)(const char *filePath,
+ const char *comments);
+
+// returns 1 if the RenderDoc UI is connected to this application, 0 otherwise
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_IsTargetControlConnected)();
+
+// DEPRECATED: compatibility for code compiled against pre-1.1.1 headers.
+// This was renamed to IsTargetControlConnected in API 1.1.1, the old typedef is kept here for
+// backwards compatibility with old code, it is castable either way since it's ABI compatible
+// as the same function pointer type.
+typedef pRENDERDOC_IsTargetControlConnected pRENDERDOC_IsRemoteAccessConnected;
+
+// This function will launch the Replay UI associated with the RenderDoc library injected
+// into the running application.
+//
+// if connectTargetControl is 1, the Replay UI will be launched with a command line parameter
+// to connect to this application
+// cmdline is the rest of the command line, as a UTF-8 string. E.g. a captures to open
+// if cmdline is NULL, the command line will be empty.
+//
+// returns the PID of the replay UI if successful, 0 if not successful.
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_LaunchReplayUI)(uint32_t connectTargetControl,
+ const char *cmdline);
+
+// RenderDoc can return a higher version than requested if it's backwards compatible,
+// this function returns the actual version returned. If a parameter is NULL, it will be
+// ignored and the others will be filled out.
+typedef void(RENDERDOC_CC *pRENDERDOC_GetAPIVersion)(int *major, int *minor, int *patch);
+
+// Requests that the replay UI show itself (if hidden or not the current top window). This can be
+// used in conjunction with IsTargetControlConnected and LaunchReplayUI to intelligently handle
+// showing the UI after making a capture.
+//
+// This will return 1 if the request was successfully passed on, though it's not guaranteed that
+// the UI will be on top in all cases depending on OS rules. It will return 0 if there is no current
+// target control connection to make such a request, or if there was another error
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_ShowReplayUI)();
+
+//////////////////////////////////////////////////////////////////////////
+// Capturing functions
+//
+
+// A device pointer is a pointer to the API's root handle.
+//
+// This would be an ID3D11Device, HGLRC/GLXContext, ID3D12Device, etc
+typedef void *RENDERDOC_DevicePointer;
+
+// A window handle is the OS's native window handle
+//
+// This would be an HWND, GLXDrawable, etc
+typedef void *RENDERDOC_WindowHandle;
+
+// A helper macro for Vulkan, where the device handle cannot be used directly.
+//
+// Passing the VkInstance to this macro will return the RENDERDOC_DevicePointer to use.
+//
+// Specifically, the value needed is the dispatch table pointer, which sits as the first
+// pointer-sized object in the memory pointed to by the VkInstance. Thus we cast to a void** and
+// indirect once.
+#define RENDERDOC_DEVICEPOINTER_FROM_VKINSTANCE(inst) (*((void **)(inst)))
+
+// This sets the RenderDoc in-app overlay in the API/window pair as 'active' and it will
+// respond to keypresses. Neither parameter can be NULL
+typedef void(RENDERDOC_CC *pRENDERDOC_SetActiveWindow)(RENDERDOC_DevicePointer device,
+ RENDERDOC_WindowHandle wndHandle);
+
+// capture the next frame on whichever window and API is currently considered active
+typedef void(RENDERDOC_CC *pRENDERDOC_TriggerCapture)();
+
+// capture the next N frames on whichever window and API is currently considered active
+typedef void(RENDERDOC_CC *pRENDERDOC_TriggerMultiFrameCapture)(uint32_t numFrames);
+
+// When choosing either a device pointer or a window handle to capture, you can pass NULL.
+// Passing NULL specifies a 'wildcard' match against anything. This allows you to specify
+// any API rendering to a specific window, or a specific API instance rendering to any window,
+// or in the simplest case of one window and one API, you can just pass NULL for both.
+//
+// In either case, if there are two or more possible matching (device,window) pairs it
+// is undefined which one will be captured.
+//
+// Note: for headless rendering you can pass NULL for the window handle and either specify
+// a device pointer or leave it NULL as above.
+
+// Immediately starts capturing API calls on the specified device pointer and window handle.
+//
+// If there is no matching thing to capture (e.g. no supported API has been initialised),
+// this will do nothing.
+//
+// The results are undefined (including crashes) if two captures are started overlapping,
+// even on separate devices and/oror windows.
+typedef void(RENDERDOC_CC *pRENDERDOC_StartFrameCapture)(RENDERDOC_DevicePointer device,
+ RENDERDOC_WindowHandle wndHandle);
+
+// Returns whether or not a frame capture is currently ongoing anywhere.
+//
+// This will return 1 if a capture is ongoing, and 0 if there is no capture running
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_IsFrameCapturing)();
+
+// Ends capturing immediately.
+//
+// This will return 1 if the capture succeeded, and 0 if there was an error capturing.
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_EndFrameCapture)(RENDERDOC_DevicePointer device,
+ RENDERDOC_WindowHandle wndHandle);
+
+// Ends capturing immediately and discard any data stored without saving to disk.
+//
+// This will return 1 if the capture was discarded, and 0 if there was an error or no capture
+// was in progress
+typedef uint32_t(RENDERDOC_CC *pRENDERDOC_DiscardFrameCapture)(RENDERDOC_DevicePointer device,
+ RENDERDOC_WindowHandle wndHandle);
+
+// Only valid to be called between a call to StartFrameCapture and EndFrameCapture. Gives a custom
+// title to the capture produced which will be displayed in the UI.
+//
+// If multiple captures are ongoing, this title will be applied to the first capture to end after
+// this call. The second capture to end will have no title, unless this function is called again.
+//
+// Calling this function has no effect if no capture is currently running, and if it is called
+// multiple times only the last title will be used.
+typedef void(RENDERDOC_CC *pRENDERDOC_SetCaptureTitle)(const char *title);
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+// RenderDoc API versions
+//
+
+// RenderDoc uses semantic versioning (http://semver.org/).
+//
+// MAJOR version is incremented when incompatible API changes happen.
+// MINOR version is incremented when functionality is added in a backwards-compatible manner.
+// PATCH version is incremented when backwards-compatible bug fixes happen.
+//
+// Note that this means the API returned can be higher than the one you might have requested.
+// e.g. if you are running against a newer RenderDoc that supports 1.0.1, it will be returned
+// instead of 1.0.0. You can check this with the GetAPIVersion entry point
+typedef enum RENDERDOC_Version {
+ eRENDERDOC_API_Version_1_0_0 = 10000, // RENDERDOC_API_1_0_0 = 1 00 00
+ eRENDERDOC_API_Version_1_0_1 = 10001, // RENDERDOC_API_1_0_1 = 1 00 01
+ eRENDERDOC_API_Version_1_0_2 = 10002, // RENDERDOC_API_1_0_2 = 1 00 02
+ eRENDERDOC_API_Version_1_1_0 = 10100, // RENDERDOC_API_1_1_0 = 1 01 00
+ eRENDERDOC_API_Version_1_1_1 = 10101, // RENDERDOC_API_1_1_1 = 1 01 01
+ eRENDERDOC_API_Version_1_1_2 = 10102, // RENDERDOC_API_1_1_2 = 1 01 02
+ eRENDERDOC_API_Version_1_2_0 = 10200, // RENDERDOC_API_1_2_0 = 1 02 00
+ eRENDERDOC_API_Version_1_3_0 = 10300, // RENDERDOC_API_1_3_0 = 1 03 00
+ eRENDERDOC_API_Version_1_4_0 = 10400, // RENDERDOC_API_1_4_0 = 1 04 00
+ eRENDERDOC_API_Version_1_4_1 = 10401, // RENDERDOC_API_1_4_1 = 1 04 01
+ eRENDERDOC_API_Version_1_4_2 = 10402, // RENDERDOC_API_1_4_2 = 1 04 02
+ eRENDERDOC_API_Version_1_5_0 = 10500, // RENDERDOC_API_1_5_0 = 1 05 00
+ eRENDERDOC_API_Version_1_6_0 = 10600, // RENDERDOC_API_1_6_0 = 1 06 00
+} RENDERDOC_Version;
+
+// API version changelog:
+//
+// 1.0.0 - initial release
+// 1.0.1 - Bugfix: IsFrameCapturing() was returning false for captures that were triggered
+// by keypress or TriggerCapture, instead of Start/EndFrameCapture.
+// 1.0.2 - Refactor: Renamed eRENDERDOC_Option_DebugDeviceMode to eRENDERDOC_Option_APIValidation
+// 1.1.0 - Add feature: TriggerMultiFrameCapture(). Backwards compatible with 1.0.x since the new
+// function pointer is added to the end of the struct, the original layout is identical
+// 1.1.1 - Refactor: Renamed remote access to target control (to better disambiguate from remote
+// replay/remote server concept in replay UI)
+// 1.1.2 - Refactor: Renamed "log file" in function names to just capture, to clarify that these
+// are captures and not debug logging files. This is the first API version in the v1.0
+// branch.
+// 1.2.0 - Added feature: SetCaptureFileComments() to add comments to a capture file that will be
+// displayed in the UI program on load.
+// 1.3.0 - Added feature: New capture option eRENDERDOC_Option_AllowUnsupportedVendorExtensions
+// which allows users to opt-in to allowing unsupported vendor extensions to function.
+// Should be used at the user's own risk.
+// Refactor: Renamed eRENDERDOC_Option_VerifyMapWrites to
+// eRENDERDOC_Option_VerifyBufferAccess, which now also controls initialisation to
+// 0xdddddddd of uninitialised buffer contents.
+// 1.4.0 - Added feature: DiscardFrameCapture() to discard a frame capture in progress and stop
+// capturing without saving anything to disk.
+// 1.4.1 - Refactor: Renamed Shutdown to RemoveHooks to better clarify what is happening
+// 1.4.2 - Refactor: Renamed 'draws' to 'actions' in callstack capture option.
+// 1.5.0 - Added feature: ShowReplayUI() to request that the replay UI show itself if connected
+// 1.6.0 - Added feature: SetCaptureTitle() which can be used to set a title for a
+// capture made with StartFrameCapture() or EndFrameCapture()
+
+typedef struct RENDERDOC_API_1_6_0
+{
+ pRENDERDOC_GetAPIVersion GetAPIVersion;
+
+ pRENDERDOC_SetCaptureOptionU32 SetCaptureOptionU32;
+ pRENDERDOC_SetCaptureOptionF32 SetCaptureOptionF32;
+
+ pRENDERDOC_GetCaptureOptionU32 GetCaptureOptionU32;
+ pRENDERDOC_GetCaptureOptionF32 GetCaptureOptionF32;
+
+ pRENDERDOC_SetFocusToggleKeys SetFocusToggleKeys;
+ pRENDERDOC_SetCaptureKeys SetCaptureKeys;
+
+ pRENDERDOC_GetOverlayBits GetOverlayBits;
+ pRENDERDOC_MaskOverlayBits MaskOverlayBits;
+
+ // Shutdown was renamed to RemoveHooks in 1.4.1.
+ // These unions allow old code to continue compiling without changes
+ union
+ {
+ pRENDERDOC_Shutdown Shutdown;
+ pRENDERDOC_RemoveHooks RemoveHooks;
+ };
+ pRENDERDOC_UnloadCrashHandler UnloadCrashHandler;
+
+ // Get/SetLogFilePathTemplate was renamed to Get/SetCaptureFilePathTemplate in 1.1.2.
+ // These unions allow old code to continue compiling without changes
+ union
+ {
+ // deprecated name
+ pRENDERDOC_SetLogFilePathTemplate SetLogFilePathTemplate;
+ // current name
+ pRENDERDOC_SetCaptureFilePathTemplate SetCaptureFilePathTemplate;
+ };
+ union
+ {
+ // deprecated name
+ pRENDERDOC_GetLogFilePathTemplate GetLogFilePathTemplate;
+ // current name
+ pRENDERDOC_GetCaptureFilePathTemplate GetCaptureFilePathTemplate;
+ };
+
+ pRENDERDOC_GetNumCaptures GetNumCaptures;
+ pRENDERDOC_GetCapture GetCapture;
+
+ pRENDERDOC_TriggerCapture TriggerCapture;
+
+ // IsRemoteAccessConnected was renamed to IsTargetControlConnected in 1.1.1.
+ // This union allows old code to continue compiling without changes
+ union
+ {
+ // deprecated name
+ pRENDERDOC_IsRemoteAccessConnected IsRemoteAccessConnected;
+ // current name
+ pRENDERDOC_IsTargetControlConnected IsTargetControlConnected;
+ };
+ pRENDERDOC_LaunchReplayUI LaunchReplayUI;
+
+ pRENDERDOC_SetActiveWindow SetActiveWindow;
+
+ pRENDERDOC_StartFrameCapture StartFrameCapture;
+ pRENDERDOC_IsFrameCapturing IsFrameCapturing;
+ pRENDERDOC_EndFrameCapture EndFrameCapture;
+
+ // new function in 1.1.0
+ pRENDERDOC_TriggerMultiFrameCapture TriggerMultiFrameCapture;
+
+ // new function in 1.2.0
+ pRENDERDOC_SetCaptureFileComments SetCaptureFileComments;
+
+ // new function in 1.4.0
+ pRENDERDOC_DiscardFrameCapture DiscardFrameCapture;
+
+ // new function in 1.5.0
+ pRENDERDOC_ShowReplayUI ShowReplayUI;
+
+ // new function in 1.6.0
+ pRENDERDOC_SetCaptureTitle SetCaptureTitle;
+} RENDERDOC_API_1_6_0;
+
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_0_0;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_0_1;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_0_2;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_1_0;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_1_1;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_1_2;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_2_0;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_3_0;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_4_0;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_4_1;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_4_2;
+typedef RENDERDOC_API_1_6_0 RENDERDOC_API_1_5_0;
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+// RenderDoc API entry point
+//
+// This entry point can be obtained via GetProcAddress/dlsym if RenderDoc is available.
+//
+// The name is the same as the typedef - "RENDERDOC_GetAPI"
+//
+// This function is not thread safe, and should not be called on multiple threads at once.
+// Ideally, call this once as early as possible in your application's startup, before doing
+// any API work, since some configuration functionality etc has to be done also before
+// initialising any APIs.
+//
+// Parameters:
+// version is a single value from the RENDERDOC_Version above.
+//
+// outAPIPointers will be filled out with a pointer to the corresponding struct of function
+// pointers.
+//
+// Returns:
+// 1 - if the outAPIPointers has been filled with a pointer to the API struct requested
+// 0 - if the requested version is not supported or the arguments are invalid.
+//
+typedef int(RENDERDOC_CC *pRENDERDOC_GetAPI)(RENDERDOC_Version version, void **outAPIPointers);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
diff --git a/Whisper/D3D/createBuffer.cpp b/Whisper/D3D/createBuffer.cpp
new file mode 100644
index 0000000..3fbb13e
--- /dev/null
+++ b/Whisper/D3D/createBuffer.cpp
@@ -0,0 +1,51 @@
+#include "stdafx.h"
+#include "createBuffer.h"
+
+#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
+
+HRESULT DirectCompute::createBuffer( eBufferUse use, size_t totalBytes, ID3D11Buffer** ppGpuBuffer, const void* rsi, ID3D11Buffer** ppStagingBuffer )
+{
+ if( totalBytes > INT_MAX )
+ return DISP_E_OVERFLOW;
+ if( nullptr == ppGpuBuffer )
+ return E_POINTER;
+
+ CD3D11_BUFFER_DESC bufferDesc{ (UINT)totalBytes, D3D11_BIND_SHADER_RESOURCE };
+ switch( use )
+ {
+ case eBufferUse::Immutable:
+ if( nullptr == rsi )
+ return E_INVALIDARG;
+ bufferDesc.Usage = D3D11_USAGE_IMMUTABLE;
+ break;
+ case eBufferUse::ReadWrite:
+ case eBufferUse::ReadWriteDownload:
+ bufferDesc.BindFlags |= D3D11_BIND_UNORDERED_ACCESS;
+ break;
+ case eBufferUse::Dynamic:
+ bufferDesc.Usage = D3D11_USAGE_DYNAMIC;
+ bufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
+ break;
+ }
+
+ D3D11_SUBRESOURCE_DATA srd;
+ D3D11_SUBRESOURCE_DATA* pSrd = nullptr;
+ if( nullptr != rsi )
+ {
+ srd.pSysMem = rsi;
+ srd.SysMemPitch = srd.SysMemSlicePitch = 0;
+ pSrd = &srd;
+ }
+
+ CHECK( device()->CreateBuffer( &bufferDesc, pSrd, ppGpuBuffer ) );
+
+ if( nullptr != ppStagingBuffer && use == eBufferUse::ReadWriteDownload )
+ {
+ bufferDesc.Usage = D3D11_USAGE_STAGING;
+ bufferDesc.BindFlags = 0;
+ bufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
+ CHECK( device()->CreateBuffer( &bufferDesc, nullptr, ppStagingBuffer ) );
+ }
+
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/D3D/createBuffer.h b/Whisper/D3D/createBuffer.h
new file mode 100644
index 0000000..df0c2ab
--- /dev/null
+++ b/Whisper/D3D/createBuffer.h
@@ -0,0 +1,8 @@
+#pragma once
+#include "enums.h"
+#include "device.h"
+
+namespace DirectCompute
+{
+ HRESULT createBuffer( eBufferUse use, size_t totalBytes, ID3D11Buffer** ppGpuBuffer, const void* rsi, ID3D11Buffer** ppStagingBuffer );
+} \ No newline at end of file
diff --git a/Whisper/D3D/device.cpp b/Whisper/D3D/device.cpp
new file mode 100644
index 0000000..4eb5a60
--- /dev/null
+++ b/Whisper/D3D/device.cpp
@@ -0,0 +1,120 @@
+#include "stdafx.h"
+#include "device.h"
+#include <immintrin.h>
+#include <ammintrin.h>
+#pragma comment(lib, "D3D11.lib")
+#include "RenderDoc/renderDoc.h"
+
+namespace DirectCompute
+{
+ CComPtr<ID3D11Device> g_device;
+ CComPtr<ID3D11DeviceContext> g_context;
+ D3D_FEATURE_LEVEL g_featureLevel = (D3D_FEATURE_LEVEL)0;
+
+ ID3D11Device* device() { return g_device; }
+ ID3D11DeviceContext* context() { return g_context; }
+ D3D_FEATURE_LEVEL featureLevel() { return g_featureLevel; }
+
+ void terminate()
+ {
+ g_context = nullptr;
+ g_device = nullptr;
+ }
+
+ static HRESULT createDevice()
+ {
+ if( g_device )
+ return S_FALSE;
+
+ const std::array<D3D_FEATURE_LEVEL, 4> levels = { D3D_FEATURE_LEVEL_12_1 , D3D_FEATURE_LEVEL_12_0 , D3D_FEATURE_LEVEL_11_1 , D3D_FEATURE_LEVEL_11_0 };
+ UINT flags = D3D11_CREATE_DEVICE_DISABLE_GPU_TIMEOUT | D3D11_CREATE_DEVICE_SINGLETHREADED;
+ bool renderDoc = initializeRenderDoc();
+#ifdef _DEBUG
+ if( !renderDoc )
+ {
+ // Last time I checked, RenderDoc crashed with debug version of D3D11 runtime
+ // Only setting this flag unless renderdoc.dll is loaded to the current process
+ flags |= D3D11_CREATE_DEVICE_DEBUG;
+ }
+#endif
+ constexpr UINT levelsCount = (UINT)levels.size();
+ HRESULT hr = D3D11CreateDevice( nullptr, D3D_DRIVER_TYPE_HARDWARE, nullptr, flags, levels.data(), levelsCount, D3D11_SDK_VERSION, &g_device, &g_featureLevel, &g_context );
+ if( SUCCEEDED( hr ) )
+ return S_OK;
+ // D3D11_CREATE_DEVICE_DISABLE_GPU_TIMEOUT: This value is not supported until Direct3D 11.1
+ // https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ne-d3d11-d3d11_create_device_flag
+ flags = _andn_u32( D3D11_CREATE_DEVICE_DISABLE_GPU_TIMEOUT, flags );
+
+ hr = D3D11CreateDevice( nullptr, D3D_DRIVER_TYPE_HARDWARE, nullptr, flags, levels.data(), levelsCount, D3D11_SDK_VERSION, &g_device, &g_featureLevel, &g_context );
+ if( SUCCEEDED( hr ) )
+ return S_OK;
+ return hr;
+ }
+
+ sGpuInfo s_gpuInfo = {};
+ const sGpuInfo& gpuInfo = s_gpuInfo;
+
+ static HRESULT queryDeviceInfo()
+ {
+ if( nullptr == g_device )
+ return OLE_E_BLANK;
+ CComPtr<IDXGIDevice> dd;
+ CHECK( g_device.QueryInterface( &dd ) );
+
+ CComPtr<IDXGIAdapter> adapter;
+ CHECK( dd->GetAdapter( &adapter ) );
+
+ DXGI_ADAPTER_DESC desc;
+ adapter->GetDesc( &desc );
+
+ const size_t descLen = wcsnlen_s( desc.Description, 128 );
+ const wchar_t* rsi = &desc.Description[ 0 ];
+ s_gpuInfo.description.assign( rsi, rsi + descLen );
+ s_gpuInfo.vendor = (eGpuVendor)desc.VendorId;
+ s_gpuInfo.device = (uint16_t)desc.DeviceId;
+ s_gpuInfo.revision = (uint16_t)desc.Revision;
+ s_gpuInfo.subsystem = desc.SubSysId;
+ s_gpuInfo.vramDedicated = desc.DedicatedVideoMemory;
+ s_gpuInfo.ramDedicated = desc.DedicatedSystemMemory;
+ s_gpuInfo.ramShared = desc.SharedSystemMemory;
+ return S_OK;
+ }
+
+ HRESULT initialize()
+ {
+ HRESULT hr = createDevice();
+ if( hr != S_OK )
+ return hr;
+ queryDeviceInfo();
+ return S_OK;
+ }
+
+ __m128i __declspec( noinline ) bufferMemoryUsage( ID3D11Buffer* buffer )
+ {
+ if( nullptr != buffer )
+ {
+ D3D11_BUFFER_DESC desc;
+ buffer->GetDesc( &desc );
+
+ if( desc.Usage != D3D11_USAGE_STAGING )
+ return setHigh_size( desc.ByteWidth );
+ else
+ return setLow_size( desc.ByteWidth );
+ }
+ return _mm_setzero_si128();
+ }
+
+ __m128i __declspec( noinline ) resourceMemoryUsage( ID3D11ShaderResourceView* srv )
+ {
+ if( nullptr != srv )
+ {
+ CComPtr<ID3D11Resource> res;
+ srv->GetResource( &res );
+ CComPtr<ID3D11Buffer> buff;
+ if( SUCCEEDED( res.QueryInterface( &buff ) ) )
+ return bufferMemoryUsage( buff );
+ assert( false ); // We don't use textures in this project
+ }
+ return _mm_setzero_si128();
+ }
+} \ No newline at end of file
diff --git a/Whisper/D3D/device.h b/Whisper/D3D/device.h
new file mode 100644
index 0000000..dfcb766
--- /dev/null
+++ b/Whisper/D3D/device.h
@@ -0,0 +1,66 @@
+#pragma once
+#include <atlcomcli.h>
+#include <string>
+
+namespace DirectCompute
+{
+ ID3D11Device* device();
+ ID3D11DeviceContext* context();
+ D3D_FEATURE_LEVEL featureLevel();
+
+ HRESULT initialize();
+ void terminate();
+
+ // DXGI_ADAPTER_DESC.VendorId magic numbers; they come from that database: https://pcisig.com/membership/member-companies
+ enum struct eGpuVendor : uint16_t
+ {
+ AMD = 0x1002,
+ NVidia = 0x10de,
+ Intel = 0x8086,
+ VMWare = 0x15ad,
+ };
+
+ struct sGpuInfo
+ {
+ std::wstring description;
+ eGpuVendor vendor;
+ uint16_t device, revision;
+ uint32_t subsystem;
+ size_t vramDedicated, ramDedicated, ramShared;
+
+ inline bool wave64() const
+ {
+ return vendor == eGpuVendor::AMD;
+ }
+
+ // On nVidia 1080Ti that approach is much slower, by a factor of 2.4
+ // On AMD Cezanne that approach is faster by a factor of 0.69, i.e. 30% faster.
+ // Dunno why that is, maybe 'coz on that AMD complete panels fit in L3 cache.
+ // Anyway, we do want extra 30% perf on AMD Cezanne, so only using that code on AMD GPUs.
+ // Dunno how it gonna behave on other GPUs, need to test.
+#if RESHAPED_MATRIX_MULTIPLY
+ inline bool useReshapedMatMul() const
+ {
+ // return true;
+ return vendor == eGpuVendor::AMD;
+ }
+#else
+ constexpr bool useReshapedMatMul() const { return false; }
+#endif
+ };
+ extern const sGpuInfo& gpuInfo;
+
+ inline bool available()
+ {
+ return nullptr != device();
+ }
+
+ inline void csSetCB( ID3D11Buffer* cb )
+ {
+ context()->CSSetConstantBuffers( 0, 1, &cb );
+ }
+
+ __m128i bufferMemoryUsage( ID3D11Buffer* buffer );
+
+ __m128i resourceMemoryUsage( ID3D11ShaderResourceView* srv );
+} \ No newline at end of file
diff --git a/Whisper/D3D/downloadBuffer.cpp b/Whisper/D3D/downloadBuffer.cpp
new file mode 100644
index 0000000..9790ab7
--- /dev/null
+++ b/Whisper/D3D/downloadBuffer.cpp
@@ -0,0 +1,72 @@
+#include "stdafx.h"
+#include "downloadBuffer.h"
+#include "device.h"
+#include "MappedResource.h"
+
+namespace
+{
+ struct BufferInfo
+ {
+ D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc;
+ D3D11_BUFFER_DESC bufferDesc;
+ CComPtr<ID3D11Buffer> source;
+
+ HRESULT create( ID3D11ShaderResourceView* srv )
+ {
+ srv->GetDesc( &viewDesc );
+ if( viewDesc.ViewDimension != D3D_SRV_DIMENSION_BUFFER )
+ return E_INVALIDARG;
+
+ CComPtr<ID3D11Resource> res;
+ srv->GetResource( &res );
+ CHECK( res.QueryInterface( &source ) );
+
+ source->GetDesc( &bufferDesc );
+ return S_OK;
+ }
+
+ HRESULT download( void* rdi )
+ {
+ bufferDesc.BindFlags = 0;
+ bufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
+ bufferDesc.Usage = D3D11_USAGE_STAGING;
+ CComPtr<ID3D11Buffer> staging;
+ using namespace DirectCompute;
+ CHECK( device()->CreateBuffer( &bufferDesc, nullptr, &staging ) );
+
+ context()->CopyResource( staging, source );
+
+ MappedResource mapped;
+ mapped.map( staging, true );
+ memcpy( rdi, mapped.data(), bufferDesc.ByteWidth );
+ return S_OK;
+ }
+ };
+
+ size_t dxgiSizeof( DXGI_FORMAT fmt )
+ {
+ switch( fmt )
+ {
+ case DXGI_FORMAT_R16_FLOAT: return 2;
+ case DXGI_FORMAT_R32_FLOAT: return 4;
+ }
+ return 0;
+ }
+}
+
+template<class E>
+HRESULT DirectCompute::downloadBuffer( ID3D11ShaderResourceView* srv, std::vector<E>& vec )
+{
+ BufferInfo bi;
+ CHECK( bi.create( srv ) );
+
+ const size_t cb = dxgiSizeof( bi.viewDesc.Format );
+ if( cb != sizeof( E ) )
+ return E_INVALIDARG;
+
+ vec.resize( bi.bufferDesc.ByteWidth / cb );
+ return bi.download( vec.data() );
+}
+
+template HRESULT DirectCompute::downloadBuffer( ID3D11ShaderResourceView* srv, std::vector<uint16_t>& vec );
+template HRESULT DirectCompute::downloadBuffer( ID3D11ShaderResourceView* srv, std::vector<float>& vec ); \ No newline at end of file
diff --git a/Whisper/D3D/downloadBuffer.h b/Whisper/D3D/downloadBuffer.h
new file mode 100644
index 0000000..0885007
--- /dev/null
+++ b/Whisper/D3D/downloadBuffer.h
@@ -0,0 +1,9 @@
+#pragma once
+
+namespace DirectCompute
+{
+ // Download a buffer from VRAM into std::vector
+ // The function is relatively expensive, creates a temporary staging buffer on each call, and only used to test things.
+ template<class E>
+ HRESULT downloadBuffer( ID3D11ShaderResourceView* srv, std::vector<E>& vec );
+} \ No newline at end of file
diff --git a/Whisper/D3D/enums.cpp b/Whisper/D3D/enums.cpp
new file mode 100644
index 0000000..7c31648
--- /dev/null
+++ b/Whisper/D3D/enums.cpp
@@ -0,0 +1,9 @@
+#include "stdafx.h"
+#include "enums.h"
+
+static const alignas( 16 ) std::array<DXGI_FORMAT, 3> s_tensorViewFormats = { DXGI_FORMAT_R16_FLOAT, DXGI_FORMAT_R32_FLOAT, DXGI_FORMAT_R32_UINT };
+
+DXGI_FORMAT DirectCompute::viewFormat( eDataType dt )
+{
+ return s_tensorViewFormats[ (uint8_t)dt ];
+} \ No newline at end of file
diff --git a/Whisper/D3D/enums.h b/Whisper/D3D/enums.h
new file mode 100644
index 0000000..c5d5350
--- /dev/null
+++ b/Whisper/D3D/enums.h
@@ -0,0 +1,34 @@
+#pragma once
+#include <stdint.h>
+#include <assert.h>
+
+namespace DirectCompute
+{
+ enum struct eDataType : uint8_t
+ {
+ FP16,
+ FP32,
+ U32,
+ };
+
+ inline size_t elementSize( eDataType dt )
+ {
+ assert( dt == eDataType::FP16 || dt == eDataType::FP32 || dt == eDataType::U32 );
+
+ return ( dt == eDataType::FP16 ) ? 2 : 4;
+ }
+
+ DXGI_FORMAT viewFormat( eDataType dt );
+
+ enum struct eBufferUse : uint8_t
+ {
+ // Immutable tensor, readable from GPU
+ Immutable,
+ // Read+write tensor, readable and writable on GPU
+ ReadWrite,
+ // Read+write tensor, readable and writable on GPU, which supports downloads from GPU
+ ReadWriteDownload,
+ // The tensor is accessible by both GPU (read only) and CPU (write only). Optimized for resources frequently updated from CPU.
+ Dynamic,
+ };
+} \ No newline at end of file
diff --git a/Whisper/D3D/shaderNames.cpp b/Whisper/D3D/shaderNames.cpp
new file mode 100644
index 0000000..b52f5db
--- /dev/null
+++ b/Whisper/D3D/shaderNames.cpp
@@ -0,0 +1,53 @@
+// This source file is generated by a tool
+#include "stdafx.h"
+#include "shaderNames.h"
+
+static const std::array<const char*, 38> s_shaderNames =
+{
+ "add",
+ "addInPlace",
+ "addRepeat",
+ "addRepeatGelu",
+ "addRepeatScale",
+ "addRows",
+ "convolutionMain",
+ "convolutionMain2",
+ "convolutionMain2Fixed",
+ "convolutionPrep1",
+ "convolutionPrep2",
+ "copyConvert",
+ "copyTranspose",
+ "diagMaskInf",
+ "flashAttention",
+ "flashAttentionCompat1",
+ "flashAttentionCompat2",
+ "flashAttentionCompat3",
+ "fmaRepeat1",
+ "fmaRepeat2",
+ "matReshapePanels",
+ "mulMatByRow",
+ "mulMatByRowTiled",
+ "mulMatByRowTiledEx",
+ "mulMatByScalar",
+ "mulMatDotMain",
+ "mulMatDotReshape",
+ "mulMatMadMain",
+ "mulMatTiled",
+ "mulMatTiledEx",
+ "norm",
+ "normCompat",
+ "normFixed",
+ "scaleInPlace",
+ "softMax",
+ "softMaxCompat",
+ "softMaxFixed",
+ "zeroMemory",
+};
+
+const char* DirectCompute::computeShaderName( eComputeShader cs )
+{
+ const uint16_t i = (uint16_t)cs;
+ if( i < s_shaderNames.size() )
+ return s_shaderNames[ i ];
+ return nullptr;
+} \ No newline at end of file
diff --git a/Whisper/D3D/shaderNames.h b/Whisper/D3D/shaderNames.h
new file mode 100644
index 0000000..ccfab86
--- /dev/null
+++ b/Whisper/D3D/shaderNames.h
@@ -0,0 +1,50 @@
+// This header is generated by a tool
+#pragma once
+#include <stdint.h>
+
+namespace DirectCompute
+{
+ enum struct eComputeShader: uint16_t
+ {
+ add = 0,
+ addInPlace = 1,
+ addRepeat = 2,
+ addRepeatGelu = 3,
+ addRepeatScale = 4,
+ addRows = 5,
+ convolutionMain = 6,
+ convolutionMain2 = 7,
+ convolutionMain2Fixed = 8,
+ convolutionPrep1 = 9,
+ convolutionPrep2 = 10,
+ copyConvert = 11,
+ copyTranspose = 12,
+ diagMaskInf = 13,
+ flashAttention = 14,
+ flashAttentionCompat1 = 15,
+ flashAttentionCompat2 = 16,
+ flashAttentionCompat3 = 17,
+ fmaRepeat1 = 18,
+ fmaRepeat2 = 19,
+ matReshapePanels = 20,
+ mulMatByRow = 21,
+ mulMatByRowTiled = 22,
+ mulMatByRowTiledEx = 23,
+ mulMatByScalar = 24,
+ mulMatDotMain = 25,
+ mulMatDotReshape = 26,
+ mulMatMadMain = 27,
+ mulMatTiled = 28,
+ mulMatTiledEx = 29,
+ norm = 30,
+ normCompat = 31,
+ normFixed = 32,
+ scaleInPlace = 33,
+ softMax = 34,
+ softMaxCompat = 35,
+ softMaxFixed = 36,
+ zeroMemory = 37,
+ };
+
+ const char* computeShaderName( eComputeShader cs );
+} \ No newline at end of file
diff --git a/Whisper/D3D/shaders.cpp b/Whisper/D3D/shaders.cpp
new file mode 100644
index 0000000..f7d9ce4
--- /dev/null
+++ b/Whisper/D3D/shaders.cpp
@@ -0,0 +1,104 @@
+#include "stdafx.h"
+#include "shaders.h"
+#include "startup.h"
+#include "device.h"
+#include <compressapi.h>
+#pragma comment( lib, "Cabinet.lib" )
+
+namespace
+{
+#ifdef _DEBUG
+#include "shaderData-Debug.inl"
+#else
+#include "shaderData-Release.inl"
+#endif
+
+ constexpr DWORD compressionAlgorithm = COMPRESS_ALGORITHM_MSZIP;
+
+ class Decompressor
+ {
+ DECOMPRESSOR_HANDLE handle = nullptr;
+
+ public:
+
+ HRESULT create()
+ {
+ if( CreateDecompressor( compressionAlgorithm, nullptr, &handle ) )
+ return S_OK;
+ return HRESULT_FROM_WIN32( GetLastError() );
+ }
+
+ HRESULT decompress( const uint8_t* src, size_t compressedLength, void* dest, size_t origLength ) const
+ {
+ if( Decompress( handle, src, compressedLength, dest, origLength, nullptr ) )
+ return S_OK;
+ return HRESULT_FROM_WIN32( GetLastError() );
+ }
+
+ ~Decompressor()
+ {
+ if( nullptr != handle )
+ {
+ CloseDecompressor( handle );
+ handle = nullptr;
+ }
+ }
+ };
+
+ static std::vector<CComPtr<ID3D11ComputeShader>> s_shaders;
+}
+
+HRESULT DirectCompute::createComputeShaders()
+{
+ constexpr size_t countBinaries = s_shaderOffsets.size() - 1;
+ const size_t cbDecompressedLength = s_shaderOffsets[ countBinaries ];
+ constexpr size_t countShaders = s_shaderBlobs32.size();
+
+ std::vector<uint8_t> dxbc;
+ try
+ {
+ s_shaders.resize( countShaders );
+ dxbc.resize( cbDecompressedLength );
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+
+ Decompressor decomp;
+ CHECK( decomp.create() );
+
+ decomp.decompress( s_compressedShaders.data(), s_compressedShaders.size(), dxbc.data(), cbDecompressedLength );
+ ID3D11Device* const dev = device();
+
+ const auto& blobs = gpuInfo.wave64() ? s_shaderBlobs64 : s_shaderBlobs32;
+
+ for( size_t i = 0; i < countShaders; i++ )
+ {
+ const size_t idxBinary = blobs[ i ];
+ const uint32_t offThis = s_shaderOffsets[ idxBinary ];
+ const uint8_t* rsi = &dxbc[ offThis ];
+ const size_t len = s_shaderOffsets[ idxBinary + 1 ] - offThis;
+ const HRESULT hr = dev->CreateComputeShader( rsi, len, nullptr, &s_shaders[ i ] );
+ if( SUCCEEDED( hr ) )
+ continue;
+
+ const uint64_t binaryBit = ( 1ull << idxBinary );
+ if( 0 != ( binaryBit & fp64ShadersBitmap ) )
+ continue; // This shader uses FP64 math, the support for that is optional. When not supported, CreateComputeShader method is expected to fail.
+ // TODO [low]: ideally, query for the support when creating the device, and don't even try creating these compute shaders
+ return hr;
+ }
+
+ return S_OK;
+}
+
+void DirectCompute::destroyComputeShaders()
+{
+ s_shaders.clear();
+}
+
+void DirectCompute::bindShader( eComputeShader shader )
+{
+ context()->CSSetShader( s_shaders[ (uint16_t)shader ], nullptr, 0 );
+} \ No newline at end of file
diff --git a/Whisper/D3D/shaders.h b/Whisper/D3D/shaders.h
new file mode 100644
index 0000000..1188988
--- /dev/null
+++ b/Whisper/D3D/shaders.h
@@ -0,0 +1,7 @@
+#pragma once
+#include "shaderNames.h"
+
+namespace DirectCompute
+{
+ void bindShader( eComputeShader shader );
+} \ No newline at end of file
diff --git a/Whisper/D3D/startup.cpp b/Whisper/D3D/startup.cpp
new file mode 100644
index 0000000..2ff0b0c
--- /dev/null
+++ b/Whisper/D3D/startup.cpp
@@ -0,0 +1,17 @@
+#include "stdafx.h"
+#include "startup.h"
+#include "device.h"
+
+HRESULT DirectCompute::d3dStartup()
+{
+ HRESULT hr = DirectCompute::initialize();
+ if( SUCCEEDED( hr ) )
+ hr = createComputeShaders();
+ return hr;
+}
+
+void DirectCompute::d3dShutdown()
+{
+ destroyComputeShaders();
+ terminate();
+} \ No newline at end of file
diff --git a/Whisper/D3D/startup.h b/Whisper/D3D/startup.h
new file mode 100644
index 0000000..de42fae
--- /dev/null
+++ b/Whisper/D3D/startup.h
@@ -0,0 +1,11 @@
+#pragma once
+using HRESULT = long;
+
+namespace DirectCompute
+{
+ HRESULT d3dStartup();
+ void d3dShutdown();
+
+ HRESULT createComputeShaders();
+ void destroyComputeShaders();
+} \ No newline at end of file
diff --git a/Whisper/DllMain.cpp b/Whisper/DllMain.cpp
new file mode 100644
index 0000000..4746d98
--- /dev/null
+++ b/Whisper/DllMain.cpp
@@ -0,0 +1,27 @@
+#include "stdafx.h"
+
+BOOL __stdcall DllMain( HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved )
+{
+ // Perform actions based on the reason for calling.
+ switch( fdwReason )
+ {
+ case DLL_PROCESS_ATTACH:
+ // Initialize once for each new process. Return FALSE to fail DLL load.
+ DisableThreadLibraryCalls( (HMODULE)hinstDLL );
+ break;
+ case DLL_THREAD_ATTACH:
+ // Do thread-specific initialization.
+ break;
+ case DLL_THREAD_DETACH:
+ // Do thread-specific cleanup.
+ break;
+ case DLL_PROCESS_DETACH:
+ if( lpvReserved != nullptr )
+ {
+ break; // do not do cleanup if process termination scenario
+ }
+ // Perform any necessary cleanup
+ break;
+ }
+ return TRUE; // Successful DLL_PROCESS_ATTACH.
+} \ No newline at end of file
diff --git a/Whisper/Hybrid/HybridContext.cpp b/Whisper/Hybrid/HybridContext.cpp
new file mode 100644
index 0000000..64ed2e6
--- /dev/null
+++ b/Whisper/Hybrid/HybridContext.cpp
@@ -0,0 +1,349 @@
+#include "stdafx.h"
+#include <immintrin.h>
+#include <optional>
+#include "HybridContext.h"
+#include "../Utils/Trace/tracing.h"
+
+#if BUILD_HYBRID_VERSION
+namespace
+{
+ int threadsCount( int t )
+ {
+#ifdef NDEBUG
+ if( t == 0 )
+ {
+ SYSTEM_INFO si;
+ GetSystemInfo( &si );
+ return (int)si.dwNumberOfProcessors;
+ }
+ if( t <= 1 )
+ return 1;
+ return t;
+#else
+ return 1;
+#endif
+ }
+
+ constexpr size_t MB = 1u << 20;
+}
+
+HybridContext::HybridContext( const Whisper::WhisperModel& wm ) :
+ ml( threadsCount( 0 ) ),
+ model( wm.hybridTensors ),
+ whisperModel( wm )
+{ }
+
+namespace
+{
+ enum struct eModelType : uint8_t
+ {
+ Tiny = 0,
+ Base = 1,
+ Small = 2,
+ Medium = 3,
+ Large = 4,
+ };
+
+ static HRESULT detectModelType( const Whisper::sModelParams& modelParams, eModelType& mt )
+ {
+ switch( modelParams.n_audio_layer )
+ {
+ case 4:
+ mt = eModelType::Tiny;
+ return S_OK;
+ case 6:
+ mt = eModelType::Base;
+ return S_OK;
+ case 12:
+ mt = eModelType::Small;
+ return S_OK;
+ case 24:
+ mt = eModelType::Medium;
+ return S_OK;
+ case 32:
+ mt = eModelType::Large;
+ return S_OK;
+ }
+ logError( u8"Unrecognized model" );
+ return E_INVALIDARG;
+ }
+
+ struct alignas( 2 ) RamMB
+ {
+ uint8_t dec, decLayer;
+ constexpr RamMB( uint8_t d, uint8_t dl ) : dec( d ), decLayer( dl ) { }
+
+ __m128i loadBytes() const
+ {
+ __m128i v = _mm_loadu_si16( this );
+ // Upcast bytes to int64_t. That instruction can load directly from memory, too bad VC++ optimized doesn't care
+ v = _mm_cvtepu8_epi64( v );
+ // Scale from megabytes into bytes, the multiplier is obviously 2^20
+ v = _mm_slli_epi64( v, 20 );
+ return v;
+ }
+ };
+
+ // The magic numbers are from MEM_REQ_DECODE and MEM_REQ_DECODE_LAYER red/black maps in the reference version,
+ // near the top of whisper.cpp source file
+ static const std::array<RamMB, 5> s_memRequirements =
+ {
+ RamMB{ 200, 32 }, // Tiny
+ RamMB{ 202, 44 }, // Base
+ RamMB{ 204, 64 }, // Small
+ RamMB{ 206, 84 }, // Medium
+ RamMB{ 208, 110 }, // Large
+ };
+}
+
+HRESULT HybridContext::create()
+{
+ // Allocate buffers for compute
+ // We know they're large, so bypassing the heap
+ eModelType modelType;
+ CHECK( detectModelType( whisperModel.parameters, modelType ) );
+
+ const __m128i bytes = s_memRequirements.at( (uint8_t)modelType ).loadBytes();
+ CHECK( allocCompute.create( _mm_cvtsi128_si64( bytes ) ) );
+ CHECK( allocComputeLayer.create( _mm_extract_epi64( bytes, 1 ) ) );
+
+ // Create staging buffers to download output from encoder stage,
+ // in the reference version they're named memory_cross_k / memory_cross_v
+ CHECK( kvCross.create( whisperModel.parameters ) );
+
+ // Create RAM buffers for memory_k / memory_v
+ CHECK( kv.create( whisperModel.parameters ) );
+
+ return S_OK;
+}
+
+class HybridContext::SetAllocatorRaii
+{
+ HybridContext& context;
+ CpuCompute::iMemoryAllocator* prevAlloc;
+ CpuCompute::iArenaAllocator* newAlloc;
+public:
+
+ SetAllocatorRaii( HybridContext* owner, CpuCompute::iArenaAllocator& a ) :
+ context( *owner )
+ {
+ prevAlloc = context.ml.setAllocator( &a );
+ newAlloc = &a;
+ }
+ ~SetAllocatorRaii()
+ {
+ context.ml.setAllocator( prevAlloc );
+ newAlloc->resetArena();
+ }
+};
+
+HRESULT HybridContext::decode( const int* tokens, const int n_tokens, const int n_past, const sDecParams& dp, std::vector<float>& probs )
+{
+ CHECK( ml.setThreadsCount( dp.n_threads ) );
+
+ // whisper_decode
+ const auto& hparams = whisperModel.parameters;
+ const uint32_t n_vocab = hparams.n_vocab;
+
+ const uint32_t n_ctx = hparams.n_text_ctx;
+ const uint32_t n_state = hparams.n_text_state;
+ const uint32_t n_head = hparams.n_text_head;
+ const uint32_t n_layer = hparams.n_text_layer;
+
+ const uint32_t N = n_tokens;
+ const uint32_t M = dp.M;
+
+ SetAllocatorRaii ac{ this, allocCompute };
+ using namespace CpuCompute;
+ Tensor cur = ml.addRows( model.tokenEmbedding, model.positionalEmbedding, tokens, n_tokens, n_past );
+ Tracing::tensor( "dec-rows", cur );
+
+ Tensor inpL = cur;
+ auto kvCross = this->kvCross.map();
+
+ for( uint32_t il = 0; il < n_layer; il++ )
+ {
+ if( 0 == il ) Tracing::tensor( "dec-inpL", inpL );
+ const auto& layer = model.layers[ il ];
+ SetAllocatorRaii acLayer{ this, allocComputeLayer };
+
+ // norm
+ Tensor cur = ml.norm( inpL );
+ ml.fmaRepeat( cur, layer.attnLn0 );
+ if( 0 == il ) Tracing::tensor( "dec-norm", cur );
+
+ // self-attention
+ {
+ Tensor Qcur = ml.mulMat( layer.attnQuery.w, cur );
+ if( 0 == il ) Tracing::tensor( "dec-Qcur-0", Qcur );
+ const float scaling = (float)pow( float( (int)n_state ) / (int)n_head, -0.25 );
+ ml.addRepeatScale( Qcur, layer.attnQuery.b, scaling );
+ if( 0 == il ) Tracing::tensor( "dec-Qcur-1", Qcur );
+
+ // note: no bias for Key
+ Tensor Kcur = ml.mulMat( layer.attnKey, cur );
+ ml.scale( Kcur, scaling );
+ if( 0 == il ) Tracing::tensor( "dec-Kcur", Kcur );
+
+ Tensor Vcur = ml.mulMat( layer.attnValue.w, cur );
+ ml.addRepeat( Vcur, layer.attnValue.b );
+ if( 0 == il ) Tracing::tensor( "dec-Vcur", Vcur );
+
+ // store key and value to memory
+ {
+ const uint32_t len = N * n_state;
+ const uint32_t off = n_state * ( (uint32_t)il * n_ctx + n_past );
+ Tensor k = kv.keysView( len, off );
+ Tensor v = kv.valuesView( len, off );
+
+ CHECK( ml.copyImpl( k, Kcur ) );
+ CHECK( ml.copyImpl( v, Vcur ) );
+ }
+
+ // ------
+ Tensor Q = ml.permute( ml.copy( Qcur, eDataType::FP32, { n_state / n_head, n_head, N } ), 0, 2, 1, 3 );
+ Tensor K = ml.permute( kv.keysView( ( n_past + N ) * n_state, (uint32_t)il * n_ctx * n_state )
+ .reshape3d( n_state / n_head, n_head, n_past + N ),
+ 0, 2, 1, 3 );
+ Tensor KQ = ml.mulMat( K, Q );
+ if( 0 == il ) Tracing::tensor( "dec-KQ-0", KQ );
+ ml.diagMaskInf( KQ, n_past );
+ if( 0 == il ) Tracing::tensor( "dec-KQ-1", KQ );
+ ml.softMax( KQ );
+ if( 0 == il ) Tracing::tensor( "dec-KQ-2", KQ );
+
+ Tensor V_trans = ml.permute(
+ kv.valuesView( ( n_past + N ) * n_state, (uint32_t)il * n_ctx * n_state )
+ .reshape3d( n_state / n_head, n_head, n_past + N ),
+ 1, 2, 0, 3 );
+
+ Tensor KQV = ml.mulMat( V_trans, KQ );
+ if( 0 == il ) Tracing::tensor( "dec-KQV", KQV );
+
+ Tensor KQV_merged = ml.permute( KQV, 0, 2, 1, 3 );
+ ml.copyInPlace( cur, KQV_merged, eDataType::FP32, { n_state, N } );
+ }
+
+ {
+ cur = ml.mulMat( layer.attnLn1.w, cur );
+ ml.addRepeat( cur, layer.attnLn1.b );
+ }
+
+ // add the input
+ Tensor inpCA = ml.add( cur, inpL );
+
+ // norm
+ {
+ cur = ml.norm( inpCA );
+ ml.fmaRepeat( cur, layer.crossAttnLn0 );
+ }
+
+ // cross-attention
+ {
+ Tensor Qcur = ml.mulMat( layer.crossAttnQuery.w, cur );
+ ml.addRepeatScale( Qcur, layer.crossAttnQuery.b, (float)pow( float( (int)n_state ) / (int)n_head, -0.25 ) );
+
+ // Kcross is already scaled
+ const uint32_t len = M * n_state;
+ const uint32_t off = (uint32_t)il * len;
+ Tensor Kcross = kvCross.keysView( len, off ).reshape3d( n_state / n_head, n_head, M );
+ Tensor Vcross = kvCross.valuesView( len, off ).reshape3d( n_state / n_head, n_head, M );
+
+ // ------
+ Tensor Q = ml.permute( ml.copy( Qcur, eDataType::FP32, { n_state / n_head, n_head, N } ), 0, 2, 1, 3 );
+ Tensor K = ml.permute( Kcross, 0, 2, 1, 3 );
+ Tensor KQ = ml.mulMat( K, Q );
+ ml.softMax( KQ );
+ Tensor V_trans = ml.permute( Vcross, 1, 2, 0, 3 );
+ Tensor KQV = ml.mulMat( V_trans, KQ );
+ if( 0 == il ) Tracing::tensor( "dec-KQV", KQV );
+ Tensor KQV_merged = ml.permute( KQV, 0, 2, 1, 3 );
+
+ ml.copyInPlace( cur, KQV_merged, eDataType::FP32, { n_state, N } );
+ }
+
+ // projection
+ {
+ cur = ml.mulMat( layer.crossAttnLn1.w, cur );
+ ml.addRepeat( cur, layer.crossAttnLn1.b );
+ }
+ // add the input
+ ml.addInPlace( cur, inpCA );
+ Tensor inpFF = cur;
+
+ // feed-forward network
+ {
+ // norm
+ cur = ml.norm( inpFF );
+ ml.fmaRepeat( cur, layer.mlpLn );
+
+ cur = ml.mulMat( layer.mlp0.w, cur );
+ ml.addRepeatGelu( cur, layer.mlp0.b );
+
+ // The mulMat() below creates a tensor for the output of this layer.
+ // We have a special memory storage for these tensors, that's how they survive resets of per-layer arenas
+ allocLayerOutput.resetArena();
+ ml.setAllocator( &allocLayerOutput );
+
+ // projection
+ cur = ml.mulMat( layer.mlp1.w, cur );
+ ml.addRepeat( cur, layer.mlp1.b );
+ }
+
+ // output from this layer
+ ml.addInPlace( cur, inpFF );
+ inpL = cur;
+ }
+
+ // norm
+ cur = ml.norm( inpL );
+ ml.fmaRepeat( cur, model.ln );
+
+ cur = ml.mulMat( model.tokenEmbedding, cur );
+
+ // logits -> probs
+ ml.softMax( cur );
+
+ const float* rsi = cur.fp32();
+ probs.assign( rsi, rsi + cur.countElements() );
+ Tracing::vector( "probs", probs );
+ return S_OK;
+}
+
+void* HybridContext::AllocSingle::allocate( size_t cb, size_t align )
+{
+ if( !allocated )
+ {
+ allocated = true;
+ if( cb <= capacity )
+ {
+ CpuCompute::dbgMarkUninitializedMemory( buffer.pointer(), capacity );
+ return buffer.pointer();
+ }
+ else
+ {
+ HRESULT hr = buffer.allocate( cb );
+ if( SUCCEEDED( hr ) )
+ {
+ capacity = cb;
+ CpuCompute::dbgMarkUninitializedMemory( buffer.pointer(), capacity );
+ return buffer.pointer();
+ }
+ logErrorHr( hr, u8"HybridContext.AllocSingle.allocate" );
+ throw hr;
+ }
+ }
+ else
+ {
+ logError( u8"HybridContext.AllocSingle only supports 1 tensor" );
+ throw E_UNEXPECTED;
+ }
+}
+
+void HybridContext::AllocSingle::resetArena()
+{
+ allocated = false;
+ if( capacity > 0 )
+ CpuCompute::dbgMarkFreedMemory( buffer.pointer(), capacity );
+}
+#endif \ No newline at end of file
diff --git a/Whisper/Hybrid/HybridContext.h b/Whisper/Hybrid/HybridContext.h
new file mode 100644
index 0000000..52039eb
--- /dev/null
+++ b/Whisper/Hybrid/HybridContext.h
@@ -0,0 +1,52 @@
+#pragma once
+#include "../Whisper/WhisperModel.h"
+#include "../CPU/MlContext.h"
+#include "../CPU/BufferAllocator.h"
+#include "KeyValueDownloader.h"
+#include "../CPU/KvTensors.h"
+
+// This version of the hybrid context uses the new, custom-built kernels
+class HybridContext
+{
+ CpuCompute::MlContext ml;
+ CpuCompute::VirtualAllocator allocCompute, allocComputeLayer;
+
+ class AllocSingle : public CpuCompute::iArenaAllocator
+ {
+ CpuCompute::LargeBuffer buffer;
+ size_t capacity = 0;
+ bool allocated = false;
+ // Inherited via iArenaAllocator
+ virtual void* allocate( size_t cb, size_t align ) override final;
+
+ public:
+ virtual void resetArena() override final;
+ };
+ AllocSingle allocLayerOutput;
+
+ const CpuCompute::DecoderTensors& model;
+ const Whisper::WhisperModel& whisperModel;
+ KeyValueDownloader kvCross;
+ CpuCompute::KvTensors kv;
+
+ class SetAllocatorRaii;
+
+public:
+
+ HybridContext( const Whisper::WhisperModel& wm );
+
+ HRESULT create();
+
+ HRESULT downloadKeyValues( const DirectCompute::KeyValueBuffers& source )
+ {
+ return kvCross.download( source );
+ }
+
+ struct sDecParams
+ {
+ int n_threads;
+ int M;
+ };
+
+ HRESULT decode( const int* tokens, const int n_tokens, const int n_past, const sDecParams& dp, std::vector<float>& probs_out );
+}; \ No newline at end of file
diff --git a/Whisper/Hybrid/KeyValueDownloader.cpp b/Whisper/Hybrid/KeyValueDownloader.cpp
new file mode 100644
index 0000000..ad50136
--- /dev/null
+++ b/Whisper/Hybrid/KeyValueDownloader.cpp
@@ -0,0 +1,32 @@
+#include "stdafx.h"
+#include "KeyValueDownloader.h"
+
+HRESULT KeyValueDownloader::create( const Whisper::sModelParams& mp )
+{
+ const uint32_t n_audio_ctx = mp.n_audio_ctx;
+ const uint32_t n_mem = mp.n_text_layer * mp.n_audio_ctx;
+ const uint32_t n_elements = mp.n_text_state * n_mem;
+
+ CD3D11_BUFFER_DESC desc{ n_elements * 2, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ };
+ ID3D11Device* dev = DirectCompute::device();
+ CHECK( dev->CreateBuffer( &desc, nullptr, &keys ) );
+ CHECK( dev->CreateBuffer( &desc, nullptr, &values ) );
+
+ length = n_elements;
+ return S_OK;
+}
+
+HRESULT KeyValueDownloader::download( const DirectCompute::KeyValueBuffers& source )
+{
+ ID3D11DeviceContext* ctx = DirectCompute::context();
+ ctx->CopyResource( keys, source.keys.getBuffer() );
+ ctx->CopyResource( values, source.values.getBuffer() );
+ return S_OK;
+}
+
+KeyValueDownloader::ReadMap::ReadMap( KeyValueDownloader& owner ) :
+ length( owner.length )
+{
+ check( mappedKeys.map( owner.keys, true ) );
+ check( mappedValues.map( owner.values, true ) );
+} \ No newline at end of file
diff --git a/Whisper/Hybrid/KeyValueDownloader.h b/Whisper/Hybrid/KeyValueDownloader.h
new file mode 100644
index 0000000..e0e9644
--- /dev/null
+++ b/Whisper/Hybrid/KeyValueDownloader.h
@@ -0,0 +1,63 @@
+#pragma once
+#include "../Whisper/sModelParams.h"
+#include "../Whisper/KeyValueBuffers.h"
+#include "../D3D/MappedResource.h"
+#include "../CPU/Tensor.h"
+
+class KeyValueDownloader
+{
+ CComPtr<ID3D11Buffer> keys, values;
+ uint32_t length = 0;
+
+ using E = uint16_t;
+ static constexpr DirectCompute::eDataType dataType = DirectCompute::eDataType::FP16;
+
+public:
+ // Create the staging resources to download kvCross tensors produced by the GPGPU encoder
+ HRESULT create( const Whisper::sModelParams& mp );
+
+ // Download these two tensors from VRAM to the staging buffers in system RAM
+ HRESULT download( const DirectCompute::KeyValueBuffers& source );
+
+ class ReadMap
+ {
+ const uint32_t length;
+ DirectCompute::MappedResource mappedKeys, mappedValues;
+
+ public:
+ ReadMap( KeyValueDownloader& owner );
+ ~ReadMap() = default;
+ ReadMap( const ReadMap& ) = delete;
+
+ // A slice of model.memory_k tensor
+ CpuCompute::Tensor keysView( uint32_t len, uint32_t off ) const
+ {
+ if( len + off <= length )
+ {
+ E* rsi = (E*)mappedKeys.data();
+ rsi += off;
+ return CpuCompute::Tensor::fromData( rsi, dataType, len );
+ }
+ throw E_BOUNDS;
+ }
+
+ // A slice of model.memory_v tensor
+ CpuCompute::Tensor valuesView( uint32_t len, uint32_t off ) const
+ {
+ if( len + off <= length )
+ {
+ E* rsi = (E*)mappedValues.data();
+ rsi += off;
+ return CpuCompute::Tensor::fromData( rsi, dataType, len );
+ }
+ throw E_BOUNDS;
+ }
+ };
+
+ // Map both staging buffers, return RAII object which unmaps when destroyed,
+ // which can supply the data in the shape of CpuCompute::Tensor vector
+ decltype( auto ) map()
+ {
+ return ReadMap( *this );
+ }
+}; \ No newline at end of file
diff --git a/Whisper/Hybrid/Readme.txt b/Whisper/Hybrid/Readme.txt
new file mode 100644
index 0000000..702813d
--- /dev/null
+++ b/Whisper/Hybrid/Readme.txt
@@ -0,0 +1 @@
+The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_HYBRID_VERSION macro in stdafx.h \ No newline at end of file
diff --git a/Whisper/MF/AudioBuffer.cpp b/Whisper/MF/AudioBuffer.cpp
new file mode 100644
index 0000000..ba4752d
--- /dev/null
+++ b/Whisper/MF/AudioBuffer.cpp
@@ -0,0 +1,93 @@
+#include "stdafx.h"
+#include "AudioBuffer.h"
+using namespace Whisper;
+
+void AudioBuffer::appendMono( const float* rsi, size_t countFloats )
+{
+ mono.insert( mono.end(), rsi, rsi + countFloats );
+}
+
+void AudioBuffer::appendStereo( const float* rsi, size_t countFloats )
+{
+ assert( 0 == ( countFloats % 2 ) );
+ const size_t countSamples = countFloats / 2;
+
+ const size_t oldLength = mono.size();
+ assert( oldLength * 2 == stereo.size() );
+ mono.resize( oldLength + countSamples );
+ stereo.resize( ( oldLength + countSamples ) * 2 );
+
+ const float* const rsiEnd = rsi + countSamples * 2;
+ const float* const rsiEndAligned = rsiEnd - ( countSamples * 2 ) % 8;
+
+ float* rdiStereo = &stereo[ oldLength * 2 ];
+ float* rdiMono = &mono[ oldLength ];
+
+ const __m128 half = _mm_set1_ps( 0.5f );
+ for( ; rsi < rsiEndAligned; rsi += 8, rdiStereo += 8, rdiMono += 4 )
+ {
+ // Load 4 samples = 8 floats
+ __m128 v0 = _mm_loadu_ps( rsi ); // L0, R0, L1, R1
+ __m128 v1 = _mm_loadu_ps( rsi + 4 );// L2, R2, L3, R3
+
+ // Store into the stereo PCM vector
+ _mm_storeu_ps( rdiStereo, v0 );
+ _mm_storeu_ps( rdiStereo + 4, v1 );
+
+ // Compute and store the average of these channels
+ __m128 left = _mm_shuffle_ps( v0, v1, _MM_SHUFFLE( 2, 0, 2, 0 ) );
+ __m128 right = _mm_shuffle_ps( v0, v1, _MM_SHUFFLE( 3, 1, 3, 1 ) );
+ __m128 sum = _mm_add_ps( left, right );
+ sum = _mm_mul_ps( sum, half );
+ _mm_storeu_ps( rdiMono, sum );
+ }
+
+#pragma loop (no_vector)
+ for( ; rsi < rsiEnd; rsi += 2, rdiStereo += 2, rdiMono++ )
+ {
+ __m128 vec = _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) );
+ _mm_store_sd( (double*)rdiStereo, _mm_castps_pd( vec ) );
+
+ vec = _mm_add_ss( vec, _mm_movehdup_ps( vec ) );
+ vec = _mm_mul_ss( vec, half );
+ _mm_store_ss( rdiMono, vec );
+ }
+}
+
+void AudioBuffer::appendDownmixedStereo( const float* rsi, size_t countFloats )
+{
+ assert( 0 == ( countFloats % 2 ) );
+ const size_t countSamples = countFloats / 2;
+
+ const size_t oldLength = mono.size();
+ mono.resize( oldLength + countSamples );
+
+ const float* const rsiEnd = rsi + countSamples * 2;
+ const float* const rsiEndAligned = rsiEnd - ( countSamples * 2 ) % 8;
+
+ float* rdiMono = &mono[ oldLength ];
+
+ const __m128 half = _mm_set1_ps( 0.5f );
+ for( ; rsi < rsiEndAligned; rsi += 8, rdiMono += 4 )
+ {
+ // Load 4 samples = 8 floats
+ __m128 v0 = _mm_loadu_ps( rsi ); // L0, R0, L1, R1
+ __m128 v1 = _mm_loadu_ps( rsi + 4 );// L2, R2, L3, R3
+
+ // Compute and store the average of these channels
+ __m128 left = _mm_shuffle_ps( v0, v1, _MM_SHUFFLE( 2, 0, 2, 0 ) );
+ __m128 right = _mm_shuffle_ps( v0, v1, _MM_SHUFFLE( 3, 1, 3, 1 ) );
+ __m128 sum = _mm_add_ps( left, right );
+ sum = _mm_mul_ps( sum, half );
+ _mm_storeu_ps( rdiMono, sum );
+ }
+
+#pragma loop (no_vector)
+ for( ; rsi < rsiEnd; rsi += 2, rdiMono++ )
+ {
+ __m128 vec = _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) );
+ vec = _mm_add_ss( vec, _mm_movehdup_ps( vec ) );
+ vec = _mm_mul_ss( vec, half );
+ _mm_store_ss( rdiMono, vec );
+ }
+} \ No newline at end of file
diff --git a/Whisper/MF/AudioBuffer.h b/Whisper/MF/AudioBuffer.h
new file mode 100644
index 0000000..87319dd
--- /dev/null
+++ b/Whisper/MF/AudioBuffer.h
@@ -0,0 +1,41 @@
+#pragma once
+#include <vector>
+
+namespace Whisper
+{
+ struct AudioBuffer
+ {
+ std::vector<float> mono;
+ std::vector<float> stereo;
+
+ void appendMono( const float* rsi, size_t countFloats );
+ void appendDownmixedStereo( const float* rsi, size_t countFloats );
+ void appendStereo( const float* rsi, size_t countFloats );
+
+ using pfnAppendSamples = void( AudioBuffer::* )( const float* rsi, size_t countFloats );
+
+ inline static pfnAppendSamples appendSamplesFunc( bool sourceMono, bool wantStereo )
+ {
+ if( sourceMono )
+ return &AudioBuffer::appendMono;
+ else if( !wantStereo )
+ return &AudioBuffer::appendDownmixedStereo;
+ else
+ return &AudioBuffer::appendStereo;
+ }
+
+ void clear()
+ {
+ mono.clear();
+ stereo.clear();
+ }
+
+ void resize( size_t len )
+ {
+ assert( len <= mono.size() );
+ mono.resize( len );
+ if( !stereo.empty() )
+ stereo.resize( len * 2 );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/MF/AudioCapture.cpp b/Whisper/MF/AudioCapture.cpp
new file mode 100644
index 0000000..17f34dc
--- /dev/null
+++ b/Whisper/MF/AudioCapture.cpp
@@ -0,0 +1,167 @@
+#include "stdafx.h"
+#include <atlstr.h>
+#include <mfapi.h>
+#include <mfidl.h>
+#include <mfreadwrite.h>
+#include "AudioCapture.h"
+#include "../API/iMediaFoundation.cl.h"
+#include "../ComLightLib/comLightServer.h"
+#pragma comment(lib, "Mf.lib")
+
+namespace
+{
+ struct Strings
+ {
+ CString displayName, endpoint;
+ };
+
+ HRESULT getAllocString( IMFActivate* activate, const GUID& id, CString& rdi )
+ {
+ wchar_t* pointer = nullptr;
+ UINT32 cchName;
+ HRESULT hr = activate->GetAllocatedString( id, &pointer, &cchName );
+ if( SUCCEEDED( hr ) )
+ rdi.SetString( pointer, cchName );
+ CoTaskMemFree( pointer );
+ return hr;
+ }
+
+ HRESULT getInfo( IMFActivate* activate, Strings& rdi )
+ {
+ CHECK( getAllocString( activate, MF_DEVSOURCE_ATTRIBUTE_FRIENDLY_NAME, rdi.displayName ) );
+ CHECK( getAllocString( activate, MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE_AUDCAP_ENDPOINT_ID, rdi.endpoint ) );
+ return S_OK;
+ }
+
+ HRESULT __stdcall supplyDevices( Whisper::pfnFoundCaptureDevices pfn, void* pv, IMFActivate** ppDevices, UINT32 count )
+ {
+ if( ppDevices == nullptr || count == 0 )
+ return pfn( 0, nullptr, pv );
+
+ std::vector<Strings> strings;
+ strings.reserve( count );
+
+ for( UINT i = 0; i < count; i++ )
+ {
+ IMFActivate* const activate = ppDevices[ i ];
+ if( nullptr == activate )
+ continue;
+ Strings info;
+ HRESULT hr = getInfo( activate, info );
+ if( FAILED( hr ) )
+ continue;
+
+ strings.emplace_back( std::move( info ) );
+ }
+
+ const size_t len = strings.size();
+ if( 0 == len )
+ return pfn( 0, nullptr, pv );
+
+ std::vector<Whisper::sCaptureDevice> pointers;
+ pointers.resize( len );
+ for( size_t i = 0; i < len; i++ )
+ {
+ const auto& src = strings[ i ];
+ auto& dest = pointers[ i ];
+ dest.displayName = src.displayName;
+ dest.endpoint = src.endpoint;
+ }
+ return pfn( (int)len, pointers.data(), pv );
+ }
+}
+
+HRESULT __stdcall Whisper::captureDeviceList( pfnFoundCaptureDevices pfn, void* pv )
+{
+ // Create an attribute store to hold the search criteria.
+ CComPtr<IMFAttributes> attrs;
+ CHECK( MFCreateAttributes( &attrs, 1 ) );
+ // Request audio capture devices
+ CHECK( attrs->SetGUID( MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE, MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE_AUDCAP_GUID ) );
+
+ // Enumerate the devices
+ IMFActivate** ppDevices = nullptr;
+ UINT32 count = 0;
+ CHECK( MFEnumDeviceSources( attrs, &ppDevices, &count ) );
+
+ // Feed the data to the caller
+ HRESULT hr = supplyDevices( pfn, pv, ppDevices, count );
+
+ // Free the memory
+ for( DWORD i = 0; i < count; i++ )
+ ppDevices[ i ]->Release();
+ CoTaskMemFree( ppDevices );
+
+ return hr;
+}
+
+namespace
+{
+ using namespace Whisper;
+
+ class Capture : public ComLight::ObjectRoot<iAudioCapture>
+ {
+ CComPtr<IMFSourceReader> reader;
+ CComPtr<iMediaFoundation> mediaFoundation;
+ sCaptureParams captureParams;
+
+ HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const noexcept override final
+ {
+ if( pp == nullptr )
+ return E_POINTER;
+ CComPtr<IMFSourceReader> res = reader;
+ *pp = res.Detach();;
+ return S_OK;
+ }
+ const sCaptureParams& COMLIGHTCALL getParams() const noexcept override final
+ {
+ return captureParams;
+ }
+ public:
+ HRESULT open( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& cp );
+ };
+
+ HRESULT Capture::open( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& cp )
+ {
+ // Create an attribute store to hold the search criteria.
+ CComPtr<IMFAttributes> attrs;
+ CHECK( MFCreateAttributes( &attrs, 2 ) );
+ // Request audio capture devices
+ CHECK( attrs->SetGUID( MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE, MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE_AUDCAP_GUID ) );
+ CHECK( attrs->SetString( MF_DEVSOURCE_ATTRIBUTE_SOURCE_TYPE_AUDCAP_ENDPOINT_ID, endpoint ) );
+
+ CComPtr<IMFMediaSource> source;
+ HRESULT hr = MFCreateDeviceSource( attrs, &source );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"MFCreateDeviceSource" );
+ return hr;
+ }
+
+ // TODO: implement IMFSourceReaderCallback, pass into MF_SOURCE_READER_ASYNC_CALLBACK attribute
+ // This is to support cancellation
+ hr = MFCreateSourceReaderFromMediaSource( source, nullptr, &reader );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"MFCreateSourceReaderFromMediaSource" );
+ return hr;
+ }
+
+ captureParams = cp;
+ mediaFoundation = owner;
+ return S_OK;
+ }
+}
+
+HRESULT __stdcall Whisper::captureOpen( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) noexcept
+{
+ if( nullptr == endpoint || nullptr == pp )
+ return E_POINTER;
+
+ ComLight::CComPtr<ComLight::Object<Capture>> res;
+ CHECK( ComLight::Object<Capture>::create( res ) );
+ CHECK( res->open( owner, endpoint, captureParams ) );
+
+ res.detach( pp );
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/MF/AudioCapture.h b/Whisper/MF/AudioCapture.h
new file mode 100644
index 0000000..276ee4b
--- /dev/null
+++ b/Whisper/MF/AudioCapture.h
@@ -0,0 +1,12 @@
+#pragma once
+#include "../API/MfStructs.h"
+
+namespace Whisper
+{
+ struct iAudioCapture;
+ struct iMediaFoundation;
+
+ HRESULT __stdcall captureDeviceList( pfnFoundCaptureDevices pfn, void* pv );
+
+ HRESULT __stdcall captureOpen( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) noexcept;
+} \ No newline at end of file
diff --git a/Whisper/MF/MediaFoundation.cpp b/Whisper/MF/MediaFoundation.cpp
new file mode 100644
index 0000000..4a4f6a2
--- /dev/null
+++ b/Whisper/MF/MediaFoundation.cpp
@@ -0,0 +1,109 @@
+#include "stdafx.h"
+#include "../API/iMediaFoundation.cl.h"
+#include "mfStartup.h"
+#include "../ComLightLib/comLightServer.h"
+#include "loadAudioFile.h"
+#include <mfidl.h>
+#include <mfreadwrite.h>
+#include "mfUtils.h"
+#include "AudioCapture.h"
+
+namespace Whisper
+{
+ class AudioReader : public ComLight::ObjectRoot<iAudioReader>
+ {
+ CComPtr<IMFSourceReader> reader;
+ bool wantStereo;
+ CComPtr<iMediaFoundation> mediaFoundation;
+
+ HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const noexcept override final
+ {
+ if( pp == nullptr )
+ return E_POINTER;
+ CComPtr<IMFSourceReader> res = reader;
+ *pp = res.Detach();;
+ return S_OK;
+ }
+ HRESULT COMLIGHTCALL requestedStereo() const noexcept override final
+ {
+ return wantStereo ? S_OK : S_FALSE;
+ }
+ HRESULT COMLIGHTCALL getDuration( int64_t& rdi ) const noexcept override final
+ {
+ if( reader )
+ return getStreamDuration( reader, rdi );
+ return OLE_E_BLANK;
+ }
+ public:
+ HRESULT open( iMediaFoundation* owner, LPCTSTR path, bool stereo )
+ {
+ HRESULT hr = MFCreateSourceReaderFromURL( path, nullptr, &reader );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"MFCreateSourceReaderFromURL failed" );
+ return hr;
+ }
+ wantStereo = stereo;
+ mediaFoundation = owner;
+ logDebug16( L"Created source reader from the file \"%s\"", path );
+ return S_OK;
+ }
+ };
+
+ class MediaFoundation : public ComLight::ObjectRoot<iMediaFoundation>
+ {
+ MfStartupRaii raii;
+ DWORD tid = ~(DWORD)0;
+
+ virtual HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ) const noexcept override final
+ {
+ return Whisper::loadAudioFile( path, stereo, pp );
+ }
+ virtual HRESULT COMLIGHTCALL openAudioFile( LPCTSTR path, bool stereo, iAudioReader** pp ) noexcept override final
+ {
+ if( nullptr == path || nullptr == pp )
+ return E_POINTER;
+
+ ComLight::CComPtr<ComLight::Object<AudioReader>> res;
+ CHECK( ComLight::Object<AudioReader>::create( res ) );
+ CHECK( res->open( this, path, stereo ) );
+
+ res.detach( pp );
+ return S_OK;
+ }
+ HRESULT COMLIGHTCALL listCaptureDevices( pfnFoundCaptureDevices pfn, void* pv ) noexcept override final
+ {
+ return captureDeviceList( pfn, pv );
+ }
+ HRESULT COMLIGHTCALL openCaptureDevice( LPCTSTR endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) noexcept override final
+ {
+ return captureOpen( this, endpoint, captureParams, pp );
+ }
+ protected:
+
+ HRESULT FinalConstruct()
+ {
+ CHECK( raii.startup() );
+ tid = GetCurrentThreadId();
+ return S_OK;
+ }
+
+ public:
+
+ ~MediaFoundation() override
+ {
+ assert( tid == GetCurrentThreadId() );
+ }
+ };
+}
+
+HRESULT COMLIGHTCALL Whisper::initMediaFoundation( iMediaFoundation** pp )
+{
+ if( nullptr == pp )
+ return E_POINTER;
+
+ ComLight::CComPtr<ComLight::Object<MediaFoundation>> obj;
+ CHECK( ComLight::Object<MediaFoundation>::create( obj ) );
+ obj.detach( pp );
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/MF/PcmReader.cpp b/Whisper/MF/PcmReader.cpp
new file mode 100644
index 0000000..ab92fc3
--- /dev/null
+++ b/Whisper/MF/PcmReader.cpp
@@ -0,0 +1,274 @@
+#include "stdafx.h"
+#include "PcmReader.h"
+#include <mfapi.h>
+#include "mfUtils.h"
+
+namespace Whisper
+{
+ __interface iSampleHandler
+ {
+ void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, PcmStereoChunk* pStereo ) const;
+ void moveBufferData( AudioBuffer& rdi, size_t amount ) const;
+ void appendPcm( AudioBuffer& rdi, const float* rsi, size_t countFloats ) const;
+ void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, size_t samples, PcmStereoChunk* pStereo ) const;
+ uint32_t readerChannelsCount() const;
+ };
+}
+
+namespace
+{
+ using namespace Whisper;
+
+ __forceinline void copyMono( PcmMonoChunk* rdi, const AudioBuffer& rsi, size_t sourceOffset, size_t samples )
+ {
+ assert( sourceOffset + samples <= rsi.mono.size() );
+ memcpy( rdi->mono.data(), &rsi.mono[ sourceOffset ], samples * 4 );
+ if( samples < FFT_STEP )
+ memset( rdi->mono.data() + samples, 0, ( FFT_STEP - samples ) * 4 );
+ }
+
+ __forceinline void copyStereo( PcmStereoChunk* rdi, const AudioBuffer& rsi, size_t sourceOffset, size_t samples )
+ {
+ memcpy( rdi->stereo.data(), &rsi.stereo[ sourceOffset * 2 ], samples * 8 );
+ if( samples < FFT_STEP )
+ memset( rdi->stereo.data() + samples * 2, 0, ( FFT_STEP - samples ) * 8 );
+ }
+
+ struct HandlerMono : iSampleHandler
+ {
+ void appendPcm( AudioBuffer& rdi, const float* rsi, size_t countFloats ) const override
+ {
+ rdi.appendMono( rsi, countFloats );
+ }
+ void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, PcmStereoChunk* pStereo ) const override final
+ {
+ copyMono( pMono, rsi, sourceOffset, FFT_STEP );
+ }
+ void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, size_t samples, PcmStereoChunk* pStereo ) const override final
+ {
+ copyMono( pMono, rsi, sourceOffset, samples );
+ }
+ void moveBufferData( AudioBuffer& rdi, size_t amount ) const override final
+ {
+ const size_t len = rdi.mono.size();
+ assert( amount <= len );
+ if( amount < len )
+ {
+ const size_t block = len - amount;
+ memmove( rdi.mono.data(), rdi.mono.data() + amount, block * 4 );
+ rdi.mono.resize( block );
+ }
+ else
+ rdi.mono.clear();
+ }
+ uint32_t readerChannelsCount() const override { return 1; }
+ };
+ struct HandlerDownmixedStereo : HandlerMono
+ {
+ void appendPcm( AudioBuffer& rdi, const float* rsi, size_t countFloats ) const override final
+ {
+ rdi.appendDownmixedStereo( rsi, countFloats );
+ }
+ uint32_t readerChannelsCount() const override final { return 2; }
+ };
+ struct HandlerStereo : iSampleHandler
+ {
+ void appendPcm( AudioBuffer& rdi, const float* rsi, size_t countFloats ) const override final
+ {
+ rdi.appendStereo( rsi, countFloats );
+ }
+ void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, PcmStereoChunk* pStereo ) const override final
+ {
+ copyMono( pMono, rsi, sourceOffset, FFT_STEP );
+ copyStereo( pStereo, rsi, sourceOffset, FFT_STEP );
+ }
+ void copyChunk( PcmMonoChunk* pMono, const AudioBuffer& rsi, size_t sourceOffset, size_t samples, PcmStereoChunk* pStereo ) const override final
+ {
+ copyMono( pMono, rsi, sourceOffset, samples );
+ copyStereo( pStereo, rsi, sourceOffset, samples );
+ }
+ void moveBufferData( AudioBuffer& rdi, size_t amount ) const override final
+ {
+ const size_t len = rdi.mono.size();
+ assert( amount <= len );
+ if( amount < len )
+ {
+ const size_t block = len - amount;
+ memmove( rdi.mono.data(), rdi.mono.data() + amount, block * 4 );
+ rdi.mono.resize( block );
+ memmove( rdi.stereo.data(), rdi.stereo.data() + amount * 2, block * 8 );
+ rdi.mono.resize( block * 2 );
+ }
+ else
+ {
+ rdi.mono.clear();
+ rdi.stereo.clear();
+ }
+ }
+ uint32_t readerChannelsCount() const override final { return 2; }
+ };
+ static const HandlerMono s_mono;
+ static const HandlerDownmixedStereo s_downmix;
+ static const HandlerStereo s_stereo;
+}
+
+PcmReader::PcmReader( IMFSourceReader* reader, bool stereo )
+{
+ if( nullptr == reader )
+ throw E_POINTER;
+ this->reader = reader;
+
+ // Set up media type, and figure out sample handler
+ check( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) );
+ check( reader->SetStreamSelection( MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE ) );
+
+ CComPtr<IMFMediaType> mtNative;
+ check( reader->GetNativeMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, MF_SOURCE_READER_CURRENT_TYPE_INDEX, &mtNative ) );
+ UINT32 numChannels;
+ check( mtNative->GetUINT32( MF_MT_AUDIO_NUM_CHANNELS, &numChannels ) );
+
+ const bool sourceMono = numChannels < 2;
+ if( sourceMono )
+ sampleHandler = &s_mono;
+ else if( !stereo )
+ sampleHandler = &s_downmix;
+ else
+ {
+ sampleHandler = &s_stereo;
+ m_stereoOutput = true;
+ }
+
+ CComPtr<IMFMediaType> mt;
+ check( createMediaType( !sourceMono, &mt ) );
+ check( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) );
+
+ // Find out the length
+ int64_t durationTicks;
+ check( getStreamDuration( reader, durationTicks ) );
+
+ // Convert length to chunks
+ // Seconds = Ticks / 10^7
+ // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7
+ // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down
+ constexpr __int64 mul = SAMPLE_RATE;
+ constexpr __int64 div = (__int64)FFT_STEP * 10'000'000;
+ m_length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 );
+}
+
+HRESULT PcmReader::readNextSample()
+{
+ const size_t off = bufferReadOffset;
+ const size_t availableSamples = pcm.mono.size() - off;
+
+ // If needed, move the remaining PCM data to the start of these vectors
+ if( availableSamples > 0 )
+ {
+ if( 0 != off )
+ sampleHandler->moveBufferData( pcm, off );
+ }
+ else
+ pcm.clear();
+ bufferReadOffset = 0;
+
+ while( true )
+ {
+ DWORD dwFlags = 0;
+ CComPtr<IMFSample> sample;
+
+ // Read the next sample
+ HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"IMFSourceReader.ReadSample" );
+ return hr;
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED )
+ {
+ // logError( u8"Media type changes ain’t supported by the library." );
+ // return E_UNEXPECTED;
+
+ // This happens for some video files at the very start of the reading, with Dolby AC3 audio track.
+ // Instead of failing the transcribe process, verify the important attributes (FP32 samples, sample rate, count of channels) haven’t changed.
+ CHECK( validateCurrentMediaType( reader, sampleHandler->readerChannelsCount() ) );
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM )
+ return E_EOF;
+
+ if( !sample )
+ {
+ // printf( "No sample\n" );
+ continue;
+ }
+
+ // Get a pointer to the audio data in the sample.
+ CComPtr<IMFMediaBuffer> buffer;
+ hr = sample->ConvertToContiguousBuffer( &buffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ const float* pAudioData = nullptr;
+ DWORD cbBuffer;
+ hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ try
+ {
+ assert( 0 == ( cbBuffer % sizeof( float ) ) );
+ const size_t countFloats = cbBuffer / sizeof( float );
+ sampleHandler->appendPcm( pcm, pAudioData, countFloats );
+ }
+ catch( const std::bad_alloc& )
+ {
+ buffer->Unlock();
+ return E_OUTOFMEMORY;
+ }
+
+ // Unlock the buffer
+ hr = buffer->Unlock();
+ if( FAILED( hr ) )
+ return hr;
+
+ return S_OK;
+ }
+}
+
+HRESULT PcmReader::readChunk( PcmMonoChunk& mono, PcmStereoChunk* stereo )
+{
+ while( true )
+ {
+ const size_t off = bufferReadOffset;
+ const size_t availableSamples = pcm.mono.size() - off;
+ if( availableSamples >= FFT_STEP )
+ {
+ // We have enough data in the buffer
+ sampleHandler->copyChunk( &mono, pcm, off, stereo );
+ bufferReadOffset = off + FFT_STEP;
+ return S_OK;
+ }
+
+ if( !m_readerEndOfFile )
+ {
+ // We don't have enough data, but the stream has not ended yet, can load moar samples from the reader
+ HRESULT hr = readNextSample();
+ if( SUCCEEDED( hr ) )
+ continue;
+ if( hr != E_EOF )
+ return hr;
+ m_readerEndOfFile = true;
+ }
+
+ if( availableSamples > 0 )
+ {
+ // We have reached the end of stream of the reader, but the buffer still has a few samples.
+ // Return the final incomplete chunk padded with zeros
+ sampleHandler->copyChunk( &mono, pcm, off, availableSamples, stereo );
+ bufferReadOffset = off + availableSamples;
+ return S_OK;
+ }
+
+ return E_EOF;
+ }
+} \ No newline at end of file
diff --git a/Whisper/MF/PcmReader.h b/Whisper/MF/PcmReader.h
new file mode 100644
index 0000000..9e3757e
--- /dev/null
+++ b/Whisper/MF/PcmReader.h
@@ -0,0 +1,63 @@
+#pragma once
+#include "../Whisper/audioConstants.h"
+#include <mfidl.h>
+#include <mfreadwrite.h>
+#include "AudioBuffer.h"
+
+namespace Whisper
+{
+ // PCM buffer with 10 milliseconds of single-channel audio
+ struct PcmMonoChunk
+ {
+ std::array<float, FFT_STEP> mono;
+ };
+ // PCM buffer with 10 milliseconds of interleaved stereo
+ struct PcmStereoChunk
+ {
+ std::array<float, FFT_STEP * 2> stereo;
+ };
+
+ __interface iSampleHandler;
+
+ constexpr HRESULT E_EOF = HRESULT_FROM_WIN32( ERROR_HANDLE_EOF );
+
+ // Utility class which reads chunks of FFT_STEP FP32 PCM samples from the MF source reader
+ // The class always delivers mono chunks, and can optionally deliver stereo in a separate buffer.
+ class PcmReader
+ {
+ // A small intermediate buffer with PCM data for complete media foundation samples
+ AudioBuffer pcm;
+ // Index of the first unconsumed sample in the pcm buffer
+ size_t bufferReadOffset = 0;
+ // Utility object to abstract away mono versus stereo shenanigans
+ const iSampleHandler* sampleHandler;
+ // The underlying MF source reader which delivers audio data
+ CComPtr<IMFSourceReader> reader;
+ // True after we consumed all available media samples from the reader
+ bool m_readerEndOfFile = false;
+ // True if this object delivers stereo samples
+ bool m_stereoOutput = false;
+ // The count of chunks we expect to get from the reader
+ size_t m_length = 0;
+ // Read next sample from the reader, store in the PCM buffer in this class
+ HRESULT readNextSample();
+
+ public:
+
+ PcmReader( IMFSourceReader* source, bool stereo );
+
+ // Count of chunks in the MEL spectrogram.
+ // The PCM audio is generally slightly longer than that, due to the incomplete last chunk.
+ size_t getLength() const noexcept
+ {
+ return m_length;
+ }
+
+ // True when the stereo flag passed to constructor, and the audio stream actually has 2 or more audio channels
+ bool outputsStereo() const { return m_stereoOutput; }
+
+ // Load another 10ms chunk from the stream
+ // For the last chunk in the stream, the output buffers are padded with zeros
+ HRESULT readChunk( PcmMonoChunk& mono, PcmStereoChunk* stereo );
+ };
+} \ No newline at end of file
diff --git a/Whisper/MF/loadAudioFile.cpp b/Whisper/MF/loadAudioFile.cpp
new file mode 100644
index 0000000..d1a439a
--- /dev/null
+++ b/Whisper/MF/loadAudioFile.cpp
@@ -0,0 +1,151 @@
+#include "stdafx.h"
+#include "../ComLightLib/comLightServer.h"
+#include "loadAudioFile.h"
+#include "mfUtils.h"
+#include "AudioBuffer.h"
+#include <mfidl.h>
+#include <mfreadwrite.h>
+#include <mfapi.h>
+#pragma comment(lib, "Mfreadwrite.lib")
+#pragma comment(lib, "mfuuid.lib")
+
+namespace Whisper
+{
+ class MediaFileBuffer : public ComLight::ObjectRoot<iAudioBuffer>
+ {
+ AudioBuffer pcm;
+ uint32_t channels = 0;
+
+ uint32_t COMLIGHTCALL countSamples() const noexcept override final
+ {
+ return (uint32_t)( pcm.mono.size() );
+ }
+ const float* COMLIGHTCALL getPcmMono() const noexcept override final
+ {
+ if( !pcm.mono.empty() )
+ return pcm.mono.data();
+ return nullptr;
+ }
+ const float* COMLIGHTCALL getPcmStereo() const noexcept override final
+ {
+ if( !pcm.stereo.empty() )
+ return pcm.stereo.data();
+ return nullptr;
+ }
+ HRESULT COMLIGHTCALL getTime( int64_t& rdi ) const noexcept override final
+ {
+ rdi = 0;
+ return S_OK;
+ }
+ public:
+ HRESULT load( LPCTSTR path, bool stereo );
+ };
+
+ HRESULT MediaFileBuffer::load( LPCTSTR path, bool stereo )
+ {
+ CComPtr<IMFSourceReader> reader;
+ HRESULT hr = MFCreateSourceReaderFromURL( path, nullptr, &reader );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"MFCreateSourceReaderFromURL failed" );
+ return hr;
+ }
+
+ CHECK( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) );
+ CHECK( reader->SetStreamSelection( MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE ) );
+
+ CComPtr<IMFMediaType> mtNative;
+ CHECK( reader->GetNativeMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, MF_SOURCE_READER_CURRENT_TYPE_INDEX, &mtNative ) );
+ UINT32 numChannels;
+ CHECK( mtNative->GetUINT32( MF_MT_AUDIO_NUM_CHANNELS, &numChannels ) );
+ const bool sourceMono = numChannels == 1;
+ const AudioBuffer::pfnAppendSamples pfn = AudioBuffer::appendSamplesFunc( sourceMono, stereo );
+ channels = ( stereo && !sourceMono ) ? 2 : 1;
+
+ CComPtr<IMFMediaType> mt;
+ CHECK( createMediaType( !sourceMono, &mt ) );
+
+ CHECK( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) );
+
+ while( true )
+ {
+ DWORD dwFlags = 0;
+ CComPtr<IMFSample> sample;
+
+ // Read the next sample.
+ hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"IMFSourceReader.ReadSample" );
+ return hr;
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED )
+ {
+ logError( u8"Media type changes ain’t supported by the library." );
+ return E_UNEXPECTED;
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM )
+ break;
+
+ if( !sample )
+ {
+ // printf( "No sample\n" );
+ continue;
+ }
+
+ // Get a pointer to the audio data in the sample.
+ CComPtr<IMFMediaBuffer> buffer;
+ hr = sample->ConvertToContiguousBuffer( &buffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ const float* pAudioData = nullptr;
+ DWORD cbBuffer;
+ hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ try
+ {
+ const size_t countFloats = cbBuffer / sizeof( float );
+ ( pcm.*pfn )( pAudioData, countFloats );
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+
+ // Unlock the buffer
+ hr = buffer->Unlock();
+ if( FAILED( hr ) )
+ return hr;
+ }
+
+ const size_t len = pcm.mono.size();
+ if( len == 0 )
+ {
+ logError16( L"The audio file \"%s\" has no samples", path );
+ return E_INVALIDARG;
+ }
+ if( len < SAMPLE_RATE / 2 )
+ logError16( L"The file \"%s\" only has %zu samples, less than 0.5 seconds of audio", path, len );
+ else
+ logDebug16( L"Loaded audio file from \"%s\": %zu samples, %g seconds", path, len, (int)len * ( 1.0 / SAMPLE_RATE ) );
+ return S_OK;
+
+ }
+}
+
+HRESULT COMLIGHTCALL Whisper::loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp )
+{
+ if( nullptr == path || nullptr == pp )
+ return E_POINTER;
+
+ ComLight::CComPtr<ComLight::Object<MediaFileBuffer>> obj;
+ CHECK( ComLight::Object<MediaFileBuffer>::create( obj ) );
+ CHECK( obj->load( path, stereo ) );
+ obj.detach( pp );
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/MF/loadAudioFile.h b/Whisper/MF/loadAudioFile.h
new file mode 100644
index 0000000..9736ccd
--- /dev/null
+++ b/Whisper/MF/loadAudioFile.h
@@ -0,0 +1,7 @@
+#pragma once
+#include "../API/iMediaFoundation.cl.h"
+
+namespace Whisper
+{
+ HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp );
+} \ No newline at end of file
diff --git a/Whisper/MF/mfStartup.cpp b/Whisper/MF/mfStartup.cpp
new file mode 100644
index 0000000..b7ab829
--- /dev/null
+++ b/Whisper/MF/mfStartup.cpp
@@ -0,0 +1,128 @@
+#include "stdafx.h"
+#include "mfStartup.h"
+#include <atlbase.h>
+#include <mfapi.h>
+#pragma comment(lib, "Mfplat.lib")
+
+namespace
+{
+ struct sCoInitStatus
+ {
+ // Possible state:
+ // -1 is the initial state, coInitialize never called
+ // S_OK - CoInitializeEx succeeded, in this state the counter tracks the count of coInitialize() for the current thread
+ // S_FALSE - CoInitializeEx failed with RPC_E_CHANGED_MODE, or did nothing because already initialized for the current thread
+ // Error status - CoInitializeEx failed for some other reason
+ HRESULT code = -1;
+ uint32_t counter = 0;
+ };
+ thread_local sCoInitStatus coInitStatus;
+
+ static HRESULT coInitialize()
+ {
+ sCoInitStatus& cis = coInitStatus;
+ HRESULT hr = cis.code;
+ if( SUCCEEDED( hr ) )
+ {
+ if( S_OK == hr )
+ cis.counter++;
+ return S_FALSE;
+ }
+
+ if( hr == HRESULT( -1 ) )
+ {
+ hr = CoInitializeEx( nullptr, COINIT_MULTITHREADED );
+ if( S_OK == hr )
+ {
+ cis.counter = 1;
+ return cis.code = S_OK;
+ }
+ if( S_FALSE == hr || RPC_E_CHANGED_MODE == hr )
+ {
+ return cis.code = S_FALSE;
+ }
+ cis.code = hr;
+ return hr;
+ }
+
+ return hr;
+ }
+
+ static void coUninitialize()
+ {
+ sCoInitStatus& cis = coInitStatus;
+ if( cis.code == S_OK )
+ {
+ assert( cis.counter > 0 );
+ cis.counter--;
+ if( 0 == cis.counter )
+ CoUninitialize();
+ }
+ }
+
+ static CComAutoCriticalSection s_lock;
+#define LOCK() CComCritSecLock<CComAutoCriticalSection> lock{ s_lock }
+ static uint32_t mfStartupCounter = 0;
+
+ constexpr uint8_t FlagCOM = 1;
+ constexpr uint8_t FlagMF = 0x10;
+}
+
+using namespace Whisper;
+
+MfStartupRaii::~MfStartupRaii()
+{
+ if( 0 != ( successFlags & FlagMF ) )
+ {
+ LOCK();
+ assert( mfStartupCounter > 0 );
+ mfStartupCounter--;
+ if( mfStartupCounter > 0 )
+ return;
+ MFShutdown();
+ successFlags &= ~FlagMF;
+ }
+
+ if( 0 != ( successFlags & FlagCOM ) )
+ {
+ coUninitialize();
+ successFlags &= ~FlagCOM;
+ }
+}
+
+HRESULT MfStartupRaii::startup()
+{
+ if( 0 != ( successFlags & FlagMF ) )
+ return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
+
+ HRESULT hr = coInitialize();
+ CHECK( hr );
+ if( hr == S_OK )
+ successFlags |= FlagCOM;
+
+ LOCK();
+
+ if( 0 == mfStartupCounter )
+ {
+ HRESULT hr = MFStartup( MF_VERSION, MFSTARTUP_LITE );
+ if( SUCCEEDED( hr ) )
+ {
+ mfStartupCounter = 1;
+ successFlags |= FlagMF;
+ return S_OK;
+ }
+
+ if( 0 != ( successFlags & FlagCOM ) )
+ {
+ coUninitialize();
+ successFlags &= ~FlagCOM;
+ }
+ return hr;
+ }
+ else
+ {
+ mfStartupCounter++;
+ successFlags |= FlagMF;
+ return S_FALSE;
+ }
+} \ No newline at end of file
diff --git a/Whisper/MF/mfStartup.h b/Whisper/MF/mfStartup.h
new file mode 100644
index 0000000..1434ffc
--- /dev/null
+++ b/Whisper/MF/mfStartup.h
@@ -0,0 +1,15 @@
+#pragma once
+
+namespace Whisper
+{
+ class MfStartupRaii
+ {
+ uint8_t successFlags = 0;
+ public:
+ MfStartupRaii() = default;
+ ~MfStartupRaii();
+ MfStartupRaii( const MfStartupRaii& ) = delete;
+
+ HRESULT startup();
+ };
+} \ No newline at end of file
diff --git a/Whisper/MF/mfUtils.cpp b/Whisper/MF/mfUtils.cpp
new file mode 100644
index 0000000..e739079
--- /dev/null
+++ b/Whisper/MF/mfUtils.cpp
@@ -0,0 +1,69 @@
+#include "stdafx.h"
+#include "mfUtils.h"
+#include <mfapi.h>
+
+HRESULT Whisper::createMediaType( bool stereo, IMFMediaType** pp )
+{
+ if( nullptr == pp )
+ return E_POINTER;
+
+ CComPtr<IMFMediaType> mt;
+ CHECK( MFCreateMediaType( &mt ) );
+ CHECK( mt->SetGUID( MF_MT_MAJOR_TYPE, MFMediaType_Audio ) );
+ CHECK( mt->SetGUID( MF_MT_SUBTYPE, MFAudioFormat_Float ) );
+ CHECK( mt->SetUINT32( MF_MT_AUDIO_SAMPLES_PER_SECOND, SAMPLE_RATE ) );
+
+ const uint32_t channels = stereo ? 2 : 1;
+ CHECK( mt->SetUINT32( MF_MT_AUDIO_NUM_CHANNELS, channels ) );
+ CHECK( mt->SetUINT32( MF_MT_AUDIO_BLOCK_ALIGNMENT, channels * 4 ) );
+ CHECK( mt->SetUINT32( MF_MT_AUDIO_AVG_BYTES_PER_SECOND, channels * 4 * SAMPLE_RATE ) );
+ CHECK( mt->SetUINT32( MF_MT_AUDIO_BITS_PER_SAMPLE, 32 ) );
+ CHECK( mt->SetUINT32( MF_MT_ALL_SAMPLES_INDEPENDENT, TRUE ) );
+
+ *pp = mt.Detach();
+
+ return S_OK;
+}
+
+HRESULT Whisper::getStreamDuration( IMFSourceReader* reader, int64_t& duration )
+{
+ PROPVARIANT var;
+ PropVariantInit( &var );
+ CHECK( reader->GetPresentationAttribute( MF_SOURCE_READER_MEDIASOURCE, MF_PD_DURATION, &var ) );
+
+ if( var.vt == VT_UI8 )
+ {
+ // The documentation says the type of that attribute is UINT64
+ // https://learn.microsoft.com/en-us/windows/win32/medfound/mf-pd-duration-attribute
+ duration = var.uhVal.QuadPart;
+ return S_OK;
+ }
+ logError( u8"Unexpected type of MF_PD_DURATION attribute" );
+ return E_INVALIDARG;
+}
+
+HRESULT Whisper::validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels )
+{
+ CComPtr<IMFMediaType> mt;
+ CHECK( reader->GetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, &mt ) );
+
+ GUID guid;
+ CHECK( mt->GetGUID( MF_MT_MAJOR_TYPE, &guid ) );
+ if( guid != MFMediaType_Audio )
+ return E_FAIL;
+
+ CHECK( mt->GetGUID( MF_MT_SUBTYPE, &guid ) );
+ if( guid != MFAudioFormat_Float )
+ return E_FAIL;
+
+ UINT32 u32;
+ CHECK( mt->GetUINT32( MF_MT_AUDIO_SAMPLES_PER_SECOND, &u32 ) );
+ if( u32 != SAMPLE_RATE )
+ return E_FAIL;
+
+ CHECK( mt->GetUINT32( MF_MT_AUDIO_NUM_CHANNELS, &u32 ) );
+ if( u32 != expectedChannels )
+ return E_FAIL;
+
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/MF/mfUtils.h b/Whisper/MF/mfUtils.h
new file mode 100644
index 0000000..c889a92
--- /dev/null
+++ b/Whisper/MF/mfUtils.h
@@ -0,0 +1,15 @@
+#pragma once
+#include <stdint.h>
+#include <mfidl.h>
+#include <mfobjects.h>
+#include <mfreadwrite.h>
+#include "../Whisper/audioConstants.h"
+
+namespace Whisper
+{
+ HRESULT createMediaType( bool stereo, IMFMediaType** pp );
+
+ HRESULT getStreamDuration( IMFSourceReader* reader, int64_t& duration );
+
+ HRESULT validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels );
+} \ No newline at end of file
diff --git a/Whisper/ML/ConstantBuffer.cpp b/Whisper/ML/ConstantBuffer.cpp
new file mode 100644
index 0000000..5f3bfbe
--- /dev/null
+++ b/Whisper/ML/ConstantBuffer.cpp
@@ -0,0 +1,63 @@
+#include "stdafx.h"
+#include "ConstantBuffer.h"
+#include "../D3D/MappedResource.h"
+using namespace DirectCompute;
+
+HRESULT ConstantBuffer::create()
+{
+ if( nullptr == buffer )
+ {
+ CD3D11_BUFFER_DESC desc{ 16 * 3 * 2, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE };
+ return device()->CreateBuffer( &desc, nullptr, &buffer );
+ }
+ return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
+}
+
+namespace
+{
+ __forceinline void copy32( __m128i* rdi, const TensorShape& ts )
+ {
+ _mm_storeu_si128( rdi, ts.sizeVec() );
+ _mm_storeu_si128( rdi + 1, ts.stridesVec() );
+ }
+}
+
+HRESULT ConstantBuffer::update( const TensorShape& t0 )
+{
+ MappedResource mapped;
+ CHECK( mapped.map( buffer, false ) );
+
+ __m128i* const rdi = ( __m128i* )mapped.data();
+ copy32( rdi, t0 );
+ return S_OK;
+}
+
+HRESULT ConstantBuffer::update( const TensorShape& t0, const TensorShape& t1 )
+{
+ MappedResource mapped;
+ CHECK( mapped.map( buffer, false ) );
+
+ __m128i* const rdi = ( __m128i* )mapped.data();
+ copy32( rdi, t0 );
+ copy32( rdi + 2, t1 );
+ return S_OK;
+}
+
+HRESULT ConstantBuffer::update( const TensorShape& t0, const TensorShape& t1, const TensorShape& t2 )
+{
+ MappedResource mapped;
+ CHECK( mapped.map( buffer, false ) );
+
+ __m128i* const rdi = ( __m128i* )mapped.data();
+ copy32( rdi, t0 );
+ copy32( rdi + 2, t1 );
+ copy32( rdi + 4, t2 );
+ return S_OK;
+}
+
+void ConstantBuffer::bind() const
+{
+ ID3D11Buffer* p = buffer;
+ assert( nullptr != p );
+ context()->CSSetConstantBuffers( 0, 1, &p );
+} \ No newline at end of file
diff --git a/Whisper/ML/ConstantBuffer.h b/Whisper/ML/ConstantBuffer.h
new file mode 100644
index 0000000..3e6664d
--- /dev/null
+++ b/Whisper/ML/ConstantBuffer.h
@@ -0,0 +1,25 @@
+#pragma once
+#include "../D3D/device.h"
+#include "TensorShape.h"
+
+namespace DirectCompute
+{
+ // 96 bytes dynamic constant buffers, with dimensions and VRAM layout of 2-3 tensors
+ class ConstantBuffer
+ {
+ CComPtr<ID3D11Buffer> buffer;
+
+ public:
+ HRESULT create();
+ HRESULT update( const TensorShape& t0 );
+ HRESULT update( const TensorShape& t0, const TensorShape& t1 );
+ HRESULT update( const TensorShape& t0, const TensorShape& t1, const TensorShape& t2 );
+
+ void bind() const;
+
+ __m128i getMemoryUse() const
+ {
+ return bufferMemoryUsage( buffer );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/Context.ops.cpp b/Whisper/ML/Context.ops.cpp
new file mode 100644
index 0000000..7dfca9f
--- /dev/null
+++ b/Whisper/ML/Context.ops.cpp
@@ -0,0 +1,280 @@
+#include "stdafx.h"
+#include "MlContext.h"
+#include "testUtils.h"
+using namespace DirectCompute;
+
+Tensor MlContext::createTensor( eDataType type, const std::array<uint32_t, 4>& ne )
+{
+ Tensor res;
+ check( res.create( type, ne ) );
+ return res;
+}
+
+Tensor MlContext::createTensor( eDataType type, std::initializer_list<uint32_t> ne )
+{
+ size_t nDims = ne.size();
+ if( 0 == nDims || nDims > 4 )
+ throw E_INVALIDARG;
+ std::array<uint32_t, 4> arr;
+ for( size_t i = 0; i < nDims; i++ )
+ arr[ i ] = ne.begin()[ i ];
+ for( size_t i = nDims; i < 4; i++ )
+ arr[ i ] = 1;
+ return createTensor( type, arr );
+}
+
+Tensor MlContext::conv_1d_1s( const Tensor& a, const Tensor& b )
+{
+ assert( b.isMatrix() );
+ assert( a.ne[ 1 ] == b.ne[ 1 ] );
+ assert( a.ne[ 3 ] == 1 );
+
+ Tensor res = createTensor( eDataType::FP32, { b.ne[ 0 ], a.ne[ 2 ] } );
+
+ convolution( a, b, res );
+ return res;
+}
+
+Tensor MlContext::conv_1d_2s( const Tensor& a, const Tensor& b )
+{
+ assert( b.isMatrix() );
+ assert( a.ne[ 1 ] == b.ne[ 1 ] );
+ assert( a.ne[ 3 ] == 1 );
+
+ Tensor res = createTensor( eDataType::FP32, { b.ne[ 0 ] / 2, a.ne[ 2 ] } );
+#if 0
+ static PrintUniqueTensorSizes printSize( "conv_1d_2s" );
+ printSize.print( a, b );
+#endif
+ convolution2( a, b, res );
+ return res;
+}
+
+namespace
+{
+ inline bool canRepeat( const TensorShape& t0, const TensorShape& t1 )
+ {
+ return ( t1.ne[ 0 ] % t0.ne[ 0 ] == 0 ) &&
+ ( t1.ne[ 1 ] % t0.ne[ 1 ] == 0 ) &&
+ ( t1.ne[ 2 ] % t0.ne[ 2 ] == 0 ) &&
+ ( t1.ne[ 3 ] % t0.ne[ 3 ] == 0 );
+ }
+}
+
+Tensor MlContext::cwiseBinary( const Tensor& a, const Tensor& b, eComputeShader cs )
+{
+ assert( isSameShape( a, b ) );
+ Tensor res = createTensor( a.getType(), a.ne );
+ cwiseBinary( a, b, res, cs );
+ return res;
+}
+
+Tensor __declspec( noinline ) MlContext::view2d( const Tensor& a, uint32_t ne0, uint32_t ne1, uint32_t nb1, uint32_t offset )
+{
+ if( 0 != offset )
+ throw E_NOTIMPL;
+
+ Tensor res = a;
+ res.ne = { ne0, ne1, 1, 1 };
+
+ res.nb[ 1 ] = nb1;
+ res.nb[ 2 ] = res.nb[ 3 ] = nb1 * ne1;
+ return res;
+}
+
+Tensor MlContext::transpose( const Tensor& a )
+{
+ Tensor result = a;
+ std::swap( result.ne[ 0 ], result.ne[ 1 ] );
+ std::swap( result.nb[ 0 ], result.nb[ 1 ] );
+ return result;
+}
+
+Tensor MlContext::norm( const Tensor& a )
+{
+ Tensor res = createTensor( a.getType(), a.ne );
+ norm( a, res );
+ return res;
+}
+
+Tensor MlContext::mulMat( const Tensor& a, const Tensor& b )
+{
+ if( !canMulMat( a, b ) )
+ throw E_INVALIDARG;
+ Tensor res = createTensor( eDataType::FP32, { a.ne[ 1 ], b.ne[ 1 ], a.ne[ 2 ], b.ne[ 3 ] } );
+ if constexpr( enableInexactOptimizations )
+ mulMatTiled( a, b, res );
+ else
+ mulMat( a, b, res );
+#if 0
+ Tensor testTiled;
+ check( testTiled.create( eDataType::FP32, res.ne ) );
+ mulMatTiled( a, b, testTiled );
+
+ std::vector<float> current, tiled;
+ res.download( current );
+ testTiled.download( tiled );
+ sTensorDiff diff = computeDiff( current.data(), tiled.data(), current.size() );
+ diff.print( "mulMatTiled" );
+#endif
+ return res;
+}
+
+Tensor MlContext::mulMatEx( const Tensor& a, const Tensor& b, const char* tagName )
+{
+ if( !canMulMat( a, b ) )
+ throw E_INVALIDARG;
+ if( 0 != a.nb[ 0 ] )
+ throw E_INVALIDARG; // The first argument is expected to be pre-transposed
+
+ const uint16_t tag = profiler.setNextTag( tagName );
+
+ if( b.ne[ 1 ] != 1 )
+ {
+ if( b.nb[ 0 ] != 0 )
+ {
+ Tensor rhs = reshapePanels( b );
+ profiler.setNextTag( tag );
+ return mulMatTiledEx( a, rhs );
+ }
+ else
+ {
+ // Second argument already reshaped into these panels
+ return mulMatTiledEx( a, b );
+ }
+ }
+ else
+ {
+ if( 0 != b.nb[ 0 ] )
+ return mulMatByRowTiledEx( a, b );
+
+ // That shader requires classic VRAM layout of the second argument, gonna fail with pre-transposed one
+ throw E_INVALIDARG;
+ }
+}
+
+Tensor MlContext::permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 )
+{
+ assert( axis0 < 4 );
+ assert( axis1 < 4 );
+ assert( axis2 < 4 );
+ assert( axis3 < 4 );
+
+ assert( axis0 != axis1 );
+ assert( axis0 != axis2 );
+ assert( axis0 != axis3 );
+ assert( axis1 != axis2 );
+ assert( axis1 != axis3 );
+ assert( axis2 != axis3 );
+
+ Tensor res = a;
+ res.ne[ axis0 ] = a.ne[ 0 ];
+ res.ne[ axis1 ] = a.ne[ 1 ];
+ res.ne[ axis2 ] = a.ne[ 2 ];
+ res.ne[ axis3 ] = a.ne[ 3 ];
+
+ res.nb[ axis0 ] = a.nb[ 0 ];
+ res.nb[ axis1 ] = a.nb[ 1 ];
+ res.nb[ axis2 ] = a.nb[ 2 ];
+ res.nb[ axis3 ] = a.nb[ 3 ];
+ return res;
+}
+
+Tensor MlContext::flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, bool masked )
+{
+ if( !canMulMat( k, q ) )
+ throw E_INVALIDARG;
+
+ if constexpr( enableInexactOptimizations )
+ {
+ if( !masked )
+ {
+ profiler.setNextTag( "flashAttn.1" );
+ Tensor tmp = mulMat( k, q );
+
+ const float tempScale = (float)( 1.0 / sqrt( (double)(int)q.ne[ 0 ] ) );
+ softMax( tmp, tempScale );
+
+ profiler.setNextTag( "flashAttn.2" );
+ return mulMat( v, tmp );
+ }
+ }
+
+ Tensor res = createTensor( eDataType::FP32, q.ne );
+ flashAttention( q, k, v, res, masked );
+
+#if 0
+ Tensor tmpMat = mulMat( k, q );
+ float scale = (float)( 1.0 / sqrt( (double)(int)q.ne[ 0 ] ) );
+ softMax( tmpMat, scale );
+ Tensor testRes = mulMat( v, tmpMat );
+ computeDiff( res, testRes ).print( "flashAttention mulmat" );
+#endif
+
+ return res;
+}
+
+Tensor MlContext::copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size )
+{
+ const size_t dims = size.size();
+ if( 0 == dims || dims > 4 )
+ throw E_BOUNDS;
+
+ size_t nRequested = 1;
+ for( size_t i = 0; i < dims; i++ )
+ {
+ uint32_t n = size.begin()[ i ];
+ nRequested *= n;
+ }
+ if( nRequested != a.countElements() )
+ throw E_INVALIDARG;
+
+ const eDataType st = a.getType();
+ Tensor res;
+ if( a.isContinuous() && st == type )
+ {
+ // Same type, and it's dense - no need to call any compute shaders, equal to reshape
+ res = a;
+ for( size_t i = 0; i < dims; i++ )
+ res.ne[ i ] = size.begin()[ i ];;
+ for( size_t i = dims; i < 4; i++ )
+ res.ne[ i ] = 1;
+ res.setDenseStrides();
+ }
+ else
+ {
+ // Either converting non-continuous to continuous, or converting types
+ res = createTensor( type, size );
+ copyImpl( a, res, st == eDataType::FP32 && type == eDataType::FP16 );
+ }
+ return res;
+}
+
+void MlContext::copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size )
+{
+ assert( type == dest.getType() );
+
+ const size_t dims = size.size();
+ if( 0 == dims || dims > 4 )
+ throw E_BOUNDS;
+
+ size_t nRequested = 1;
+ for( size_t i = 0; i < dims; i++ )
+ {
+ uint32_t n = size.begin()[ i ];
+ nRequested *= n;
+ }
+ if( nRequested != a.countElements() || nRequested != dest.countElements() )
+ throw E_INVALIDARG;
+
+ // Reshape the destination
+ for( size_t i = 0; i < dims; i++ )
+ dest.ne[ i ] = size.begin()[ i ];
+ for( size_t i = dims; i < 4; i++ )
+ dest.ne[ i ] = 1;
+ dest.setDenseStrides();
+
+ // Call the shader
+ const eDataType st = a.getType();
+ copyImpl( a, dest, st == eDataType::FP32 && type == eDataType::FP16 );
+} \ No newline at end of file
diff --git a/Whisper/ML/LookupTables.cpp b/Whisper/ML/LookupTables.cpp
new file mode 100644
index 0000000..2fc1cc8
--- /dev/null
+++ b/Whisper/ML/LookupTables.cpp
@@ -0,0 +1,54 @@
+#include "stdafx.h"
+#include "LookupTables.h"
+#include "LookupTablesData.h"
+#include <memory>
+using namespace DirectCompute;
+
+namespace
+{
+ HRESULT uploadLookupTable( const std::array<uint16_t, 0x10000>& rsi, CComPtr<ID3D11ShaderResourceView>& rdi )
+ {
+ rdi = nullptr;
+ CComPtr<ID3D11Buffer> buffer;
+
+ CD3D11_BUFFER_DESC desc{ 0x10000 * 2, D3D11_BIND_SHADER_RESOURCE, D3D11_USAGE_IMMUTABLE };
+ D3D11_SUBRESOURCE_DATA srd{ rsi.data(), 0, 0 };
+ CHECK( device()->CreateBuffer( &desc, &srd, &buffer ) );
+
+ CD3D11_SHADER_RESOURCE_VIEW_DESC viewDesc{ D3D11_SRV_DIMENSION_BUFFER, DXGI_FORMAT_R16_UINT, 0, 0x10000 };
+ CHECK( device()->CreateShaderResourceView( buffer, &viewDesc, &rdi ) );
+
+ return S_OK;
+ }
+}
+
+HRESULT LookupTables::create()
+{
+ std::unique_ptr<LookupTablesData> data;
+ try
+ {
+ data = std::make_unique<LookupTablesData>();
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+
+ CHECK( uploadLookupTable( data->gelu, m_gelu ) );
+ CHECK( uploadLookupTable( data->exponent, m_exponent ) );
+
+ return S_OK;
+}
+
+void LookupTables::clear()
+{
+ m_gelu = nullptr;
+ m_exponent = nullptr;
+}
+
+__m128i LookupTables::getMemoryUsage() const
+{
+ __m128i v = resourceMemoryUsage( m_gelu );
+ v = _mm_add_epi64( v, resourceMemoryUsage( m_exponent ) );
+ return v;
+} \ No newline at end of file
diff --git a/Whisper/ML/LookupTables.h b/Whisper/ML/LookupTables.h
new file mode 100644
index 0000000..306b548
--- /dev/null
+++ b/Whisper/ML/LookupTables.h
@@ -0,0 +1,22 @@
+#pragma once
+#include "../D3D/device.h"
+
+namespace DirectCompute
+{
+ class LookupTables
+ {
+ CComPtr<ID3D11ShaderResourceView> m_gelu, m_exponent;
+
+ public:
+
+ HRESULT create();
+ void clear();
+ ID3D11ShaderResourceView* gelu() const { return m_gelu; }
+ ID3D11ShaderResourceView* exponent() const { return m_exponent; }
+
+ __m128i getMemoryUsage() const;
+ };
+
+ // Singleton instance, defined in mlStartup.cpp
+ extern const LookupTables& lookupTables;
+} \ No newline at end of file
diff --git a/Whisper/ML/LookupTablesData.cpp b/Whisper/ML/LookupTablesData.cpp
new file mode 100644
index 0000000..263bcf7
--- /dev/null
+++ b/Whisper/ML/LookupTablesData.cpp
@@ -0,0 +1,40 @@
+#include "stdafx.h"
+#include "LookupTablesData.h"
+#include <immintrin.h>
+using namespace DirectCompute;
+
+namespace
+{
+ inline float fp32( uint16_t f16 )
+ {
+ __m128i i = _mm_cvtsi32_si128( f16 );
+ __m128 f = _mm_cvtph_ps( i );
+ return _mm_cvtss_f32( f );
+ }
+
+ inline uint16_t fp16( float fp32 )
+ {
+ __m128 f = _mm_set_ss( fp32 );
+ __m128i i = _mm_cvtps_ph( f, 0 );
+ uint32_t res = (uint32_t)_mm_cvtsi128_si32( i );
+ return (uint16_t)res;
+ }
+
+ constexpr double GELU_COEF_A = 0.044715;
+ constexpr double SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+
+ inline float computeGelu( float x )
+ {
+ return (float)( 0.5 * x * ( 1.0 + tanh( SQRT_2_OVER_PI * x * ( 1.0 + GELU_COEF_A * x * x ) ) ) );
+ }
+}
+
+LookupTablesData::LookupTablesData()
+{
+ for( int i = 0; i < 0x10000; i++ )
+ {
+ const float f = fp32( i );
+ gelu[ i ] = fp16( computeGelu( f ) );
+ exponent[ i ] = fp16( (float)exp( f ) );
+ }
+} \ No newline at end of file
diff --git a/Whisper/ML/LookupTablesData.h b/Whisper/ML/LookupTablesData.h
new file mode 100644
index 0000000..aa9f8ae
--- /dev/null
+++ b/Whisper/ML/LookupTablesData.h
@@ -0,0 +1,14 @@
+#pragma once
+#include <stdint.h>
+#include <array>
+
+namespace DirectCompute
+{
+ struct LookupTablesData
+ {
+ std::array<uint16_t, 0x10000> gelu;
+ std::array<uint16_t, 0x10000> exponent;
+
+ LookupTablesData();
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/MlContext.cpp b/Whisper/ML/MlContext.cpp
new file mode 100644
index 0000000..5a29b85
--- /dev/null
+++ b/Whisper/ML/MlContext.cpp
@@ -0,0 +1,744 @@
+#include "stdafx.h"
+#include "MlContext.h"
+#include "../D3D/shaderNames.h"
+#include "LookupTables.h"
+#include "../D3D/shaders.h"
+#include "../D3D/Binder.h"
+#include "../D3D/MappedResource.h"
+#include "../D3D/downloadBuffer.h"
+#include "testUtils.h"
+#include "reshapedMultiply.h"
+using namespace DirectCompute;
+
+// TODO: change this to a field, and set to false when the GPU doesn't support FP64 math
+// Most notably, Intel has dropped the support recently:
+// https://www.intel.com/content/www/us/en/developer/articles/guide/lp-api-developer-optimization-guide.html#inpage-nav-3-8-undefined
+// "To improve power and performance", LOL
+constexpr bool usePreciseComputeShaders = true;
+
+MlContext::MlContext( Whisper::ProfileCollection& profileColl ) :
+ profiler( profileColl )
+{
+ check( cb.create() );
+ check( profiler.create() );
+}
+
+void MlContext::bindShader( eComputeShader cs )
+{
+ DirectCompute::bindShader( cs );
+ profiler.computeShader( cs );
+}
+
+void MlContext::mulMatDot( const Tensor& src0, const Tensor& src1, Tensor& res )
+{
+ const auto& size1 = src1.ne;
+ if( 1 != size1[ 3 ] )
+ throw E_UNEXPECTED;
+
+ const size_t tempLength = size1[ 0 ] * size1[ 1 ] * size1[ 2 ] * size1[ 3 ];
+ const TensorGpuViews& tempBuffer = temp.fp16( tempLength );
+ cb.bind();
+
+ bindShader( eComputeShader::mulMatDotReshape );
+ cb.update( src1 );
+ Binder bind;
+ bind.bind( src1, tempBuffer );
+ context()->Dispatch( size1[ 1 ], size1[ 2 ], 1 );
+
+ bindShader( eComputeShader::mulMatDotMain );
+ cb.update( src0, src1, res );
+ bind.bind( src0, tempBuffer, res );
+
+ const auto& size0 = src0.ne;
+ // total rows in src0
+ const uint32_t nr = size0[ 1 ] * size0[ 2 ] * size0[ 3 ];
+ context()->Dispatch( size1[ 1 ], nr, 1 );
+}
+
+void MlContext::mulMatMad( const Tensor& a, const Tensor& b, Tensor& res )
+{
+ // CaptureRaii renderDoc;
+ const uint32_t resultElts = res.countElements();
+ constexpr uint32_t nth = 4;
+
+ uint32_t fp16;
+ TensorGpuViews tempBuffer;
+
+ const eDataType dataType = a.getType();
+ if( dataType == eDataType::FP16 )
+ {
+ fp16 = TRUE;
+ tempBuffer = temp.fp16( resultElts * nth );
+ }
+ else if( dataType == eDataType::FP32 )
+ {
+ fp16 = FALSE;
+ tempBuffer = temp.fp32( resultElts * nth );
+ }
+ else
+ throw E_INVALIDARG;
+
+ TensorShape resultShape = res;
+ resultShape.nb = { fp16, resultElts, 0, 0 };
+
+ cb.update( a, b, resultShape );
+ bindShader( eComputeShader::mulMatMadMain );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( { a, b }, { res, tempBuffer } );
+ context()->Dispatch( b.ne[ 1 ], b.ne[ 2 ], b.ne[ 3 ] );
+}
+
+void MlContext::mulMatTiled( const Tensor& a, const Tensor& b, Tensor& res )
+{
+ cb.update( a, b, res );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( a, b, res );
+
+ if( b.ne[ 1 ] == 1 )
+ {
+ if( b.ne[ 0 ] != 1 )
+ {
+#if 0
+ static PrintUniqueTensorSizes printSize( "mulMatByRow" );
+ printSize.print( a, b );
+#endif
+ // Tensor B is a single row, we have optimized compute shaders for that use case
+ // Even 2 of them, tiled and sequential. Select between these two shaders.
+ constexpr uint32_t minHeightToTile = 2;
+ if( a.ne[ 1 ] < minHeightToTile )
+ {
+ bindShader( eComputeShader::mulMatByRow );
+ context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] );
+ }
+ else
+ {
+ bindShader( eComputeShader::mulMatByRowTiled );
+ uint32_t groupsX;
+ if( gpuInfo.wave64() )
+ {
+ constexpr uint32_t TILE_Y = 128;
+ groupsX = ( a.ne[ 1 ] + TILE_Y - 1 ) / TILE_Y;
+ }
+ else
+ {
+ constexpr uint32_t TILE_Y = 64;
+ groupsX = ( a.ne[ 1 ] + TILE_Y - 1 ) / TILE_Y;
+ }
+ context()->Dispatch( groupsX, a.ne[ 2 ], a.ne[ 3 ] );
+ }
+ }
+ else
+ {
+ // Tensor B is a single element: we have an optimized shader for that as well
+ bindShader( eComputeShader::mulMatByScalar );
+ context()->Dispatch( a.ne[ 2 ], a.ne[ 3 ], 1 );
+ }
+ }
+ else
+ {
+ // According to visual studio debugger, when the second argument of this method is a 2D matrix, the first argument is 2D as well.
+ // Assuming both arguments are 2D matrices.
+ // For optimal VRAM bandwidth utilization, we compute such matrix products in square tiles, a tile is 32x32 elements.
+ // Dispatching one thread group for each tile of the output matrix.
+ bindShader( eComputeShader::mulMatTiled );
+
+ uint32_t x, y;
+ // These compute shaders correctly handle partial tiles on the right and bottom edges of the output matrix, that's why rounding up.
+ if( gpuInfo.wave64() )
+ {
+ constexpr uint32_t TILE_SIZE = 64;
+ x = ( res.ne[ 0 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ y = ( res.ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ }
+ else
+ {
+ constexpr uint32_t TILE_SIZE = 32;
+ x = ( res.ne[ 0 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ y = ( res.ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ }
+
+ const uint32_t z = res.ne[ 2 ] * res.ne[ 3 ];
+ context()->Dispatch( x, y, z );
+ }
+}
+
+void MlContext::mulMat( const Tensor& src0, const Tensor& src1, Tensor& res )
+{
+ const uint32_t nb00 = src0.nb[ 0 ];
+ const uint32_t nb01 = src0.nb[ 1 ];
+ if( nb01 >= nb00 )
+ mulMatDot( src0, src1, res );
+ else
+ mulMatMad( src0, src1, res );
+}
+
+namespace
+{
+ // Must match the HLSL in flashAttention.hlsl
+ struct sFlashAttentionConstants
+ {
+ TensorShape q, k, v, res;
+ BOOL masked;
+ float scale;
+ uint32_t tempBufferStride;
+ uint32_t zzPadding;
+ };
+
+ struct sFlashAttnDispatchInfo
+ {
+ uint32_t tempStride;
+ uint32_t groupsCount;
+ };
+
+ sFlashAttnDispatchInfo makeFlashAttentionConstants( CComPtr<ID3D11Buffer>& buffer, const Tensor& q, const Tensor& k, const Tensor& v, Tensor& res, bool masked )
+ {
+ if( nullptr == buffer )
+ {
+ CD3D11_BUFFER_DESC desc{ sizeof( sFlashAttentionConstants ), D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE };
+ check( device()->CreateBuffer( &desc, nullptr, &buffer ) );
+ }
+
+ sFlashAttnDispatchInfo result;
+
+ sFlashAttentionConstants cb;
+ cb.q = q;
+ cb.k = k;
+ cb.v = v;
+ cb.res = res;
+ cb.masked = masked ? TRUE : FALSE;
+
+ const int neq0 = (int)cb.q.ne[ 0 ];
+ const int D = neq0;
+ cb.scale = (float)( 1.0 / sqrt( (double)(int)D ) );
+
+ const uint32_t nek1 = cb.k.ne[ 1 ];
+ constexpr uint32_t align = 32 / 4;
+ result.tempStride = ( ( nek1 + align - 1 ) / align ) * align;
+ cb.tempBufferStride = result.tempStride;
+ cb.zzPadding = 0;
+ result.groupsCount = cb.q.ne[ 1 ] * cb.q.ne[ 2 ] * cb.q.ne[ 3 ];
+
+ MappedResource mapped;
+ check( mapped.map( buffer, false ) );
+ memcpy( mapped.data(), &cb, sizeof( cb ) );
+ return result;
+ }
+}
+
+void MlContext::flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, Tensor& res, bool masked )
+{
+ sFlashAttnDispatchInfo di = makeFlashAttentionConstants( flashAttentionConstants, q, k, v, res, masked );
+
+ const uint32_t tempLength = di.tempStride * di.groupsCount;
+ const TensorGpuViews& tb = temp.fp32( tempLength );
+
+ csSetCB( flashAttentionConstants );
+ ID3D11DeviceContext* const ctx = context();
+
+ Binder bind;
+ bind.bind( { q, k, v, lookupTables.exponent() }, { res, tb } );
+
+ if constexpr( usePreciseComputeShaders && !enableInexactOptimizations )
+ {
+ bindShader( eComputeShader::flashAttentionCompat1 );
+ ctx->Dispatch( di.groupsCount, 1, 1 );
+
+ bindShader( eComputeShader::flashAttentionCompat2 );
+ ctx->Dispatch( ( di.groupsCount + 31 ) / 32, 1, 1 );
+
+ bindShader( eComputeShader::flashAttentionCompat3 );
+ ctx->Dispatch( di.groupsCount, 1, 1 );
+ }
+ else
+ {
+ // This version is not too bad, e.g. maxAbsDiff = 2.7895e-05, avgDiffSquared = 1.61783e-14
+ // And probably much faster.
+ // But still, it does not deliver bitwise equality with the reference CPU version
+ bindShader( eComputeShader::flashAttention );
+ ctx->Dispatch( di.groupsCount, 1, 1 );
+ }
+}
+
+namespace
+{
+ // Round up the number to be a multiple of 32
+ inline uint32_t roundUp32( uint32_t x )
+ {
+ return ( x + 31 ) & ( ~31u );
+ }
+}
+
+void MlContext::convolutionImpl( const Tensor& a, const Tensor& b, Tensor& res, bool is2 )
+{
+ const uint32_t ne00 = a.ne[ 0 ];
+ const uint32_t ne01 = a.ne[ 1 ];
+ const uint32_t ne02 = a.ne[ 2 ];
+
+ const uint32_t ne10 = b.ne[ 0 ];
+ const uint32_t ne11 = b.ne[ 1 ];
+
+ const uint32_t nb00 = a.nb[ 0 ];
+ const uint32_t nb01 = a.nb[ 1 ];
+ const uint32_t nb02 = a.nb[ 2 ];
+
+ const uint32_t nb10 = b.nb[ 0 ];
+ const uint32_t nb11 = b.nb[ 1 ];
+
+ const uint32_t nb1 = res.nb[ 1 ];
+
+ const uint32_t ew0 = roundUp32( ne01 );
+
+ const uint32_t nk = ne00;
+ const uint32_t nh = nk / 2;
+
+ const uint32_t lenTemp1 = ne02 * ew0 * ne00;
+ const uint32_t lenTemp2 = ( ne10 + ne00 ) * ew0;
+
+ const TensorGpuViews& temp1 = temp.fp16( lenTemp1, true );
+ const TensorGpuViews& temp2 = temp.fp16_2( lenTemp2, true );
+
+ cb.bind();
+
+ bindShader( eComputeShader::convolutionPrep1 );
+ cb.update( a );
+ Binder bind;
+ bind.bind( a, temp1 );
+ context()->Dispatch( ne01, ne02, 1 );
+
+ bindShader( eComputeShader::convolutionPrep2 );
+ cb.update( a, b );
+ bind.bind( b, temp2 );
+ context()->Dispatch( ne11, 1, 1 );
+
+ cb.update( a, b, res );
+ bind.bind( temp1, temp2, res );
+ if( is2 )
+ {
+ if constexpr( enableInexactOptimizations )
+ {
+ constexpr uint32_t KERNEL = 3;
+ constexpr uint32_t TILE_Y = 8;
+ if( a.ne[ 0 ] == KERNEL )
+ {
+ const uint32_t x = ( ( ne10 / 2 ) + TILE_Y - 1 ) / TILE_Y;
+ bindShader( eComputeShader::convolutionMain2Fixed );
+ context()->Dispatch( x, ne02, 1 );
+ return;
+ }
+ }
+ bindShader( eComputeShader::convolutionMain2 );
+ context()->Dispatch( ne10 / 2, ne02, 1 );
+ }
+ else
+ {
+ bindShader( eComputeShader::convolutionMain );
+ context()->Dispatch( ne10, ne02, 1 );
+ }
+#if 0
+ std::vector<uint16_t> tmp;
+ downloadBuffer( temp1, tmp );
+ dbgWriteBinaryFile( L"conv-gpu-arg1.bin", tmp.data(), lenTemp1 * 2 );
+ downloadBuffer( temp2, tmp );
+ dbgWriteBinaryFile( L"conv-gpu-arg2.bin", tmp.data(), lenTemp1 * 2 );
+ res.download( tempVector );
+ dbgWriteBinaryFile( L"conv-gpu-result.bin", tempVector.data(), tempVector.size() * 4 );
+#endif
+}
+
+void MlContext::norm( const Tensor& a, Tensor& res )
+{
+ const uint32_t ne01 = a.ne[ 1 ];
+ const uint32_t ne02 = a.ne[ 2 ];
+ const uint32_t ne03 = a.ne[ 3 ];
+
+ cb.bind();
+ cb.update( a, res );
+ Binder bind;
+ bind.bind( a, res );
+
+ if constexpr( usePreciseComputeShaders && !enableInexactOptimizations )
+ {
+ bindShader( eComputeShader::normCompat );
+ context()->Dispatch( ( ne01 + 31 ) / 32, ne02, ne03 );
+ }
+ else
+ {
+ constexpr uint32_t FIXED_ROW_SIZE = 1024;
+ eComputeShader cs = ( a.ne[ 0 ] == FIXED_ROW_SIZE ) ? eComputeShader::normFixed : eComputeShader::norm;
+ bindShader( cs );
+ context()->Dispatch( ne01, ne02, ne03 );
+ }
+}
+
+void MlContext::cwiseBinary( const Tensor& a, const Tensor& b, Tensor& res, eComputeShader cs )
+{
+ assert( isSameShape( a, b ) );
+ assert( isSameShape( a, res ) );
+
+ bindShader( cs );
+ cb.bind();
+ check( cb.update( a, b, res ) );
+ Binder bind;
+ bind.bind( a, b, res );
+
+ uint32_t rows = a.countRows();
+ context()->Dispatch( rows, 1, 1 );
+}
+
+Tensor MlContext::add( const Tensor& a, const Tensor& b )
+{
+ return cwiseBinary( a, b, eComputeShader::add );
+}
+
+void MlContext::addInPlace( Tensor& a, const Tensor& b )
+{
+ if( !isSameShape( a, b ) )
+ throw E_INVALIDARG;
+ assert( a.getType() == eDataType::FP32 );
+
+ check( cb.update( a, b ) );
+ bindShader( eComputeShader::addInPlace );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( b, a );
+ context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] );
+}
+
+void MlContext::copyImpl( const Tensor& a, Tensor& res, bool downcastFp32 )
+{
+ assert( res.isContinuous() );
+ const eComputeShader cs = a.isContinuous() ? eComputeShader::copyConvert : eComputeShader::copyTranspose;
+ bindShader( cs );
+
+ cb.bind();
+ // These two shaders don't need shape of the destination because dense, but they wants a boolean flag whether to implement rounding while downcasting
+ TensorShape dummyShape;
+ dummyShape.setZero();
+ dummyShape.ne[ 0 ] = downcastFp32 ? TRUE : FALSE;
+ check( cb.update( a, dummyShape ) );
+
+ Binder bind;
+ bind.bind( a, res );
+ context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] );
+}
+
+namespace
+{
+ uint32_t bitcast( float val )
+ {
+ __m128 f = _mm_set_ss( val );
+ __m128i i = _mm_castps_si128( f );
+ return (uint32_t)_mm_cvtsi128_si32( i );
+ }
+}
+
+void MlContext::scale( Tensor& a, float mul )
+{
+ if( !a.isContinuous() )
+ throw E_INVALIDARG;
+
+ bindShader( eComputeShader::scaleInPlace );
+ cb.bind();
+ TensorShape dummyShape;
+ dummyShape.setZero();
+ dummyShape.ne[ 0 ] = bitcast( mul );
+ check( cb.update( a, dummyShape ) );
+
+ Binder bind;
+ bind.bind( a );
+ context()->Dispatch( a.countRows(), 1, 1 );
+}
+
+void MlContext::addRepeat( Tensor& a, const Tensor& b )
+{
+ check( cb.update( a, b ) );
+ bindShader( eComputeShader::addRepeat );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( b, a );
+ context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] );
+}
+
+void MlContext::addRepeatScale( Tensor& a, const Tensor& b, float scale )
+{
+#if 0
+ addRepeat( a, b );
+ this->scale( a, scale );
+ return;
+#endif
+
+ TensorShape dummyShape;
+ dummyShape.setZero();
+ dummyShape.ne[ 0 ] = bitcast( scale );
+ check( cb.update( a, b, dummyShape ) );
+ bindShader( eComputeShader::addRepeatScale );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( b, a );
+ context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] );
+}
+
+void MlContext::fmaRepeat( Tensor& a, const Tensor& mul, const Tensor& add )
+{
+ eComputeShader cs;
+ if( isSameShapeAndLayout( mul, add ) )
+ {
+ cs = eComputeShader::fmaRepeat1;
+ check( cb.update( a, mul ) );
+ }
+ else
+ {
+ cs = eComputeShader::fmaRepeat2;
+ check( cb.update( a, mul, add ) );
+ }
+
+ bindShader( cs );
+ cb.bind();
+ Binder bind;
+ bind.bind( mul, add, a );
+ context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] );
+}
+
+void MlContext::diagMaskInf( Tensor& a, uint32_t n_past )
+{
+ if( !a.isContinuous() )
+ throw E_INVALIDARG;
+
+ bindShader( eComputeShader::diagMaskInf );
+ TensorShape dummyShape;
+ dummyShape.setZero();
+ dummyShape.ne[ 0 ] = n_past;
+
+ cb.bind();
+ check( cb.update( a, dummyShape ) );
+
+ Binder bind;
+ bind.bind( a );
+
+ const uint32_t n = a.countRows();
+ const uint32_t nr = a.ne[ 1 ];
+ const uint32_t nz = n / nr;
+ context()->Dispatch( nr, nz, 1 );
+}
+
+void MlContext::softMax( Tensor& a, float inputScale )
+{
+ if( !a.isContinuous() )
+ throw E_INVALIDARG;
+
+ if constexpr( usePreciseComputeShaders && !enableInexactOptimizations )
+ {
+ assert( inputScale == 1.0f );
+ bindShader( eComputeShader::softMaxCompat );
+ const uint32_t nr = a.countRows();
+ TensorShape dummyShape;
+ dummyShape.setZero();
+ dummyShape.ne[ 0 ] = nr;
+
+ cb.bind();
+ check( cb.update( a, dummyShape ) );
+
+ Binder bind;
+ bind.bind( lookupTables.exponent(), a );
+ context()->Dispatch( ( nr + 31 ) / 32, 1, 1 );
+ }
+ else
+ {
+#if 0
+ static PrintUniqueTensorSizes printSizes( "softMax" );
+ printSizes.print( a );
+#endif
+ constexpr uint32_t FIXED_ROW_SIZE = 1500;
+ eComputeShader cs = ( a.ne[ 0 ] == FIXED_ROW_SIZE ) ? eComputeShader::softMaxFixed : eComputeShader::softMax;
+ bindShader( cs );
+ const uint32_t nr = a.countRows();
+ TensorShape dummyShape;
+ dummyShape.setZero();
+ dummyShape.ne[ 0 ] = nr;
+ dummyShape.ne[ 1 ] = bitcast( inputScale );
+
+ cb.bind();
+ check( cb.update( a, dummyShape ) );
+
+ Binder bind;
+ bind.bind( lookupTables.exponent(), a );
+ context()->Dispatch( nr, 1, 1 );
+ }
+}
+
+void MlContext::addRepeatGelu( Tensor& a, const Tensor& b )
+{
+ check( cb.update( a, b ) );
+ bindShader( eComputeShader::addRepeatGelu );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( b, lookupTables.gelu(), a );
+ context()->Dispatch( a.ne[ 1 ], a.ne[ 2 ], a.ne[ 3 ] );
+}
+
+namespace
+{
+ inline bool canAddRows( const Tensor& tokenEmbedding, const Tensor& positionalEmbedding, const Tensor& embd, uint32_t pastTokensCount )
+ {
+ if( tokenEmbedding.ne[ 0 ] != positionalEmbedding.ne[ 0 ] )
+ return false; // Different row lengths
+ if( embd.ne[ 0 ] + pastTokensCount > positionalEmbedding.ne[ 1 ] )
+ return false; // Too many rows requested, positionalEmbedding matrix doesn't have that many
+ return true;
+ }
+}
+
+Tensor MlContext::addRows( const Tensor& tokenEmbedding, const Tensor& positionalEmbedding, const Tensor& embd, uint32_t pastTokensCount )
+{
+ if( !canAddRows( tokenEmbedding, positionalEmbedding, embd, pastTokensCount ) )
+ throw E_INVALIDARG;
+
+ const uint32_t rowLength = tokenEmbedding.ne[ 0 ];
+ const uint32_t rows = embd.ne[ 0 ];
+ Tensor result = createTensor( eDataType::FP32, { rowLength, rows } );
+
+ TensorShape constants;
+ // rowLength
+ constants.ne[ 0 ] = rowLength;
+ // pastTokensCount
+ constants.ne[ 1 ] = pastTokensCount;
+ // outputRowStride
+ constants.ne[ 2 ] = result.nb[ 1 ];
+ // embStrides
+ constants.nb[ 0 ] = tokenEmbedding.nb[ 0 ];
+ constants.nb[ 1 ] = tokenEmbedding.nb[ 1 ];
+ // posStrides
+ constants.nb[ 2 ] = positionalEmbedding.nb[ 0 ];
+ constants.nb[ 3 ] = positionalEmbedding.nb[ 1 ];
+ check( cb.update( constants ) );
+
+ bindShader( eComputeShader::addRows );
+ cb.bind();
+ Binder bind;
+ bind.bind( { tokenEmbedding, positionalEmbedding, embd }, { result } );
+ context()->Dispatch( rows, 1, 1 );
+ return result;
+}
+
+Tensor MlContext::reshapePanels( const Tensor& a )
+{
+ constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE;
+
+ const eDataType dataType = a.getType();
+ // Reshaping into column major horizontal panels, height = TILE_SIZE, width = width of the source matrix
+
+ // Round height to multiple of tile size
+ std::array<uint32_t, 4> ne = a.ne;
+ // Dispatch a group of threads thread per panel
+ const uint32_t groupsX = ( ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ ne[ 1 ] = groupsX * TILE_SIZE;;
+ // Each panel has [ size.x, TILE_SIZE ] elements
+ const uint32_t panelSize = ne[ 0 ] * TILE_SIZE;
+
+ Tensor result = createTensor( dataType, ne );
+
+ TensorShape constants;
+ constants.setZero();
+ // uint panelSize : packoffset( c2.y );
+ constants.ne[ 1 ] = panelSize;
+ // uint2 layerStrides: packoffset( c2.z );
+ constants.ne[ 2 ] = result.nb[ 2 ];
+ constants.ne[ 3 ] = result.nb[ 3 ];
+
+ check( cb.update( a, constants ) );
+ bindShader( eComputeShader::matReshapePanels );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( a, result );
+ context()->Dispatch( groupsX, a.ne[ 2 ], a.ne[ 3 ] );
+
+#if 0
+ if( dataType == eDataType::FP32 )
+ {
+ std::vector<float> v1, v2;
+ a.download( v1 );
+ result.download( v2 );
+ __debugbreak();
+ }
+ else if( dataType == eDataType::FP16 )
+ {
+ std::vector<uint16_t> v1, v2;
+ a.download( v1 );
+ result.download( v2 );
+ __debugbreak();
+ }
+#endif
+
+ // Set up size and stride expected by the mulMatTiledEx compute shader
+ result.ne = a.ne;
+ result.nb[ 0 ] = 0;
+ result.nb[ 1 ] = panelSize;
+ return result;
+}
+
+Tensor MlContext::mulMatTiledEx( const Tensor& a, const Tensor& b )
+{
+ constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE;
+
+ if( !canMulMat( a, b ) )
+ throw E_INVALIDARG; // Wrong size
+ if( 0 != ( a.nb[ 0 ] | b.nb[ 0 ] ) )
+ throw E_INVALIDARG; // Both tensors are expected to be pre-transposed into these panels
+
+ Tensor res = createTensor( eDataType::FP32, { a.ne[ 1 ], b.ne[ 1 ], a.ne[ 2 ], b.ne[ 3 ] } );
+
+ check( cb.update( a, b, res ) );
+ bindShader( eComputeShader::mulMatTiledEx );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( a, b, res );
+
+ const uint32_t x = ( res.ne[ 0 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ const uint32_t y = ( res.ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ const uint32_t z = res.ne[ 2 ] * res.ne[ 3 ];
+ context()->Dispatch( x, y, z );
+
+ return res;
+}
+
+Tensor MlContext::mulMatByRowTiledEx( const Tensor& a, const Tensor& b )
+{
+ constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE;
+ assert( canMulMat( a, b ) );
+ assert( b.ne[ 1 ] == 1 );
+
+ Tensor res = createTensor( eDataType::FP32, { a.ne[ 1 ], 1, a.ne[ 2 ], b.ne[ 3 ] } );
+
+ check( cb.update( a, b, res ) );
+ bindShader( eComputeShader::mulMatByRowTiledEx );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( a, b, res );
+
+ const uint32_t x = ( res.ne[ 0 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ const uint32_t y = res.ne[ 2 ];
+ const uint32_t z = res.ne[ 3 ];
+ context()->Dispatch( x, y, z );
+
+ return res;
+}
+
+__m128i MlContext::getMemoryUse() const
+{
+ __m128i v = cb.getMemoryUse();
+ v = _mm_add_epi64( v, temp.getMemoryUse() );
+ v = _mm_add_epi64( v, bufferMemoryUsage( flashAttentionConstants ) );
+ v = _mm_add_epi64( v, lookupTables.getMemoryUsage() );
+ return v;
+} \ No newline at end of file
diff --git a/Whisper/ML/MlContext.dbg.cpp b/Whisper/ML/MlContext.dbg.cpp
new file mode 100644
index 0000000..df095aa
--- /dev/null
+++ b/Whisper/ML/MlContext.dbg.cpp
@@ -0,0 +1,59 @@
+#include "stdafx.h"
+#include "MlContext.h"
+#include "../source/ggml.h"
+#include "testUtils.h"
+using namespace DirectCompute;
+
+#define E_TYPE HRESULT_FROM_WIN32( ERROR_DATATYPE_MISMATCH )
+
+static void dbgPrintSizeDiff( const char* what, __m128i ref, __m128i gpu )
+{
+ std::array<int, 8> a;
+ _mm_storeu_si128( ( __m128i* ) & a[ 0 ], ref );
+ _mm_storeu_si128( ( __m128i* ) & a[ 4 ], gpu );
+ printf( "%s; reference [ %i, %i, %i, %i ], GPGPU [ %i, %i, %i, %i ]\n",
+ what,
+ a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ],
+ a[ 4 ], a[ 5 ], a[ 6 ], a[ 7 ] );
+}
+
+void MlContext::dbgPrintDifference( const ggml_tensor* reference, const Tensor& gpu, const char* what, bool trapToDebugger )
+{
+ sTensorDiff diff;
+ const __m128i gpuSize = gpu.sizeVec();
+ const __m128i gpuStrides = gpu.stridesVec();
+ __m128i expectedStrides;
+ if( reference->type == GGML_TYPE_F32 )
+ {
+ if( gpu.getType() != eDataType::FP32 )
+ throw E_TYPE;
+ expectedStrides = _mm_slli_epi32( gpuStrides, 2 );
+
+ std::vector<float> v;
+ gpu.download( v );
+ diff = computeDiff( v.data(), (const float*)reference->data, v.size() );
+ }
+ else if( reference->type == GGML_TYPE_F16 )
+ {
+ if( gpu.getType() != eDataType::FP16 )
+ throw E_TYPE;
+ expectedStrides = _mm_slli_epi32( gpuStrides, 1 );
+
+ std::vector<uint16_t> v;
+ gpu.download( v );
+ diff = computeDiff( v.data(), (const uint16_t*)reference->data, v.size() );
+ }
+ else
+ throw E_NOTIMPL;
+
+ const __m128i ggmlSize = _mm_loadu_si128( ( const __m128i* ) & reference->ne[ 0 ] );
+ const __m128i ggmlStrides = _mm_loadu_si128( ( const __m128i* ) & reference->nb[ 0 ] );
+ if( !vectorEqual( gpuSize, ggmlSize ) )
+ dbgPrintSizeDiff( "dbgPrintDifference - size is different", ggmlSize, gpuSize );
+ // if( !vectorEqual( expectedStrides, ggmlStrides ) ) dbgPrintSizeDiff( "dbgPrintDifference - stride is different", ggmlStrides, expectedStrides );
+
+ diff.print( what );
+
+ if( trapToDebugger )
+ __debugbreak();
+} \ No newline at end of file
diff --git a/Whisper/ML/MlContext.h b/Whisper/ML/MlContext.h
new file mode 100644
index 0000000..d0c2d9e
--- /dev/null
+++ b/Whisper/ML/MlContext.h
@@ -0,0 +1,111 @@
+#pragma once
+#include <vector>
+#include "TempBuffers.h"
+#include "ConstantBuffer.h"
+#include "Tensor.h"
+#include "../Utils/GpuProfiler.h"
+#include "../Utils/ProfileCollection.h"
+
+namespace DirectCompute
+{
+ enum struct eComputeShader : uint16_t;
+
+ class MlContext
+ {
+ // When false, the implementation is 100% compatible with the CPU-running code written by Georgi Gerganov
+ // When true, the implementation is much faster, and doesn't require FP64 support in the compute shaders.
+ // FP64 is an optional feature, not all GPUs support that.
+ static constexpr bool enableInexactOptimizations = true;
+
+ ConstantBuffer cb;
+ TempBuffers temp;
+ CComPtr<ID3D11Buffer> flashAttentionConstants;
+
+ void convolutionImpl( const Tensor& a, const Tensor& b, Tensor& res, bool is2 );
+
+ void cwiseBinary( const Tensor& a, const Tensor& b, Tensor& res, eComputeShader cs );
+ Tensor cwiseBinary( const Tensor& a, const Tensor& b, eComputeShader cs );
+
+ void mulMatDot( const Tensor& a, const Tensor& b, Tensor& res );
+ void mulMatMad( const Tensor& a, const Tensor& b, Tensor& res );
+ void mulMatTiled( const Tensor& a, const Tensor& b, Tensor& res );
+
+ void bindShader( eComputeShader cs );
+
+ protected:
+ void copyImpl( const Tensor& a, Tensor& res, bool downcastFp32 );
+
+ // Create a dense output tensor for the results of a computation
+ // Override this method to implement a pool of these tensors
+ virtual Tensor createTensor( eDataType type, const std::array<uint32_t, 4>& ne );
+
+ Tensor createTensor( eDataType type, std::initializer_list<uint32_t> ne );
+
+ GpuProfiler profiler;
+
+ public:
+ MlContext( Whisper::ProfileCollection& profileColl );
+ MlContext( const MlContext& ) = delete;
+
+ // res = a * b
+ void mulMat( const Tensor& a, const Tensor& b, Tensor& res );
+
+ void flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, Tensor& res, bool masked );
+
+ inline void convolution( const Tensor& a, const Tensor& b, Tensor& res )
+ {
+ convolutionImpl( a, b, res, false );
+ }
+ void convolution2( const Tensor& a, const Tensor& b, Tensor& res )
+ {
+ convolutionImpl( a, b, res, true );
+ }
+
+ void norm( const Tensor& a, Tensor& res );
+
+ Tensor conv_1d_1s( const Tensor& a, const Tensor& b );
+ Tensor conv_1d_2s( const Tensor& a, const Tensor& b );
+
+ Tensor add( const Tensor& a, const Tensor& b );
+ void addInPlace( Tensor& a, const Tensor& b );
+
+ Tensor view2d( const Tensor& a, uint32_t ne0, uint32_t ne1, uint32_t nb1, uint32_t offset );
+ Tensor transpose( const Tensor& a );
+
+ Tensor norm( const Tensor& a );
+ Tensor mulMat( const Tensor& a, const Tensor& b );
+ Tensor mulMatEx( const Tensor& a, const Tensor& b, const char* tagName );
+ Tensor permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 );
+ Tensor flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, bool masked );
+
+ Tensor copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size );
+ void copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size );
+
+ void dbgPrintDifference( const ggml_tensor* reference, const Tensor& gpu, const char * what, bool trapToDebugger = true );
+
+ void scale( Tensor& a, float mul );
+
+ void addRepeat( Tensor& a, const Tensor& b );
+ void addRepeatScale( Tensor& a, const Tensor& b, float scale );
+ void fmaRepeat( Tensor& a, const Tensor& mul, const Tensor& add );
+
+ // ggml_diag_mask_inf
+ void diagMaskInf( Tensor& a, uint32_t n_past );
+ // ggml_soft_max
+ void softMax( Tensor& a, float inputScale = 1.0f );
+
+ void addRepeatGelu( Tensor& a, const Tensor& b );
+
+ // Extract rows from tokenEmbedding matrix, row indices are taken from the `embd` R32_UINT row vector
+ // Extract same count of rows from positionalEmbedding matrix, starting at the `pastTokensCount` row
+ // Return a new FP32 matrix with the sum of these rows
+ Tensor addRows( const Tensor& tokenEmbedding, const Tensor& positionalEmbedding, const Tensor& embd, uint32_t pastTokensCount );
+
+ Tensor reshapePanels( const Tensor& a );
+
+ Tensor mulMatTiledEx( const Tensor& a, const Tensor& b );
+ Tensor mulMatByRowTiledEx( const Tensor& a, const Tensor& b );
+
+ __m128i getMemoryUse() const;
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/Reshaper.cpp b/Whisper/ML/Reshaper.cpp
new file mode 100644
index 0000000..af66929
--- /dev/null
+++ b/Whisper/ML/Reshaper.cpp
@@ -0,0 +1,80 @@
+#include "stdafx.h"
+#include "Reshaper.h"
+#include "../D3D/MappedResource.h"
+#include "../D3D/Binder.h"
+#include "../D3D/shaders.h"
+#include "reshapedMultiply.h"
+
+namespace
+{
+ using namespace DirectCompute;
+ struct Constants
+ {
+ // Size and strides of the source tensor
+ TensorShape arg0;
+ uint32_t zzPadding;
+ // Count of elements per panel
+ uint32_t panelSize;
+ // Layer strides of the output matrix
+ std::array<uint32_t, 2> layerStrides;
+ };
+}
+
+HRESULT DirectCompute::Reshaper::createConstants()
+{
+ constexpr uint32_t cb = sizeof( Constants );
+ CD3D11_BUFFER_DESC desc{ cb, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE };
+ return device()->CreateBuffer( &desc, nullptr, &constantBuffer );
+}
+
+HRESULT DirectCompute::Reshaper::makePanels( Tensor& tensor, eDataType dataType )
+{
+ if( !constantBuffer )
+ CHECK( createConstants() );
+
+ constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE;
+
+ // Reshaping into column major horizontal panels, height = TILE_SIZE, width = width of the source matrix
+
+ std::array<uint32_t, 4> ne = tensor.ne;
+ const uint32_t groupsX = ( ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ ne[ 1 ] = groupsX * TILE_SIZE;;
+ // Each panel has [ size.x, TILE_SIZE ] elements
+ const uint32_t panelSize = ne[ 0 ] * TILE_SIZE;
+
+ Tensor result;
+ result.create( dataType, ne );
+
+ {
+ MappedResource mapped;
+ CHECK( mapped.map( constantBuffer, false ) );
+ Constants& cb = *(Constants*)mapped.data();
+
+ store( cb.arg0.ne, tensor.sizeVec() );
+ store( cb.arg0.nb, tensor.stridesVec() );
+ cb.panelSize = panelSize;
+ cb.layerStrides[ 0 ] = result.nb[ 2 ];
+ cb.layerStrides[ 1 ] = result.nb[ 3 ];
+ }
+
+ csSetCB( constantBuffer );
+ {
+ Binder bind;
+ bind.bind( tensor, result );
+ bindShader( eComputeShader::matReshapePanels );
+ context()->Dispatch( groupsX, tensor.ne[ 2 ], tensor.ne[ 3 ] );
+ }
+
+ tensor.nb[ 0 ] = 0;
+ tensor.nb[ 1 ] = panelSize;
+ tensor.nb[ 2 ] = result.nb[ 2 ];
+ tensor.nb[ 3 ] = result.nb[ 3 ];
+ tensor.setGpuViews( result );
+ return S_OK;
+}
+
+DirectCompute::Reshaper::~Reshaper()
+{
+ if( constantBuffer )
+ csSetCB( nullptr );
+} \ No newline at end of file
diff --git a/Whisper/ML/Reshaper.h b/Whisper/ML/Reshaper.h
new file mode 100644
index 0000000..3565cbb
--- /dev/null
+++ b/Whisper/ML/Reshaper.h
@@ -0,0 +1,17 @@
+#pragma once
+#include "Tensor.h"
+
+namespace DirectCompute
+{
+ // This class reshapes some of the model’s tensor, immediately after they’re loaded.
+ // That feature is used on all AMD GPUs.
+ class Reshaper
+ {
+ CComPtr<ID3D11Buffer> constantBuffer;
+ HRESULT createConstants();
+
+ public:
+ ~Reshaper();
+ HRESULT makePanels( Tensor& tensor, eDataType dataType );
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/TempBuffers.cpp b/Whisper/ML/TempBuffers.cpp
new file mode 100644
index 0000000..8823364
--- /dev/null
+++ b/Whisper/ML/TempBuffers.cpp
@@ -0,0 +1,88 @@
+#include "stdafx.h"
+#include "TempBuffers.h"
+#include "../D3D/createBuffer.h"
+#include "../D3D/MappedResource.h"
+#include "../D3D/shaders.h"
+using namespace DirectCompute;
+
+#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
+
+HRESULT TempBuffers::Buffer::resize( DXGI_FORMAT format, size_t elements, size_t cbElement, bool zeroMemory, CComPtr<ID3D11Buffer>& cb )
+{
+ if( elements <= capacity )
+ {
+ if( zeroMemory )
+ TempBuffers::zeroMemory( *this, (uint32_t)elements, cb );
+ return S_OK;
+ }
+ clear();
+
+ CComPtr<ID3D11Buffer> buffer;
+ const size_t totalBytes = elements * cbElement;
+ CHECK( createBuffer( eBufferUse::ReadWrite, totalBytes, &buffer, nullptr, nullptr ) );
+ CHECK( TensorGpuViews::create( buffer, format, elements, true ) );
+ capacity = elements;
+ return S_OK;
+}
+
+void TempBuffers::zeroMemory( ID3D11UnorderedAccessView* uav, uint32_t length, CComPtr<ID3D11Buffer>& cb )
+{
+ const __m128i cbData = _mm_cvtsi32_si128( (int)length );
+ if( cb )
+ {
+ MappedResource mapped;
+ check( mapped.map( cb, false ) );
+ store16( mapped.data(), cbData );
+ }
+ else
+ {
+ CD3D11_BUFFER_DESC desc{ 16, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE };
+ std::array<uint32_t, 4> cbBuffer;
+ store( cbBuffer, cbData );
+ D3D11_SUBRESOURCE_DATA srd{ cbBuffer.data(), 0, 0 };
+ check( device()->CreateBuffer( &desc, &srd, &cb ) );
+ }
+
+ ID3D11DeviceContext* ctx = context();
+ ctx->CSSetUnorderedAccessViews( 0, 1, &uav, nullptr );
+ csSetCB( cb );
+
+ constexpr uint32_t THREADS = 512;
+ constexpr uint32_t ITERATIONS = 128;
+ constexpr uint32_t elementsPerGroup = THREADS * ITERATIONS;
+ const uint32_t countGroups = ( length + elementsPerGroup - 1 ) / elementsPerGroup;
+ bindShader( eComputeShader::zeroMemory );
+ ctx->Dispatch( countGroups, 1, 1 );
+}
+
+const TensorGpuViews& TempBuffers::fp16( size_t countElements, bool zeroMemory )
+{
+ HRESULT hr = m_fp16.resize( DXGI_FORMAT_R16_FLOAT, countElements, 2, zeroMemory, smallCb );
+ if( FAILED( hr ) )
+ throw hr;
+ return m_fp16;
+}
+
+const TensorGpuViews& TempBuffers::fp16_2( size_t countElements, bool zeroMemory )
+{
+ HRESULT hr = m_fp16_2.resize( DXGI_FORMAT_R16_FLOAT, countElements, 2, zeroMemory, smallCb );
+ if( FAILED( hr ) )
+ throw hr;
+ return m_fp16_2;
+}
+
+const TensorGpuViews& TempBuffers::fp32( size_t countElements, bool zeroMemory )
+{
+ HRESULT hr = m_fp32.resize( DXGI_FORMAT_R32_FLOAT, countElements, 4, zeroMemory, smallCb );
+ if( FAILED( hr ) )
+ throw hr;
+ return m_fp32;
+}
+
+__m128i TempBuffers::getMemoryUse() const
+{
+ size_t cb = m_fp16.getCapacity() * 2;
+ cb += m_fp16_2.getCapacity() * 2;
+ cb += m_fp32.getCapacity() * 4;
+ return setHigh_size( cb );
+} \ No newline at end of file
diff --git a/Whisper/ML/TempBuffers.h b/Whisper/ML/TempBuffers.h
new file mode 100644
index 0000000..f376ed6
--- /dev/null
+++ b/Whisper/ML/TempBuffers.h
@@ -0,0 +1,48 @@
+#pragma once
+#include "TensorGpuViews.h"
+
+namespace DirectCompute
+{
+ class TempBuffers
+ {
+ class Buffer : public TensorGpuViews
+ {
+ size_t capacity = 0;
+
+ public:
+
+ void clear()
+ {
+ TensorGpuViews::clear();
+ capacity = 0;
+ }
+
+ HRESULT resize( DXGI_FORMAT format, size_t elements, size_t cbElement, bool zeroMemory, CComPtr<ID3D11Buffer>& cb );
+
+ size_t getCapacity() const { return capacity; }
+ };
+
+ Buffer m_fp16;
+ Buffer m_fp16_2;
+ Buffer m_fp32;
+
+ public:
+
+ CComPtr<ID3D11Buffer> smallCb;
+
+ static void zeroMemory( ID3D11UnorderedAccessView* uav, uint32_t length, CComPtr<ID3D11Buffer>& cb );
+
+ const TensorGpuViews& fp16( size_t countElements, bool zeroMemory = false );
+ const TensorGpuViews& fp16_2( size_t countElements, bool zeroMemory = false );
+ const TensorGpuViews& fp32( size_t countElements, bool zeroMemory = false );
+
+ void clear()
+ {
+ m_fp16.clear();
+ m_fp16_2.clear();
+ m_fp32.clear();
+ }
+
+ __m128i getMemoryUse() const;
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/Tensor.cpp b/Whisper/ML/Tensor.cpp
new file mode 100644
index 0000000..5542f5a
--- /dev/null
+++ b/Whisper/ML/Tensor.cpp
@@ -0,0 +1,340 @@
+#include "stdafx.h"
+#include "Tensor.h"
+#include "../D3D/MappedResource.h"
+#include "../D3D/createBuffer.h"
+#include "../source/ggml.h"
+using namespace DirectCompute;
+
+Tensor::Tensor( const Tensor& that )
+{
+ ne = that.ne;
+ nb = that.nb;
+ srv = that.srv;
+ uav = that.uav;
+#ifdef _DEBUG
+ dbgType = that.dbgType;
+#endif
+}
+
+Tensor::Tensor( Tensor&& that ) noexcept
+{
+ ne = that.ne;
+ nb = that.nb;
+ srv.Attach( that.srv.Detach() );
+ uav.Attach( that.uav.Detach() );
+#ifdef _DEBUG
+ dbgType = that.dbgType;
+#endif
+}
+
+Tensor& Tensor::operator=( const Tensor& that )
+{
+ ne = that.ne;
+ nb = that.nb;
+ srv = that.srv;
+ uav = that.uav;
+#ifdef _DEBUG
+ dbgType = that.dbgType;
+#endif
+ return *this;
+}
+
+Tensor& Tensor::operator=( Tensor&& that ) noexcept
+{
+ ne = that.ne;
+ nb = that.nb;
+ srv.Attach( that.srv.Detach() );
+ uav.Attach( that.uav.Detach() );
+#ifdef _DEBUG
+ dbgType = that.dbgType;
+#endif
+ return *this;
+}
+
+Tensor::Tensor( const TensorShape& shape, CComPtr<ID3D11ShaderResourceView>& srv, CComPtr<ID3D11UnorderedAccessView>& uav ) noexcept :
+ TensorShape( shape )
+{
+ TensorGpuViews::srv.Attach( srv.Detach() );
+ TensorGpuViews::uav.Attach( uav.Detach() );
+}
+
+Tensor::Tensor( const TensorShape& shape, const TensorGpuViews& views ) :
+ TensorShape( shape )
+{
+ srv = views;
+ uav = views;
+}
+
+HRESULT Tensor::create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData )
+{
+ TensorGpuViews::clear();
+
+ switch( usage )
+ {
+ case eBufferUse::Immutable:
+ case eBufferUse::ReadWriteDownload:
+ break;
+ default:
+ return E_INVALIDARG;
+ }
+
+ CComPtr<ID3D11Buffer> buffer;
+
+ CHECK( TensorShape::create( ggml ) );
+ const ggml_type dataType = ggml.type;
+ const uint32_t cbElement = (uint32_t)ggml_type_size( dataType );
+
+ const size_t totalBytes = ggml_nbytes( &ggml );
+ if( totalBytes > INT_MAX )
+ return DISP_E_OVERFLOW;
+ const uint32_t countElements = (uint32_t)( totalBytes / cbElement );
+
+ {
+ const void* const rsi = uploadData ? ggml.data : nullptr;
+ CHECK( createBuffer( usage, totalBytes, &buffer, rsi, nullptr ) );
+ }
+
+ DXGI_FORMAT format;
+ eDataType type;
+ switch( dataType )
+ {
+ case GGML_TYPE_F16:
+ format = DXGI_FORMAT_R16_FLOAT;
+ type = eDataType::FP16;
+ break;
+ case GGML_TYPE_F32:
+ format = DXGI_FORMAT_R32_FLOAT;
+ type = eDataType::FP32;
+ break;
+ default:
+ return E_NOTIMPL;
+ }
+
+ const bool makeUav = ( usage == eBufferUse::ReadWrite );
+
+ CHECK( TensorGpuViews::create( buffer, format, totalBytes / cbElement, makeUav ) );
+#ifdef _DEBUG
+ dbgType.type = type;
+ dbgType.usage = usage;
+ dbgType.hasInitialData = uploadData;
+#endif
+ return S_OK;
+}
+
+HRESULT Tensor::createImmutable( eDataType type, const std::array<int, 4>& size, const void* rsi )
+{
+ size_t elts = (uint32_t)size[ 0 ];
+ elts *= (uint32_t)size[ 1 ];
+ elts *= (uint32_t)size[ 2 ];
+ elts *= (uint32_t)size[ 3 ];
+
+ DXGI_FORMAT format;
+ size_t cbElement;
+ switch( type )
+ {
+ case eDataType::FP16:
+ format = DXGI_FORMAT_R16_FLOAT;
+ cbElement = 2;
+ break;
+ case eDataType::FP32:
+ format = DXGI_FORMAT_R32_FLOAT;
+ cbElement = 4;
+ break;
+ default:
+ return E_NOTIMPL;
+ }
+
+ CComPtr<ID3D11Buffer> buffer;
+ CHECK( createBuffer( eBufferUse::Immutable, cbElement * elts, &buffer, rsi, nullptr ) );
+ CHECK( TensorGpuViews::create( buffer, format, elts, false ) );
+
+ __m128i v = _mm_loadu_si128( ( const __m128i* )size.data() );
+ _mm_storeu_si128( ( __m128i* )ne.data(), v );
+ setDenseStrides();
+ return S_OK;
+}
+
+HRESULT Tensor::create( eDataType type, std::initializer_list<uint32_t> sizeElements, eBufferUse usage, CComPtr<ID3D11Buffer>& buffer, const void* rsi, ID3D11Buffer** ppStagingBuffer )
+{
+ TensorGpuViews::clear();
+
+ size_t nDims = sizeElements.size();
+ if( 0 == nDims || nDims > 4 )
+ return E_INVALIDARG;
+ nDims = std::min( nDims, (size_t)4 );
+ size_t totalElements = 1;
+ for( size_t i = 0; i < nDims; i++ )
+ {
+ uint32_t n = sizeElements.begin()[ i ];
+ if( n == 0 )
+ return E_INVALIDARG;
+ ne[ i ] = n;
+ totalElements *= n;
+ }
+
+ DXGI_FORMAT format;
+ size_t cbElement;
+ switch( type )
+ {
+ case eDataType::FP32:
+ format = DXGI_FORMAT_R32_FLOAT;
+ cbElement = 4;
+ break;
+ case eDataType::FP16:
+ format = DXGI_FORMAT_R16_FLOAT;
+ cbElement = 2;
+ break;
+ case eDataType::U32:
+ format = DXGI_FORMAT_R32_UINT;
+ cbElement = 4;
+ break;
+ default:
+ return E_NOTIMPL;
+ }
+
+ const size_t totalBytes = cbElement * totalElements;
+ if( totalBytes > INT_MAX )
+ return DISP_E_OVERFLOW;
+
+ for( size_t i = nDims; i < 4; i++ )
+ ne[ i ] = 1;
+ TensorShape::setDenseStrides();
+
+ CHECK( createBuffer( usage, totalBytes, &buffer, rsi, ppStagingBuffer ) );
+
+ CHECK( TensorGpuViews::create( buffer, format, totalBytes / cbElement, true ) );
+#ifdef _DEBUG
+ dbgType.type = type;
+ dbgType.usage = usage;
+ dbgType.hasInitialData = ( nullptr != rsi );
+#endif
+ return S_OK;
+}
+
+HRESULT Tensor::create( eDataType type, std::initializer_list<uint32_t> sizeElements )
+{
+ CComPtr<ID3D11Buffer> buffer;
+ return create( type, sizeElements, eBufferUse::ReadWrite, buffer, nullptr, nullptr );
+}
+
+HRESULT Tensor::create( eDataType type, const std::array<uint32_t, 4>& sizeElements )
+{
+ std::initializer_list<uint32_t> il( sizeElements.data(), sizeElements.data() + 4 );
+ return create( type, il );
+}
+
+eDataType Tensor::getType() const
+{
+ ID3D11ShaderResourceView* const srv = *this;
+ if( nullptr == srv )
+ throw OLE_E_BLANK;
+
+ D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc;
+ srv->GetDesc( &viewDesc );
+ const DXGI_FORMAT format = viewDesc.Format;
+ switch( format )
+ {
+ case DXGI_FORMAT_R32_FLOAT:
+ return eDataType::FP32;
+ case DXGI_FORMAT_R16_FLOAT:
+ return eDataType::FP16;
+ case DXGI_FORMAT_R32_UINT:
+ return eDataType::U32;
+ }
+ throw E_NOTIMPL;
+}
+
+CComPtr<ID3D11Buffer> Tensor::getBuffer() const
+{
+ ID3D11ShaderResourceView* const srv = *this;
+ if( nullptr == srv )
+ throw OLE_E_BLANK;
+
+ CComPtr<ID3D11Resource> res;
+ srv->GetResource( &res );
+
+ CComPtr<ID3D11Buffer> buff;
+ check( res.QueryInterface( &buff ) );
+ return buff;
+}
+
+uint32_t Tensor::dxgiSizeof( DXGI_FORMAT format )
+{
+ switch( format )
+ {
+ case DXGI_FORMAT_R16_FLOAT:
+ return 2;
+ case DXGI_FORMAT_R32_FLOAT:
+ case DXGI_FORMAT_R32_UINT:
+ return 4;
+ }
+ throw E_INVALIDARG;
+}
+
+void Tensor::downloadImpl( const D3D11_SHADER_RESOURCE_VIEW_DESC& viewDesc, uint32_t countElements, size_t cbElement, void* rdi ) const
+{
+ assert( viewDesc.ViewDimension == D3D_SRV_DIMENSION_BUFFER );
+ const uint32_t idxFirst = viewDesc.Buffer.FirstElement;
+
+ CComPtr<ID3D11Buffer> buff = getBuffer();
+ D3D11_BUFFER_DESC desc;
+ buff->GetDesc( &desc );
+ desc.BindFlags = 0;
+ desc.Usage = D3D11_USAGE_STAGING;
+ desc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
+
+ CComPtr<ID3D11Buffer> staging;
+ check( device()->CreateBuffer( &desc, nullptr, &staging ) );
+ context()->CopyResource( staging, buff );
+
+ MappedResource mapped;
+ check( mapped.map( staging, true ) );
+ const uint8_t* rsi = (const uint8_t*)mapped.data();
+ rsi += cbElement * idxFirst;
+ memcpy( rdi, rsi, cbElement * countElements );
+}
+
+void Tensor::download( std::vector<float>& vec ) const
+{
+ ID3D11ShaderResourceView* const srv = *this;
+ if( nullptr == srv )
+ throw OLE_E_BLANK;
+
+ D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc;
+ srv->GetDesc( &viewDesc );
+ if( viewDesc.Format != DXGI_FORMAT_R32_FLOAT )
+ throw E_INVALIDARG;
+
+ uint32_t countElements = viewDesc.Buffer.NumElements;
+ vec.resize( countElements );
+ downloadImpl( viewDesc, countElements, 4, vec.data() );
+}
+
+void Tensor::download( std::vector<uint16_t>& vec ) const
+{
+ ID3D11ShaderResourceView* const srv = *this;
+ if( nullptr == srv )
+ throw OLE_E_BLANK;
+
+ D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc;
+ srv->GetDesc( &viewDesc );
+ if( viewDesc.Format != DXGI_FORMAT_R16_FLOAT )
+ throw E_INVALIDARG;
+
+ uint32_t countElements = viewDesc.Buffer.NumElements;
+ vec.resize( countElements );
+ downloadImpl( viewDesc, countElements, 2, vec.data() );
+}
+
+Tensor Tensor::reshape3d( uint32_t ne0, uint32_t ne1, uint32_t ne2 ) const
+{
+ if( !isContinuous() )
+ throw E_NOTIMPL;
+ if( countElements() != ne0 * ne1 * ne2 )
+ throw E_INVALIDARG;
+
+ Tensor res = *this;
+ res.ne = { ne0, ne1, ne2, 1 };
+ res.setDenseStrides();
+ return res;
+} \ No newline at end of file
diff --git a/Whisper/ML/Tensor.h b/Whisper/ML/Tensor.h
new file mode 100644
index 0000000..cc61b7f
--- /dev/null
+++ b/Whisper/ML/Tensor.h
@@ -0,0 +1,78 @@
+#pragma once
+#include "TensorShape.h"
+#include "TensorGpuViews.h"
+#include "../D3D/enums.h"
+
+namespace DirectCompute
+{
+ // A minimal tensor object sufficient to compute things on GPU, with compute shaders
+ // This class only takes 48 bytes in system memory, and is very cheap to make copies 'coz GPU objects are reference counted.
+ class Tensor : public TensorShape, public TensorGpuViews
+ {
+ CComPtr<ID3D11Buffer> getBuffer() const;
+
+ struct TensorType
+ {
+ eDataType type;
+ eBufferUse usage;
+ bool hasInitialData;
+ };
+#ifdef _DEBUG
+ // In debug builds, we include a few pieces of data to this class.
+ TensorType dbgType;
+#endif
+ protected:
+ HRESULT create( eDataType type, std::initializer_list<uint32_t> sizeElements, eBufferUse usage, CComPtr<ID3D11Buffer>& buffer, const void* rsi, ID3D11Buffer** ppStagingBuffer );
+
+ static uint32_t dxgiSizeof( DXGI_FORMAT format );
+
+ void downloadImpl( const D3D11_SHADER_RESOURCE_VIEW_DESC& viewDesc, uint32_t countElements, size_t cbElement, void* rdi ) const;
+
+ public:
+ Tensor() = default;
+
+ // These copy operators don't copy any data, they merely increment ref.counter of the GPU resources
+ Tensor( const Tensor& );
+ Tensor( Tensor&& that ) noexcept;
+ Tensor& operator=( const Tensor& that );
+ Tensor& operator=( Tensor&& that ) noexcept;
+
+ // Move the provided buffer views into this newly created tensor, and assign the shape
+ // This destroys old values in the smart pointers
+ Tensor( const TensorShape& shape, CComPtr<ID3D11ShaderResourceView>& srv, CComPtr<ID3D11UnorderedAccessView>& uav ) noexcept;
+
+ Tensor( const TensorShape& shape, const TensorGpuViews& views );
+
+ // Create a tensor from the GGML's one
+ HRESULT create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData );
+
+ // Create a new dense tensor of the specified size in elements, without initial data
+ HRESULT create( eDataType type, std::initializer_list<uint32_t> sizeElements );
+ HRESULT create( eDataType type, const std::array<uint32_t, 4>& sizeElements );
+ HRESULT createImmutable( eDataType type, const std::array<int, 4>& size, const void* rsi );
+
+ eDataType getType() const;
+
+ // This method should probably only be used to test things
+ // TensorEx is better for production usage, because it creates staging buffer in advance.
+ void download( std::vector<float>& vec ) const;
+ void download( std::vector<uint16_t>& vec ) const;
+
+ // ggml_reshape_3d
+ Tensor reshape3d( uint32_t ne0, uint32_t ne1, uint32_t ne2 ) const;
+
+ inline void dbgSetType( eDataType dt, bool hasData = false, eBufferUse use = eBufferUse::ReadWrite )
+ {
+#ifdef _DEBUG
+ dbgType.type = dt;
+ dbgType.hasInitialData = hasData;
+ dbgType.usage = use;
+#endif
+ }
+
+ __m128i getMemoryUse() const
+ {
+ return resourceMemoryUsage( srv );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/TensorEx.cpp b/Whisper/ML/TensorEx.cpp
new file mode 100644
index 0000000..97e4e30
--- /dev/null
+++ b/Whisper/ML/TensorEx.cpp
@@ -0,0 +1,97 @@
+#include "stdafx.h"
+#include "TensorEx.h"
+#include "../D3D/createBuffer.h"
+#include "../source/ggml.h"
+#include "../D3D/MappedResource.h"
+using namespace DirectCompute;
+
+HRESULT TensorEx::create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData )
+{
+ TensorGpuViews::clear();
+ buffer = nullptr;
+ stagingBuffer = nullptr;
+
+ CHECK( TensorShape::create( ggml ) );
+ const ggml_type dataType = ggml.type;
+ const uint32_t cbElement = (uint32_t)ggml_type_size( dataType );
+
+ const size_t totalBytes = ggml_nbytes( &ggml );
+ if( totalBytes > INT_MAX )
+ return DISP_E_OVERFLOW;
+ const uint32_t countElements = (uint32_t)( totalBytes / cbElement );
+
+ {
+ const void* const rsi = uploadData ? ggml.data : nullptr;
+ ID3D11Buffer** ppStagingBuffer = ( usage == eBufferUse::ReadWriteDownload ) ? &stagingBuffer : nullptr;
+ CHECK( createBuffer( usage, totalBytes, &buffer, rsi, ppStagingBuffer ) );
+ }
+
+ DXGI_FORMAT format;
+ switch( dataType )
+ {
+ case GGML_TYPE_F16:
+ format = DXGI_FORMAT_R16_FLOAT;
+ break;
+ case GGML_TYPE_F32:
+ format = DXGI_FORMAT_R32_FLOAT;
+ break;
+ default:
+ return E_NOTIMPL;
+ }
+
+ const bool makeUav = usage == eBufferUse::ReadWrite || usage == eBufferUse::ReadWriteDownload;
+ return TensorGpuViews::create( buffer, format, totalBytes / cbElement, makeUav );
+}
+
+HRESULT TensorEx::create( eDataType type, eBufferUse usage, const std::array<uint32_t, 4>& sizeElements )
+{
+ TensorGpuViews::clear();
+ buffer = nullptr;
+ stagingBuffer = nullptr;
+ std::initializer_list<uint32_t> il( sizeElements.data(), sizeElements.data() + 4 );
+
+ ID3D11Buffer** ppStaging = ( usage == eBufferUse::ReadWriteDownload ) ? &stagingBuffer : nullptr;
+ return Tensor::create( type, il, usage, buffer, nullptr, ppStaging );
+}
+
+HRESULT TensorEx::getViewSize( uint32_t& cbElement, uint32_t& countElements ) const
+{
+ ID3D11ShaderResourceView* const srv = *this;
+ if( nullptr == srv )
+ return OLE_E_BLANK;
+
+ D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc;
+ srv->GetDesc( &viewDesc );
+
+ cbElement = dxgiSizeof( viewDesc.Format );
+
+ assert( viewDesc.ViewDimension == D3D_SRV_DIMENSION_BUFFER );
+ assert( viewDesc.Buffer.FirstElement == 0 );
+ countElements = viewDesc.Buffer.NumElements;
+
+ return S_OK;
+}
+
+HRESULT TensorEx::download( void* rdi, size_t cb ) const
+{
+ if( nullptr == stagingBuffer )
+ return HRESULT_FROM_WIN32( ERROR_GPIO_OPERATION_DENIED ); // The requested operation is not supported for the specified handle.
+
+ ID3D11DeviceContext* const ctx = context();
+ ctx->CopyResource( stagingBuffer, buffer );
+
+ MappedResource mapped;
+ CHECK( mapped.map( stagingBuffer, true ) );
+ memcpy( rdi, mapped.data(), cb );
+
+ return S_OK;
+}
+
+HRESULT TensorEx::download( void* rdi ) const
+{
+ uint32_t cbElement, numElements;
+ CHECK( getViewSize( cbElement, numElements ) );
+
+ size_t cb = (size_t)cbElement * numElements;
+ return download( rdi, cb );
+} \ No newline at end of file
diff --git a/Whisper/ML/TensorEx.h b/Whisper/ML/TensorEx.h
new file mode 100644
index 0000000..c82f3d8
--- /dev/null
+++ b/Whisper/ML/TensorEx.h
@@ -0,0 +1,42 @@
+#pragma once
+#include "Tensor.h"
+
+namespace DirectCompute
+{
+ // A tensor which supports dynamic updates from CPU, or downloads from VRAM to system RAM
+ class TensorEx : public Tensor
+ {
+ protected:
+ CComPtr<ID3D11Buffer> buffer;
+ CComPtr<ID3D11Buffer> stagingBuffer;
+
+ HRESULT getViewSize( uint32_t& cbElement, uint32_t& countElements ) const;
+
+ public:
+
+ HRESULT create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData );
+ HRESULT create( eDataType type, eBufferUse usage, const std::array<uint32_t, 4>& sizeElements );
+
+ HRESULT download( void* rdi, size_t cb ) const;
+
+ HRESULT download( void* rdi ) const;
+
+ template<class E>
+ HRESULT download( std::vector<E>& vec ) const
+ {
+ uint32_t cbElement, numElements;
+ CHECK( getViewSize( cbElement, numElements ) );
+
+ try
+ {
+ vec.resize( numElements );
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+
+ return download( vec.data(), (size_t)cbElement * numElements );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/TensorGpuViews.cpp b/Whisper/ML/TensorGpuViews.cpp
new file mode 100644
index 0000000..2788153
--- /dev/null
+++ b/Whisper/ML/TensorGpuViews.cpp
@@ -0,0 +1,23 @@
+#include "stdafx.h"
+#include "TensorGpuViews.h"
+using namespace DirectCompute;
+
+HRESULT TensorGpuViews::create( ID3D11Buffer* gpuBuffer, DXGI_FORMAT format, size_t countElements, bool makeUav )
+{
+ srv = nullptr;
+ uav = nullptr;
+
+ if( countElements > UINT_MAX )
+ return DISP_E_OVERFLOW;
+
+ CD3D11_SHADER_RESOURCE_VIEW_DESC viewDesc{ D3D11_SRV_DIMENSION_BUFFER, format, 0, (UINT)countElements };
+ CHECK( device()->CreateShaderResourceView( gpuBuffer, &viewDesc, &srv ) );
+
+ if( makeUav )
+ {
+ CD3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc{ D3D11_UAV_DIMENSION_BUFFER, format , 0, (UINT)countElements };
+ CHECK( device()->CreateUnorderedAccessView( gpuBuffer, &uavDesc, &uav ) );
+ }
+
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/ML/TensorGpuViews.h b/Whisper/ML/TensorGpuViews.h
new file mode 100644
index 0000000..ef26473
--- /dev/null
+++ b/Whisper/ML/TensorGpuViews.h
@@ -0,0 +1,32 @@
+#pragma once
+#include <stdint.h>
+#include "../D3D/device.h"
+
+namespace DirectCompute
+{
+ class TensorGpuViews
+ {
+ protected:
+ CComPtr<ID3D11ShaderResourceView> srv;
+ CComPtr<ID3D11UnorderedAccessView> uav;
+
+ public:
+
+ operator ID3D11ShaderResourceView* ( ) const { return srv; }
+ operator ID3D11UnorderedAccessView* ( ) const { return uav; }
+
+ HRESULT create( ID3D11Buffer* buffer, DXGI_FORMAT format, size_t countElements, bool makeUav );
+
+ void clear()
+ {
+ srv = nullptr;
+ uav = nullptr;
+ }
+
+ void setGpuViews( ID3D11ShaderResourceView* read, ID3D11UnorderedAccessView* write = nullptr )
+ {
+ srv = read;
+ uav = write;
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/TensorShape.cpp b/Whisper/ML/TensorShape.cpp
new file mode 100644
index 0000000..7de6fb8
--- /dev/null
+++ b/Whisper/ML/TensorShape.cpp
@@ -0,0 +1,72 @@
+#include "stdafx.h"
+#include "TensorShape.h"
+#include "../source/ggml.h"
+using namespace DirectCompute;
+
+TensorShape::TensorShape()
+{
+ setZero();
+}
+
+TensorShape::TensorShape( const TensorShape& that )
+{
+ _mm_storeu_si128( ( __m128i* )ne.data(), that.sizeVec() );
+ _mm_storeu_si128( ( __m128i* )nb.data(), that.stridesVec() );
+}
+
+void TensorShape::operator=( const TensorShape& that )
+{
+ _mm_storeu_si128( ( __m128i* )ne.data(), that.sizeVec() );
+ _mm_storeu_si128( ( __m128i* )nb.data(), that.stridesVec() );
+}
+
+HRESULT TensorShape::create( const ggml_tensor& ggml )
+{
+ for( size_t i = 0; i < 4; i++ )
+ ne[ i ] = (uint32_t)ggml.ne[ i ];
+
+ const ggml_type dataType = ggml.type;
+ // Verify a few things
+ uint32_t cbElement = (uint32_t)ggml_type_size( dataType );
+ for( size_t i = 0; i < 4; i++ )
+ {
+ size_t stride = ggml.nb[ i ];
+ if( 0 != stride % cbElement )
+ return E_INVALIDARG;
+ size_t nn = stride / cbElement;
+ if( nn > UINT_MAX )
+ return DISP_E_OVERFLOW;
+ nb[ i ] = (uint32_t)nn;
+ }
+ return S_OK;
+}
+
+TensorShape::TensorShape( const ggml_tensor& ggml )
+{
+ HRESULT hr = create( ggml );
+ if( FAILED( hr ) )
+ throw hr;
+}
+
+void TensorShape::setDenseStrides()
+{
+ nb[ 0 ] = 1;
+ nb[ 1 ] = ne[ 0 ];
+ const uint32_t p01 = ne[ 0 ] * ne[ 1 ];
+ nb[ 2 ] = p01;
+ nb[ 3 ] = p01 * ne[ 2 ];
+}
+
+bool DirectCompute::canMulMat( const TensorShape& t0, const TensorShape& t1 )
+{
+ /*
+ return
+ ( t0.ne[ 0 ] == t1.ne[ 0 ] ) &&
+ ( t0.ne[ 2 ] == t1.ne[ 2 ] ) &&
+ ( t0.ne[ 3 ] == t1.ne[ 3 ] ); */
+ __m128i a = t0.sizeVec();
+ __m128i b = t1.sizeVec();
+ __m128i xx = _mm_xor_si128( a, b );
+ xx = _mm_shuffle_epi32( xx, _MM_SHUFFLE( 3, 2, 0, 0 ) );
+ return (bool)_mm_testz_si128( xx, xx );
+} \ No newline at end of file
diff --git a/Whisper/ML/TensorShape.h b/Whisper/ML/TensorShape.h
new file mode 100644
index 0000000..473b0c9
--- /dev/null
+++ b/Whisper/ML/TensorShape.h
@@ -0,0 +1,120 @@
+#pragma once
+#include <stdint.h>
+#include <array>
+#include <smmintrin.h>
+
+struct ggml_tensor;
+using HRESULT = long;
+
+namespace DirectCompute
+{
+ // This POD structure describes the shape of a tensor.
+ // It’s used for both GPU tensors in VRAM, and tensors in system memory used by the Hybrid model.
+ struct TensorShape
+ {
+ // Count of elements, up to 4 coordinates
+ // The unused coordinates are set to 1
+ std::array<uint32_t, 4> ne;
+
+ // Strides of the tensor
+ // For a dense row-major tensor, these numbers are [ 1, ne[0], ne[0]*ne[1], ne[0]*ne[1]*ne[2] ]
+ // Note that unlike GGML code, these numbers are expressed in elements not bytes, but the meaning is the same
+ // GPU matrices reshaped into panels are keeping different values here: [ 0, panelSize, panelSize * panelsCount, panelSize * panelsCount * ne[ 2 ] ]
+ std::array<uint32_t, 4> nb;
+
+ TensorShape();
+ TensorShape( const TensorShape& that );
+ void operator=( const TensorShape& that );
+ HRESULT create( const ggml_tensor& ggml );
+ TensorShape( const ggml_tensor& ggml );
+
+ __m128i sizeVec() const
+ {
+ return load( ne );
+ }
+ __m128i stridesVec() const
+ {
+ return load( nb );
+ }
+
+ uint32_t countRows() const
+ {
+ return ne[ 1 ] * ne[ 2 ] * ne[ 3 ];
+ }
+
+ uint32_t countElements() const
+ {
+ // return ne[ 0 ] * countRows();
+ const __m128i a = sizeVec();
+ const __m128i b = _mm_srli_si128( a, 4 );
+ const __m128i p2 = _mm_mul_epu32( a, b );
+ uint64_t res = (uint64_t)_mm_extract_epi64( p2, 1 );
+ res *= (uint64_t)_mm_cvtsi128_si64( p2 );
+ assert( 0 == ( res >> 32 ) );
+ return (uint32_t)res;
+ }
+
+ // Compute strides from sizes, assuming dense row-major memory layout of the tensor
+ void setDenseStrides();
+
+ bool isMatrix() const
+ {
+ // return ne[ 2 ] == 1 && ne[ 3 ] == 1;
+ const uint64_t num = *(const uint64_t*)&ne[ 2 ];
+ return num == 0x100000001ull;
+ }
+ bool isVector() const
+ {
+ return 1 == ne[ 1 ] && isMatrix();
+ }
+
+ // True of this tensor is dense and row-major
+ bool isContinuous() const
+ {
+ /* return 1 == nb[ 0 ] &&
+ nb[ 1 ] == nb[ 0 ] * ne[ 0 ] &&
+ nb[ 2 ] == nb[ 1 ] * ne[ 1 ] &&
+ nb[ 3 ] == nb[ 2 ] * ne[ 2 ]; */
+
+ const __m128i nbv = stridesVec();
+ const __m128i nev = sizeVec();
+ __m128i tmp = _mm_mullo_epi32( nbv, nev ); // Vertical product of int32 lanes
+ tmp = _mm_shuffle_epi32( tmp, _MM_SHUFFLE( 2, 1, 0, 0 ) ); // Shift left by 1 int32 lane
+ tmp = _mm_insert_epi32( tmp, 1, 0 ); // Reset X lane to 1
+ return vectorEqual( tmp, nbv );
+ }
+
+ // Reset all fields to zero
+ void setZero()
+ {
+ const __m128i z = _mm_setzero_si128();
+ _mm_storeu_si128( ( __m128i* )ne.data(), z );
+ _mm_storeu_si128( ( __m128i* )nb.data(), z );
+ }
+ };
+
+ // True when two tensors have equal count of elements
+ inline bool isSameShape( const TensorShape& t0, const TensorShape& t1 )
+ {
+ __m128i a = t0.sizeVec();
+ __m128i b = t1.sizeVec();
+ return vectorEqual( a, b );
+ }
+
+ // True when two tensors have equal count of elements, and equal VRAM layout too
+ inline bool isSameShapeAndLayout( const TensorShape& t0, const TensorShape& t1 )
+ {
+ __m128i a, b, x;
+ a = t0.sizeVec();
+ b = t1.sizeVec();
+ x = _mm_xor_si128( a, b );
+
+ a = t0.stridesVec();
+ b = t1.stridesVec();
+ x = _mm_or_si128( x, _mm_xor_si128( a, b ) );
+ return (bool)_mm_testz_si128( x, x );
+ }
+
+ // True when we can multiply two tensors of the provided shapes
+ bool canMulMat( const TensorShape& t0, const TensorShape& t1 );
+} \ No newline at end of file
diff --git a/Whisper/ML/TensorsArena.cpp b/Whisper/ML/TensorsArena.cpp
new file mode 100644
index 0000000..f1a4bed
--- /dev/null
+++ b/Whisper/ML/TensorsArena.cpp
@@ -0,0 +1,117 @@
+#include "stdafx.h"
+#include "TensorsArena.h"
+#include "../D3D/createBuffer.h"
+#include <bit>
+
+uint32_t DirectCompute::defaultNewCapacity( uint32_t current, uint32_t requested )
+{
+ if( 0 == current )
+ {
+ // When the current capacity is 0 this means it's the first resize for the pooled tensor
+ // Create tensor of the exact requested size, as most tensors on these pools are never actually resized.
+ return requested;
+ }
+ else
+ {
+ // Implement some reasonable tactics to grow an old tensor
+ const uint32_t res = std::max( 1024u, std::bit_ceil( requested ) );
+ assert( res >= requested );
+ return res;
+ }
+}
+
+using namespace DirectCompute;
+
+TensorsArena::ArenaImpl::ArenaImpl( eDataType dataType, const sArenaConfig& config ) :
+ type( dataType ),
+ pfnNewCap( nullptr != config.pfnCapInner ? config.pfnCapInner : &defaultNewCapacity )
+{
+ pool.reserve( config.initialCapOuter );
+}
+
+Tensor PooledTensor::tensor( eDataType type, const std::array<uint32_t, 4>& ne, pfnNewCapacity pfnNewCap )
+{
+ const uint32_t p1 = ne[ 0 ] * ne[ 1 ];
+ const uint32_t p2 = ne[ 2 ] * ne[ 3 ];
+ const uint32_t count = p1 * p2;
+
+ if( count > capacity )
+ {
+ views.clear();
+ const uint32_t newCap = pfnNewCap( capacity, count );
+ assert( newCap >= count );
+
+ const size_t cb = elementSize( type ) * newCap;
+ CComPtr<ID3D11Buffer> buffer;
+ check( createBuffer( eBufferUse::ReadWrite, cb, &buffer, nullptr, nullptr ) );
+ check( views.create( buffer, viewFormat( type ), newCap, true ) );
+ capacity = newCap;
+ }
+
+ TensorShape shape;
+ shape.ne = ne;
+ shape.setDenseStrides();
+ Tensor res{ shape, views };
+ res.dbgSetType( type );
+ return res;
+}
+
+Tensor TensorsArena::ArenaImpl::tensor( const std::array<uint32_t, 4>& ne )
+{
+ PooledTensor* res;
+ if( index >= pool.size() )
+ {
+ assert( index == pool.size() );
+ res = &pool.emplace_back();
+ }
+ else
+ res = &pool[ index ];
+
+ index++;
+ return res->tensor( type, ne, pfnNewCap );
+}
+
+TensorsArena::TensorsArena( const sArenaConfigs& configs ) :
+ arenas{ ArenaImpl{ eDataType::FP16, configs.fp16 }, ArenaImpl{ eDataType::FP32, configs.fp32 } }
+{
+ static_assert( 0 == (uint8_t)eDataType::FP16 );
+ static_assert( 1 == (uint8_t)eDataType::FP32 );
+}
+
+Tensor TensorsArena::tensor( eDataType type, const std::array<uint32_t, 4>& ne )
+{
+ ArenaImpl& arena = arenas[ (uint8_t)type ];
+ return arena.tensor( ne );
+}
+
+void TensorsArena::reset()
+{
+ for( ArenaImpl& a : arenas )
+ a.reset();
+}
+
+void TensorsArena::clear()
+{
+ for( ArenaImpl& a : arenas )
+ a.clear();
+}
+
+__m128i TensorsArena::ArenaImpl::getMemoryUse() const
+{
+ const size_t cbElement = elementSize( type );
+ size_t countElts = 0;
+ for( const auto& t : pool )
+ countElts += t.getCapacity();
+
+ const size_t cbVideo = cbElement * countElts;
+ const size_t cbSystem = vectorMemoryUse( pool );
+ return setr_size( cbSystem, cbVideo );
+}
+
+__m128i TensorsArena::getMemoryUse() const
+{
+ __m128i res = _mm_setzero_si128();
+ for( const auto& a : arenas )
+ res = _mm_add_epi64( res, a.getMemoryUse() );
+ return res;
+} \ No newline at end of file
diff --git a/Whisper/ML/TensorsArena.h b/Whisper/ML/TensorsArena.h
new file mode 100644
index 0000000..acfaf86
--- /dev/null
+++ b/Whisper/ML/TensorsArena.h
@@ -0,0 +1,79 @@
+#pragma once
+#include "Tensor.h"
+
+namespace DirectCompute
+{
+ using pfnNewCapacity = uint32_t( * )( uint32_t current, uint32_t requested );
+
+ uint32_t defaultNewCapacity( uint32_t current, uint32_t requested );
+
+ class PooledTensor
+ {
+ TensorGpuViews views;
+ uint32_t capacity = 0;
+ public:
+ Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne, pfnNewCapacity pfnNewCap );
+ size_t getCapacity() const { return capacity; }
+ };
+
+ __interface iTensorArena
+ {
+ Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne );
+ void reset();
+ };
+
+ class TensorsArena: public iTensorArena
+ {
+ public:
+ struct sArenaConfig
+ {
+ pfnNewCapacity pfnCapInner;
+ size_t initialCapOuter;
+ };
+
+ struct sArenaConfigs
+ {
+ sArenaConfig fp16, fp32;
+ };
+
+ TensorsArena( const sArenaConfigs& configs );
+
+ Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final;
+ void reset() override final;
+
+ void clear();
+ __m128i getMemoryUse() const;
+
+ private:
+
+ struct ArenaImpl
+ {
+ ArenaImpl( eDataType dataType, const sArenaConfig& config );
+
+ void reset()
+ {
+ index = 0;
+ }
+
+ void clear()
+ {
+ index = 0;
+ pool.clear();
+ }
+
+ Tensor tensor( const std::array<uint32_t, 4>& ne );
+ __m128i getMemoryUse() const;
+
+ private:
+
+ const eDataType type;
+ const pfnNewCapacity pfnNewCap;
+
+ std::vector<PooledTensor> pool;
+ size_t index = 0;
+ };
+
+ static constexpr size_t countTypes = 2;
+ std::array<ArenaImpl, countTypes> arenas;
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/mlStartup.cpp b/Whisper/ML/mlStartup.cpp
new file mode 100644
index 0000000..815ce52
--- /dev/null
+++ b/Whisper/ML/mlStartup.cpp
@@ -0,0 +1,27 @@
+#include "stdafx.h"
+#include "mlStartup.h"
+#include "../D3D/startup.h"
+#include "LookupTables.h"
+
+namespace
+{
+ static DirectCompute::LookupTables s_tables;
+}
+
+namespace DirectCompute
+{
+ const LookupTables& lookupTables = s_tables;
+
+ HRESULT mlStartup()
+ {
+ CHECK( d3dStartup() );
+ CHECK( s_tables.create() );
+ return S_OK;
+ }
+
+ void mlShutdown()
+ {
+ s_tables.clear();
+ d3dShutdown();
+ }
+} \ No newline at end of file
diff --git a/Whisper/ML/mlStartup.h b/Whisper/ML/mlStartup.h
new file mode 100644
index 0000000..ef5020c
--- /dev/null
+++ b/Whisper/ML/mlStartup.h
@@ -0,0 +1,8 @@
+#pragma once
+using HRESULT = long;
+
+namespace DirectCompute
+{
+ HRESULT mlStartup();
+ void mlShutdown();
+} \ No newline at end of file
diff --git a/Whisper/ML/reshapedMultiply.h b/Whisper/ML/reshapedMultiply.h
new file mode 100644
index 0000000..7cf2365
--- /dev/null
+++ b/Whisper/ML/reshapedMultiply.h
@@ -0,0 +1,10 @@
+#pragma once
+#include <stdint.h>
+
+namespace DirectCompute
+{
+ namespace ReshapedMultiply
+ {
+ constexpr uint32_t TILE_SIZE = 32;
+ }
+} \ No newline at end of file
diff --git a/Whisper/ML/tensorOpsTests.cpp b/Whisper/ML/tensorOpsTests.cpp
new file mode 100644
index 0000000..adc020e
--- /dev/null
+++ b/Whisper/ML/tensorOpsTests.cpp
@@ -0,0 +1,183 @@
+#include "stdafx.h"
+#include "tensorOpsTests.h"
+#include "MlContext.h"
+#include "TensorEx.h"
+#include "../D3D/shaders.h"
+#include "../D3D/Binder.h"
+#include "testUtils.h"
+#include "../Whisper/WhisperContext.h"
+
+void DirectCompute::testMulMat( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst, const void* tempBuffer )
+{
+ return;
+ CaptureRaii capture;
+ const size_t nb00 = src0->nb[ 0 ];
+ const size_t nb01 = src0->nb[ 1 ];
+
+ if( src0->type != GGML_TYPE_F16 )
+ return; // TODO
+
+ if( nb01 < nb00 )
+ return; // TODO
+
+ WhisperContext& ctx = WhisperContext::current();
+
+ Tensor arg0, arg1;
+ check( arg0.create( *src0, eBufferUse::Immutable, true ) );
+ check( arg1.create( *src1, eBufferUse::Immutable, true ) );
+ TensorEx res;
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ ctx.mulMat( arg0, arg1, res );
+
+ std::vector<float> tv;
+ check( res.download( tv ) );
+
+ const size_t len = tv.size();
+ computeDiff( tv.data(), (const float*)dst->data, len ).print( "testMulMat-product" );
+
+#if 0
+ dbgWriteBinaryFile( L"product-orig.bin", dst->data, len * 4 );
+ dbgWriteBinaryFile( L"product-gpu.bin", tv.data(), len * 4 );
+ __debugbreak();
+#endif
+}
+
+#if 0
+void DirectCompute::testMulMatReshape( const ggml_tensor* src1, const void* tempBuffer )
+{
+ Tensor src;
+ check( src.create( *src1, eBufferUse::Immutable, true ) );
+
+ const size_t ne10 = (uint32_t)src1->ne[ 0 ];
+ const size_t ne11 = (uint32_t)src1->ne[ 1 ];
+ const size_t ne12 = (uint32_t)src1->ne[ 2 ];
+ const size_t ne13 = (uint32_t)src1->ne[ 3 ];
+ if( 1 != ne13 )
+ throw E_UNEXPECTED;
+ const size_t tempLength = ne10 * ne11 * ne12 * ne13;
+
+ Context& ctx = Context::current();
+ const ReadWriteViews& temp = ctx.temp.fp16( tempLength );
+
+ {
+ Binder bind;
+ ctx.cb.bind();
+
+ bindShader( eComputeShader::mulMatDotReshape );
+
+ ctx.cb.update( src );
+ bind.bind( src, temp );
+ context()->Dispatch( (UINT)ne11, (UINT)ne12, 1 );
+ }
+
+ std::vector<uint16_t> reshaped;
+ check( downloadBuffer( temp, reshaped ) );
+ computeDiff( reshaped.data(), (const uint16_t*)tempBuffer, reshaped.size() ).print( "testMulMatReshape" );
+
+#if 0
+ dbgWriteBinaryFile( L"fp32.bin", src1->data, ggml_nbytes( src1 ) );
+ dbgWriteBinaryFile( L"fp16-cpu.bin", tempBuffer, reshaped.size() * 2 );
+ dbgWriteBinaryFile( L"fp16-gpu.bin", reshaped.data(), reshaped.size() * 2 );
+ __debugbreak();
+#endif
+}
+#endif
+
+void DirectCompute::computeMulMat( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst )
+{
+ CaptureRaii capture;
+ const size_t nb00 = src0->nb[ 0 ];
+ const size_t nb01 = src0->nb[ 1 ];
+
+ if( src0->type != GGML_TYPE_F16 )
+ throw E_INVALIDARG;
+ if( nb01 < nb00 )
+ throw E_INVALIDARG;
+
+ WhisperContext& ctx = WhisperContext::current();
+
+ Tensor arg0, arg1;
+ check( arg0.create( *src0, eBufferUse::Immutable, true ) );
+ check( arg1.create( *src1, eBufferUse::Immutable, true ) );
+ TensorEx res;
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ ctx.mulMat( arg0, arg1, res );
+
+ check( res.download( dst->data ) );
+}
+
+void DirectCompute::testFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, const ggml_tensor* dst )
+{
+ CaptureRaii capture;
+
+ Tensor Q, K, V;
+ TensorEx res;
+ check( Q.create( *q, eBufferUse::Immutable, true ) );
+ check( K.create( *k, eBufferUse::Immutable, true ) );
+ check( V.create( *v, eBufferUse::Immutable, true ) );
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.flashAttention( Q, K, V, res, masked );
+
+ std::vector<float> tv;
+ check( res.download( tv ) );
+
+ const size_t len = tv.size();
+ computeDiff( tv.data(), (const float*)dst->data, len ).print( "testFlashAttention" );
+}
+
+void DirectCompute::computeFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, ggml_tensor* dst )
+{
+ CaptureRaii capture;
+
+ Tensor Q, K, V;
+ TensorEx res;
+ check( Q.create( *q, eBufferUse::Immutable, true ) );
+ check( K.create( *k, eBufferUse::Immutable, true ) );
+ check( V.create( *v, eBufferUse::Immutable, true ) );
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.flashAttention( Q, K, V, res, masked );
+
+ check( res.download( dst->data ) );
+}
+
+void DirectCompute::testConvolution( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst )
+{
+ CaptureRaii capture;
+
+ Tensor arg0, arg1;
+ check( arg0.create( *src0, eBufferUse::Immutable, true ) );
+ check( arg1.create( *src1, eBufferUse::Immutable, true ) );
+ TensorEx res;
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.convolution( arg0, arg1, res );
+
+ std::vector<float> tv;
+ check( res.download( tv ) );
+
+ const size_t len = tv.size();
+ computeDiff( tv.data(), (const float*)dst->data, len ).print( "testConvolution" );
+}
+
+void DirectCompute::computeConvolution( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst )
+{
+ CaptureRaii capture;
+
+ Tensor arg0, arg1;
+ check( arg0.create( *src0, eBufferUse::Immutable, true ) );
+ check( arg1.create( *src1, eBufferUse::Immutable, true ) );
+ TensorEx res;
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.convolution( arg0, arg1, res );
+
+ res.download( dst->data );
+} \ No newline at end of file
diff --git a/Whisper/ML/tensorOpsTests.h b/Whisper/ML/tensorOpsTests.h
new file mode 100644
index 0000000..7820bba
--- /dev/null
+++ b/Whisper/ML/tensorOpsTests.h
@@ -0,0 +1,15 @@
+#pragma once
+#include "../source/ggml.h"
+
+namespace DirectCompute
+{
+ // void testMulMatReshape( const ggml_tensor* src1, const void* tempBuffer );
+ void testMulMat( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst, const void* tempBuffer );
+ void computeMulMat( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst );
+
+ void testFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, const ggml_tensor* dst );
+ void computeFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, ggml_tensor* dst );
+
+ void testConvolution( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst );
+ void computeConvolution( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst );
+} \ No newline at end of file
diff --git a/Whisper/ML/testUtils.cpp b/Whisper/ML/testUtils.cpp
new file mode 100644
index 0000000..8a20e49
--- /dev/null
+++ b/Whisper/ML/testUtils.cpp
@@ -0,0 +1,334 @@
+#include "stdafx.h"
+#include "testUtils.h"
+#include <immintrin.h>
+#include <atlfile.h>
+#include <atlpath.h>
+
+namespace
+{
+ using DirectCompute::sTensorDiff;
+
+ __forceinline __m256 load( const float* rsi )
+ {
+ return _mm256_loadu_ps( rsi );
+ }
+
+ __forceinline __m256 load( const uint16_t* rsi )
+ {
+ const __m128i iv = _mm_load_si128( ( const __m128i* )rsi );
+ return _mm256_cvtph_ps( iv );
+ }
+
+ __forceinline void loadPartial( const uint16_t* x, const uint16_t* y, size_t count, __m256& fx, __m256& fy )
+ {
+ __m128i ix, iy;
+ switch( count )
+ {
+ case 1: // load 2 bytes
+ ix = _mm_cvtsi32_si128( *x );
+ iy = _mm_cvtsi32_si128( *y );
+ break;
+ case 2: // load 4 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ iy = _mm_cvtsi32_si128( *(const int*)y );
+ break;
+ case 3: // load 6 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ iy = _mm_cvtsi32_si128( *(const int*)y );
+ ix = _mm_insert_epi16( ix, x[ 2 ], 2 );
+ iy = _mm_insert_epi16( iy, y[ 2 ], 2 );
+ break;
+ case 4: // load 8 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ break;
+ case 5: // load 10 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi16( ix, x[ 4 ], 4 );
+ iy = _mm_insert_epi16( iy, y[ 4 ], 4 );
+ break;
+ case 6: // load 12 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 );
+ break;
+ case 7: // load 14 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ iy = _mm_cvtsi64_si128( *(const int64_t*)y );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ iy = _mm_insert_epi32( iy, *(const int*)( y + 4 ), 2 );
+ ix = _mm_insert_epi16( ix, x[ 6 ], 6 );
+ iy = _mm_insert_epi16( iy, y[ 6 ], 6 );
+ break;
+ default:
+ fx = fy = _mm256_setzero_ps();
+ return;
+ }
+
+ fx = _mm256_cvtph_ps( ix );
+ fy = _mm256_cvtph_ps( iy );
+ }
+
+ inline __m128 loadFloat2( const float* rsi )
+ {
+ return _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) );
+ }
+ inline __m128 loadFloat3( const float* rsi )
+ {
+ __m128 f = loadFloat2( rsi );
+ f = _mm_insert_ps( f, _mm_load_ss( rsi + 2 ), 0x20 );
+ return f;
+ }
+ __forceinline void loadPartial( const float* x, const float* y, size_t count, __m256& fx, __m256& fy )
+ {
+ __m128 low1, high1;
+ __m128 low2, high2;
+ high1 = high2 = _mm_setzero_ps();
+ switch( count )
+ {
+ case 1:
+ low1 = _mm_load_ss( x );
+ low2 = _mm_load_ss( y );
+ break;
+ case 2:
+ low1 = loadFloat2( x );
+ low2 = loadFloat2( y );
+ break;
+ case 3:
+ low1 = loadFloat3( x );
+ low2 = loadFloat3( y );
+ break;
+ case 4:
+ low1 = _mm_loadu_ps( x );
+ low2 = _mm_loadu_ps( y );
+ break;
+ case 5:
+ low1 = _mm_loadu_ps( x );
+ low2 = _mm_loadu_ps( y );
+ high1 = _mm_load_ss( x + 4 );
+ high2 = _mm_load_ss( y + 4 );
+ break;
+ case 6:
+ low1 = _mm_loadu_ps( x );
+ low2 = _mm_loadu_ps( y );
+ high1 = loadFloat2( x + 4 );
+ high2 = loadFloat2( y + 4 );
+ break;
+ case 7: // load 14 bytes
+ low1 = _mm_loadu_ps( x );
+ low2 = _mm_loadu_ps( y );
+ high1 = loadFloat3( x + 4 );
+ high2 = loadFloat3( y + 4 );
+ break;
+ default:
+ fx = fy = _mm256_setzero_ps();
+ return;
+ }
+
+ fx = _mm256_setr_m128( low1, high1 );
+ fy = _mm256_setr_m128( low2, high2 );
+ }
+
+ __forceinline float horizontalMaximum( __m256 v )
+ {
+ __m128 s = _mm256_extractf128_ps( v, 1 );
+ s = _mm_max_ps( s, _mm256_castps256_ps128( v ) );
+ s = _mm_max_ps( s, _mm_movehl_ps( s, s ) );
+ s = _mm_max_ss( s, _mm_movehdup_ps( s ) );
+ return _mm_cvtss_f32( s );
+ }
+
+ __forceinline double horizontalSum( __m256 v )
+ {
+ __m256d d = _mm256_cvtps_pd( _mm256_extractf128_ps( v, 1 ) );
+ d = _mm256_add_pd( d, _mm256_cvtps_pd( _mm256_castps256_ps128( v ) ) );
+
+ __m128d s = _mm256_extractf128_pd( d, 1 );
+ s = _mm_add_pd( s, _mm256_castpd256_pd128( d ) );
+ s = _mm_add_sd( s, _mm_unpackhi_pd( s, s ) );
+ return _mm_cvtsd_f64( s );
+ }
+
+ __m256 maskInfNan( __m256 diff, __m256 a, __m256 b )
+ {
+ __m256i ai = _mm256_castps_si256( a );
+ __m256i bi = _mm256_castps_si256( b );
+ __m256i eqi = _mm256_cmpeq_epi32( ai, bi );
+ __m256 eq = _mm256_castsi256_ps( eqi );
+ return _mm256_andnot_ps( eq, diff );
+ }
+
+ class DiffAcc
+ {
+ __m256 maxAbs = _mm256_setzero_ps();
+ __m256 sumSquares = _mm256_setzero_ps();
+
+ public:
+
+ __forceinline void add( __m256 a, __m256 b )
+ {
+ const __m256 neg0 = _mm256_set1_ps( -0.0f );
+ __m256 diff = _mm256_sub_ps( b, a );
+ diff = maskInfNan( diff, a, b );
+ sumSquares = _mm256_fmadd_ps( diff, diff, sumSquares );
+ const __m256 absDiff = _mm256_andnot_ps( neg0, diff );
+ maxAbs = _mm256_max_ps( maxAbs, absDiff );
+ }
+
+ __forceinline sTensorDiff reduce( size_t count )
+ {
+ sTensorDiff res;
+ res.maxAbsDiff = horizontalMaximum( maxAbs );
+ res.avgDiffSquared = (float)( horizontalSum( sumSquares ) / (double)(int64_t)count );
+ res.length = count;
+ return res;
+ }
+ };
+
+ template<class E>
+ static sTensorDiff __declspec( noinline ) diffVectors( const E* a, const E* b, size_t length )
+ {
+ // const E* const aEnd = a + length;
+ const E* const aEndAligned = a + ( length / 8 ) * 8;
+ const size_t remainder = length % 8;
+
+ DiffAcc acc;
+ for( ; a < aEndAligned; a += 8, b += 8 )
+ acc.add( load( a ), load( b ) );
+
+ if( remainder != 0 )
+ {
+ __m256 va, vb;
+ loadPartial( a, b, remainder, va, vb );
+ acc.add( va, vb );
+ }
+
+ return acc.reduce( length );
+ }
+}
+
+sTensorDiff DirectCompute::computeDiff( const float* a, const float* b, size_t length )
+{
+ return diffVectors( a, b, length );
+}
+
+sTensorDiff DirectCompute::computeDiff( const uint16_t* a, const uint16_t* b, size_t length )
+{
+ return diffVectors( a, b, length );
+}
+
+void DirectCompute::sTensorDiff::print( const char* what ) const
+{
+ logDebug( u8"%s: length %zu, maxAbsDiff = %g, avgDiffSquared = %g", what, length, maxAbsDiff, avgDiffSquared );
+}
+void DirectCompute::sTensorDiff::print() const
+{
+ logDebug( u8"%zu elements, maxAbsDiff = %g, avgDiffSquared = %g", length, maxAbsDiff, avgDiffSquared );
+}
+
+HRESULT DirectCompute::dbgWriteBinaryFile( LPCTSTR fileName, const void* rsi, size_t cb )
+{
+ CPath path;
+ path.m_strPath = LR"(C:\Temp\2remove\Whisper)";
+ path.Append( fileName );
+
+ CAtlFile file;
+ CHECK( file.Create( path, GENERIC_WRITE, 0, CREATE_ALWAYS ) );
+ CHECK( file.Write( rsi, (DWORD)cb ) );
+ CHECK( file.Flush() );
+ return S_OK;
+}
+
+#include "Tensor.h"
+
+sTensorDiff DirectCompute::computeDiff( const Tensor& a, const Tensor& b )
+{
+ assert( isSameShapeAndLayout( a, b ) );
+ const eDataType dt = a.getType();
+ assert( dt == b.getType() );
+ switch( dt )
+ {
+ case eDataType::FP32:
+ {
+ std::vector<float> v1, v2;
+ a.download( v1 );
+ b.download( v2 );
+ assert( v1.size() == v2.size() );
+#if 0
+ const size_t firstZero = std::find( v2.begin(), v2.end(), 0.0f ) - v2.begin();
+
+ std::vector<float> delta;
+ delta.resize( v1.size() );
+ for( size_t i = 0; i < v1.size(); i++ )
+ delta[ i ] = std::abs( v1[ i ] - v2[ i ] );
+ const size_t maxIndex = std::max_element( delta.begin(), delta.end() ) - delta.begin();
+#endif
+ return computeDiff( v1.data(), v2.data(), v1.size() );
+ }
+ }
+ throw E_NOTIMPL;
+}
+
+using namespace DirectCompute;
+
+void PrintUniqueTensorSizes::printImpl( const std::array<uint32_t, 8>& a )
+{
+ auto pair = set.emplace( a );
+ if( !pair.second )
+ return; // was already there
+
+ const __m128i rhs = _mm_loadu_si128( ( const __m128i* ) ( &a[ 4 ] ) );
+
+ if( _mm_testz_si128( rhs, rhs ) )
+ {
+ logDebug( u8"%s: [ %i, %i, %i, %i ]", what,
+ a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ] );
+ }
+ else
+ {
+ logDebug( u8"%s: [ %i, %i, %i, %i ], [ %i, %i, %i, %i ]", what,
+ a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ], a[ 4 ], a[ 5 ], a[ 6 ], a[ 7 ] );
+ }
+}
+
+void PrintUniqueTensorSizes::print( const Tensor& lhs, const Tensor& rhs )
+{
+ std::array<uint32_t, 8> arr;
+ __m128i* const rdi = ( __m128i* )arr.data();
+ _mm_storeu_si128( rdi, lhs.sizeVec() );
+ _mm_storeu_si128( rdi + 1, rhs.sizeVec() );
+
+ printImpl( arr );
+}
+
+void PrintUniqueTensorSizes::print( const int* lhs, const int* rhs )
+{
+ std::array<uint32_t, 8> arr;
+ __m128i* const rdi = ( __m128i* )arr.data();
+ _mm_storeu_si128( rdi, load16( lhs ) );
+ _mm_storeu_si128( rdi + 1, load16( rhs ) );
+
+ printImpl( arr );
+}
+
+void PrintUniqueTensorSizes::print( const Tensor& lhs )
+{
+ std::array<uint32_t, 8> arr;
+ __m128i* const rdi = ( __m128i* )arr.data();
+ _mm_storeu_si128( rdi, lhs.sizeVec() );
+ _mm_storeu_si128( rdi + 1, _mm_setzero_si128() );
+
+ printImpl( arr );
+}
+
+#include "testUtilsC.h"
+
+void printUniqueTensorSize( const char* name, const int* lhs, const int* rhs )
+{
+ using TS = DirectCompute::PrintUniqueTensorSizes;
+ static std::unordered_map<std::string, TS> map;
+ TS& ts = map.try_emplace( name, name ).first->second;
+ ts.print( lhs, rhs );
+} \ No newline at end of file
diff --git a/Whisper/ML/testUtils.h b/Whisper/ML/testUtils.h
new file mode 100644
index 0000000..1e62e38
--- /dev/null
+++ b/Whisper/ML/testUtils.h
@@ -0,0 +1,62 @@
+#pragma once
+#include "../D3D/downloadBuffer.h"
+#include "../D3D/RenderDoc/renderDoc.h"
+#include <unordered_set>
+#include <functional>
+
+// Funfact: this code written by ChatGPT
+namespace std
+{
+ template<>
+ struct hash<array<uint32_t, 8>>
+ {
+ size_t operator()( const array<uint32_t, 8>& arr ) const
+ {
+ size_t result = 0;
+ for( uint32_t element : arr )
+ result = ( result * 31 ) ^ element;
+ return result;
+ }
+ };
+}
+
+namespace DirectCompute
+{
+ struct sTensorDiff
+ {
+ // maximum( absolute( a - b ) )
+ float maxAbsDiff;
+ // average( ( a - b )^2 )
+ float avgDiffSquared;
+ size_t length;
+
+ void print() const;
+ void print( const char* what ) const;
+ };
+
+ // Compute difference between 2 FP32 vectors
+ sTensorDiff computeDiff( const float* a, const float* b, size_t length );
+
+ // Compute difference between 2 FP16 vectors
+ sTensorDiff computeDiff( const uint16_t* a, const uint16_t* b, size_t length );
+
+ class Tensor;
+ sTensorDiff computeDiff( const Tensor& a, const Tensor& b );
+
+ HRESULT dbgWriteBinaryFile( LPCTSTR fileName, const void* rsi, size_t cb );
+
+ // Print unique sizes of the two tensors
+ class PrintUniqueTensorSizes
+ {
+ std::unordered_set<std::array<uint32_t, 8>> set;
+ const char* const what;
+ void printImpl( const std::array<uint32_t, 8>& a );
+
+ public:
+ PrintUniqueTensorSizes( const char* w ) : what( w ) { }
+
+ void print( const Tensor& lhs, const Tensor& rhs );
+ void print( const Tensor& lhs );
+ void print( const int* lhs, const int* rhs );
+ };
+} \ No newline at end of file
diff --git a/Whisper/ML/testUtilsC.h b/Whisper/ML/testUtilsC.h
new file mode 100644
index 0000000..2246e92
--- /dev/null
+++ b/Whisper/ML/testUtilsC.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+ void printUniqueTensorSize( const char* name, const int* lhs, const int* rhs );
+#ifdef __cplusplus
+}
+#endif \ No newline at end of file
diff --git a/Whisper/Readme.txt b/Whisper/Readme.txt
new file mode 100644
index 0000000..fc0e9b8
--- /dev/null
+++ b/Whisper/Readme.txt
@@ -0,0 +1,9 @@
+This C++ project builds a DLL which actually does the heavy lifting of this project.
+
+It implements the ML model, handles multimedia files with Media Foundation, captures audio (also with MF), does voice activity detection (custom code running on CPU), and a few smaller things.
+
+The code requires C++/20, and only tested with Visual Studio 2022.
+
+When running pure GPGPU model, the DLL requires SSE 4.1 instruction set.
+
+When running a hybrid model, the DLL requires AVX1, FMA3, F16C, and BMI1 instruction set extensions. \ No newline at end of file
diff --git a/Whisper/Resource.rc b/Whisper/Resource.rc
new file mode 100644
index 0000000..cb93b99
--- /dev/null
+++ b/Whisper/Resource.rc
Binary files differ
diff --git a/Whisper/Utils/CpuProfiler.cpp b/Whisper/Utils/CpuProfiler.cpp
new file mode 100644
index 0000000..6161e95
--- /dev/null
+++ b/Whisper/Utils/CpuProfiler.cpp
@@ -0,0 +1,65 @@
+#include "stdafx.h"
+#include "CpuProfiler.h"
+
+namespace
+{
+ using namespace Whisper;
+
+ inline int64_t qpcNow()
+ {
+ int64_t res;
+ QueryPerformanceCounter( (LARGE_INTEGER*)&res );
+ return res;
+ }
+
+ class CpuTimescale
+ {
+ uint64_t frequency = 0;
+ const int64_t tscStart;
+ const int64_t qpcStart;
+
+ uint64_t computeTscFrequency();
+
+ public:
+
+ CpuTimescale() :
+ tscStart( tscNow() ),
+ qpcStart( qpcNow() )
+ { }
+
+ inline uint64_t computeTicks( uint64_t tsc )
+ {
+ uint64_t freq = frequency;
+ if( freq == 0 )
+ freq = computeTscFrequency();
+
+ return makeTime( tsc, freq );
+ }
+ };
+
+ uint64_t __declspec( noinline ) CpuTimescale::computeTscFrequency()
+ {
+ int64_t tsc = tscNow();
+ int64_t qpc = qpcNow();
+ tsc -= tscStart;
+ qpc -= qpcStart;
+
+ uint64_t qpcFreq;
+ QueryPerformanceFrequency( (LARGE_INTEGER*)&qpcFreq );
+
+ // Seconds = qpc / qpcFreq
+ // ticks per second = tsc / seconds = tsc * qpcFreq / qpc
+ uint64_t res = ( (uint64_t)tsc * qpcFreq + ( (uint64_t)qpc / 2 ) - 1 ) / (uint64_t)qpc;
+ frequency = res;
+ const double GHz = (double)(int64_t)res * 1.0E-9;
+ logDebug( u8"Computed CPU base frequency: %g GHz", GHz );
+ return res;
+ }
+
+ static CpuTimescale timescale;
+}
+
+uint64_t Whisper::ticksFromTsc( uint64_t tscDiff )
+{
+ return timescale.computeTicks( tscDiff );
+} \ No newline at end of file
diff --git a/Whisper/Utils/CpuProfiler.h b/Whisper/Utils/CpuProfiler.h
new file mode 100644
index 0000000..17887f8
--- /dev/null
+++ b/Whisper/Utils/CpuProfiler.h
@@ -0,0 +1,26 @@
+#pragma once
+
+namespace Whisper
+{
+ // Get current time in CPU clock
+ // More specifically, each CPU core has a timestamp counter which runs at CPU's base frequency, regardless on the frequency scaling of that core.
+ inline int64_t tscNow()
+ {
+ return __rdtsc();
+ }
+
+ // Scale the time interval from CPU time stamp counter clock into 100-nanosecond ticks, rounding to nearest
+ uint64_t ticksFromTsc( uint64_t tscDiff );
+
+ class CpuProfiler
+ {
+ const int64_t started = tscNow();
+
+ public:
+
+ uint64_t elapsed() const
+ {
+ return ticksFromTsc( (uint64_t)( tscNow() - started ) );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Utils/GpuProfiler.cpp b/Whisper/Utils/GpuProfiler.cpp
new file mode 100644
index 0000000..6f19415
--- /dev/null
+++ b/Whisper/Utils/GpuProfiler.cpp
@@ -0,0 +1,374 @@
+#include "stdafx.h"
+#include "GpuProfiler.h"
+#include "GpuProfilerSimple.h"
+using namespace DirectCompute;
+
+inline void GpuProfiler::sProfilerData::reset()
+{
+ _mm_storeu_si128( ( __m128i* ) & callsPending, _mm_setzero_si128() );
+}
+
+inline void GpuProfiler::sProfilerData::addPending( int64_t time )
+{
+ callsPending++;
+ timePending += time;
+}
+
+inline void GpuProfiler::sProfilerData::dropPending()
+{
+ callsPending = 0;
+ timePending = 0;
+}
+
+inline void GpuProfiler::sProfilerData::makeTime( uint64_t freq )
+{
+ dest->count += callsPending;
+ dest->totalTicks += ::makeTime( timePending, freq );
+ callsPending = 0;
+ timePending = 0;
+}
+
+HRESULT GpuProfiler::Queue::create()
+{
+ ID3D11Device* const dev = device();
+
+ CD3D11_QUERY_DESC desc{ D3D11_QUERY_TIMESTAMP };
+ for( Entry& e : queue )
+ {
+ CHECK( dev->CreateQuery( &desc, &e.query ) );
+ e.block = nullptr;
+ e.event = eEvent::None;
+ e.shader = EmptyShader;
+ }
+ return S_OK;
+}
+
+namespace
+{
+ static uint64_t getTimestamp( ID3D11Query* query )
+ {
+ ID3D11DeviceContext* const ctx = context();
+
+ uint64_t res = 0;
+ while( true )
+ {
+ const HRESULT hr = ctx->GetData( query, &res, sizeof( uint64_t ), 0 );
+ check( hr );
+ if( S_OK == hr )
+ return res;
+#if 0
+ Sleep( 1 );
+#else
+ for( size_t i = 0; i < 1024; i++ )
+ _mm_pause();
+#endif
+ }
+ }
+
+ static D3D11_QUERY_DATA_TIMESTAMP_DISJOINT waitForDisjointData( ID3D11Query* query )
+ {
+ ID3D11DeviceContext* const ctx = context();
+ ctx->End( query );
+
+ D3D11_QUERY_DATA_TIMESTAMP_DISJOINT res;
+ while( true )
+ {
+ const HRESULT hr = ctx->GetData( query, &res, sizeof( D3D11_QUERY_DATA_TIMESTAMP_DISJOINT ), 0 );
+ check( hr );
+ if( S_OK == hr )
+ return res;
+ Sleep( 1 );
+ }
+ }
+}
+
+void GpuProfiler::Queue::Entry::join( GpuProfiler& owner )
+{
+ assert( nullptr != block );
+
+ uint64_t res = getTimestamp( query );
+#if PROFILER_COLLECT_TAGS
+ block->haveTimestamp( event, shader, tag, res, owner );
+#else
+ block->haveTimestamp( event, shader, 0, res, owner );
+#endif
+ block = nullptr;
+ event = eEvent::None;
+ shader = EmptyShader;
+}
+
+void GpuProfiler::Queue::submit( BlockState* block, eEvent evt, uint16_t shader, uint16_t tag )
+{
+ // if( evt == GpuProfiler::eEvent::Shader && shader == 0 ) __debugbreak();
+ assert( nullptr != block );
+
+ Entry& e = queue[ nextEntry ];
+ if( nullptr != e.block )
+ e.join( owner );
+
+ e.block = block;
+ e.event = evt;
+ e.shader = shader;
+#if PROFILER_COLLECT_TAGS
+ e.tag = tag;
+#endif
+ context()->End( e.query );
+ nextEntry = ( nextEntry + 1 ) % queueLength;
+}
+
+void GpuProfiler::Queue::join()
+{
+ while( true )
+ {
+ Entry& e = queue[ nextEntry ];
+ if( nullptr == e.block )
+ return;
+ e.join( owner );
+ nextEntry = ( nextEntry + 1 ) % queueLength;
+ }
+}
+
+static inline uint32_t makeTagKey( uint16_t cs, uint16_t tag )
+{
+ uint32_t r = cs;
+ r = r << 16;
+ r |= tag;
+ return r;
+}
+
+void GpuProfiler::BlockState::completePrevShader( uint64_t time, GpuProfiler& profiler )
+{
+ if( shaderStart == -1 )
+ return;
+ assert( prevShader != EmptyShader );
+ const int64_t elapsed = (int64_t)time - shaderStart;
+
+ sProfilerData* dest = nullptr;
+ auto* p = profiler.results.Lookup( prevShader );
+ if( nullptr != p )
+ dest = &p->m_value;
+ else
+ {
+ sProfilerData& res = profiler.results[ prevShader ];
+ res.dest = &profiler.dest.measure( (eComputeShader)prevShader );
+ dest = &res;
+ }
+ dest->addPending( elapsed );
+
+#if PROFILER_COLLECT_TAGS
+ if( 0 != prevShaderTag )
+ {
+ const uint32_t key = makeTagKey( prevShader, prevShaderTag );
+ auto* pt = profiler.resultsTagged.Lookup( key );
+ if( nullptr != pt )
+ dest = &pt->m_value;
+ else
+ {
+ sProfilerData& res = profiler.resultsTagged[ key ];
+ res.dest = &profiler.dest.measure( (eComputeShader)prevShader, prevShaderTag );
+ dest = &res;
+ }
+ dest->addPending( elapsed );
+ }
+#endif
+ prevShader = EmptyShader;
+ prevShaderTag = 0;
+ shaderStart = -1;
+}
+
+void GpuProfiler::BlockState::haveTimestamp( eEvent evt, uint16_t cs, uint16_t tag, uint64_t time, GpuProfiler& profiler )
+{
+ switch( evt )
+ {
+ case eEvent::BlockStart:
+ assert( -1 == timeStart );
+ assert( -1 == shaderStart );
+ assert( cs == EmptyShader );
+ timeStart = (int64_t)time;
+ if( nullptr != parentBlock )
+ parentBlock->completePrevShader( time, profiler );
+ return;
+ case eEvent::BlockEnd:
+ assert( -1 != timeStart );
+ assert( cs == EmptyShader );
+ completePrevShader( time, profiler );
+ destBlock->addPending( (int64_t)time - timeStart );
+ timeStart = -1;
+ return;
+ case eEvent::Shader:
+ assert( cs != EmptyShader );
+ // if( cs == (uint16_t)0 ) __debugbreak();
+ completePrevShader( time, profiler );
+ prevShader = cs;
+ prevShaderTag = tag;
+ shaderStart = (int64_t)time;
+ return;
+ }
+ assert( false );
+}
+
+HRESULT GpuProfiler::create( size_t maxDepth )
+{
+ CD3D11_QUERY_DESC desc{ D3D11_QUERY_TIMESTAMP_DISJOINT };
+ CHECK( device()->CreateQuery( &desc, &disjoint ) );
+ CHECK( queries.create() );
+ stack.reserve( maxDepth );
+ return S_OK;
+}
+
+void GpuProfiler::blockStart( eProfilerBlock which )
+{
+ BlockState* parentBlock;
+ if( stack.empty() )
+ {
+ context()->Begin( disjoint );
+ parentBlock = nullptr;
+ }
+ else
+ parentBlock = *stack.rbegin();
+
+ BlockState* bs = nullptr;
+ auto p = blockStates.Lookup( which );
+ if( nullptr != p )
+ bs = &p->m_value;
+ else
+ {
+ BlockState& block = blockStates[ which ];
+ block.destBlock = &results[ (uint16_t)which ];
+ block.destBlock->dest = &dest.measure( which );
+ bs = &block;
+ }
+ bs->parentBlock = parentBlock;
+ queries.submit( bs, eEvent::BlockStart );
+ stack.push_back( bs );
+}
+
+void GpuProfiler::blockEnd()
+{
+ assert( !stack.empty() );
+ BlockState* const bs = *stack.rbegin();
+ queries.submit( bs, eEvent::BlockEnd );
+ stack.pop_back();
+
+ if( !stack.empty() )
+ return;
+
+ const D3D11_QUERY_DATA_TIMESTAMP_DISJOINT dtsd = waitForDisjointData( disjoint );
+ queries.join();
+
+ if( !dtsd.Disjoint )
+ {
+ // Fortunately, these timers appear to be relatively high resolution.
+ // Specifically, on the iGPU inside Ryzen 7 5700G that frequency is 1E+8 = 100 MHz
+ // On nVidia 1080Ti, that frequency is 1E+9 = 1 GHz
+ const uint64_t freq = dtsd.Frequency;
+ resultsMakeTime( freq );
+ }
+ else
+ {
+ // Something occurred in between the query's ID3D11DeviceContext::Begin and ID3D11DeviceContext::End calls
+ // that caused the timestamp counter to become discontinuous or disjoint, such as unplugging the AC cord on a laptop, overheating, or throttling up/down due to laptop savings events.
+ // The timestamp returned by ID3D11DeviceContext::GetData for a timestamp query is only reliable if Disjoint is FALSE.
+ resultsDropPending();
+ }
+}
+
+void GpuProfiler::computeShader( eComputeShader cs )
+{
+ assert( !stack.empty() );
+ if( !profileShaders )
+ return;
+
+ BlockState* const bs = *stack.rbegin();
+#if PROFILER_COLLECT_TAGS
+ queries.submit( bs, eEvent::Shader, (uint16_t)cs, m_nextTag );
+ m_nextTag = 0;
+#else
+ queries.submit( bs, eEvent::Shader, (uint16_t)cs );
+#endif
+}
+
+void GpuProfiler::resultsDropPending()
+{
+ for( POSITION pos = results.GetStartPosition(); nullptr != pos; )
+ results.GetNextValue( pos ).dropPending();
+#if PROFILER_COLLECT_TAGS
+ for( POSITION pos = resultsTagged.GetStartPosition(); nullptr != pos; )
+ resultsTagged.GetNextValue( pos ).dropPending();
+#endif
+}
+
+void GpuProfiler::resultsMakeTime( uint64_t freq )
+{
+ for( POSITION pos = results.GetStartPosition(); nullptr != pos; )
+ results.GetNextValue( pos ).makeTime( freq );
+#if PROFILER_COLLECT_TAGS
+ for( POSITION pos = resultsTagged.GetStartPosition(); nullptr != pos; )
+ resultsTagged.GetNextValue( pos ).makeTime( freq );
+#endif
+}
+
+void GpuProfiler::resultsReset()
+{
+ for( POSITION pos = results.GetStartPosition(); nullptr != pos; )
+ results.GetNextValue( pos ).reset();
+#if PROFILER_COLLECT_TAGS
+ for( POSITION pos = resultsTagged.GetStartPosition(); nullptr != pos; )
+ resultsTagged.GetNextValue( pos ).reset();
+#endif
+}
+
+#if PROFILER_COLLECT_TAGS
+uint16_t __declspec( noinline ) GpuProfiler::setNextTag( const char* name )
+{
+ uint16_t tag = dest.makeTagId( name );
+ m_nextTag = tag;
+ return tag;
+}
+#endif
+
+HRESULT GpuProfilerSimple::create()
+{
+ ID3D11Device* const dev = device();
+
+ CD3D11_QUERY_DESC desc{ D3D11_QUERY_TIMESTAMP_DISJOINT };
+ CHECK( dev->CreateQuery( &desc, &disjoint ) );
+
+ desc.Query = D3D11_QUERY_TIMESTAMP;
+ CHECK( dev->CreateQuery( &desc, &begin ) );
+ CHECK( dev->CreateQuery( &desc, &end ) );
+
+ context()->Begin( disjoint );
+ context()->End( begin );
+ return S_OK;
+}
+
+HRESULT GpuProfilerSimple::time( uint64_t& rdi ) const
+{
+ context()->End( end );
+
+ try
+ {
+ const D3D11_QUERY_DATA_TIMESTAMP_DISJOINT dtsd = waitForDisjointData( disjoint );
+ const uint64_t t1 = getTimestamp( begin );
+ const uint64_t t2 = getTimestamp( end );
+
+ if( !dtsd.Disjoint )
+ {
+ rdi = makeTime( t2 - t1, dtsd.Frequency );
+ return S_OK;
+ }
+ else
+ {
+ // Something occurred in between the query's ID3D11DeviceContext::Begin and ID3D11DeviceContext::End calls
+ // that caused the timestamp counter to become discontinuous or disjoint, such as unplugging the AC cord on a laptop, overheating, or throttling up/down due to laptop savings events.
+ // The timestamp returned by ID3D11DeviceContext::GetData for a timestamp query is only reliable if Disjoint is FALSE.
+ rdi = -1;
+ return S_FALSE;
+ }
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+} \ No newline at end of file
diff --git a/Whisper/Utils/GpuProfiler.h b/Whisper/Utils/GpuProfiler.h
new file mode 100644
index 0000000..fbc284e
--- /dev/null
+++ b/Whisper/Utils/GpuProfiler.h
@@ -0,0 +1,187 @@
+#pragma once
+#include "../D3D/device.h"
+#include "ProfileCollection.h"
+
+namespace DirectCompute
+{
+ enum struct eProfilerBlock : uint16_t
+ {
+ LoadModel = 0x1000,
+ Run = 0x2000,
+ Encode = 0x3000,
+ EncodeLayer = 0x4000,
+ Decode = 0x5000,
+ DecodeStep = 0x6000,
+ DecodeLayer = 0x7000,
+ };
+
+ enum struct eComputeShader : uint16_t;
+
+ class GpuProfiler
+ {
+ CComPtr<ID3D11Query> disjoint;
+
+ enum struct eEvent
+ {
+ None = 0,
+ BlockStart,
+ BlockEnd,
+ Shader
+ };
+
+ struct BlockState;
+ static constexpr uint16_t EmptyShader = ~(uint16_t)0;
+
+ // A circular buffer with in-flight queries which feeds timestamps into the iTimestampSink interface
+ class Queue
+ {
+ static constexpr size_t queueLength = 32;
+
+ // Ring buffer for individual measures
+ struct Entry
+ {
+ CComPtr<ID3D11Query> query;
+ BlockState* block;
+ eEvent event;
+ uint16_t shader;
+#if PROFILER_COLLECT_TAGS
+ uint16_t tag = 0;
+#endif
+ void join( GpuProfiler& owner );
+ };
+
+ GpuProfiler& owner;
+ std::array<Entry, queueLength> queue;
+ size_t nextEntry = 0;
+
+ public:
+ Queue( GpuProfiler& gp ) : owner( gp ) {}
+
+ HRESULT create();
+
+ // Begin a next query. Eventually, this will result in the BlockState.haveTimestamp callback
+ void submit( BlockState* block, eEvent evt, uint16_t shader = EmptyShader, uint16_t tag = 0 );
+
+ // Wait for all the pending queries, and call their callbacks
+ void join();
+ };
+ Queue queries;
+
+ struct sProfilerData;
+ struct BlockState
+ {
+ int64_t timeStart = -1;
+ sProfilerData* destBlock = nullptr;
+ int64_t shaderStart = -1;
+ uint16_t prevShader = EmptyShader;
+ uint16_t prevShaderTag = 0;
+ BlockState* parentBlock = nullptr;
+ void haveTimestamp( eEvent evt, uint16_t cs, uint16_t tag, uint64_t time, GpuProfiler& profiler );
+ private:
+ void completePrevShader( uint64_t time, GpuProfiler& profiler );
+ };
+ CAtlMap<eProfilerBlock, BlockState> blockStates;
+ std::vector<BlockState*> stack;
+
+ struct sProfilerData
+ {
+ // Count of accumulated measures
+ size_t callsPending;
+ // Total time spent running all instances of that measure, expressed in GPU ticks
+ uint64_t timePending;
+
+ Whisper::ProfileCollection::Measure* dest;
+
+ inline void makeTime( uint64_t freq );
+ inline void addPending( int64_t time );
+ inline void reset();
+ inline void dropPending();
+
+ sProfilerData()
+ {
+ reset();
+ }
+ };
+
+ CAtlMap<uint16_t, sProfilerData> results;
+#if PROFILER_COLLECT_TAGS
+ CAtlMap<uint32_t, sProfilerData> resultsTagged;
+#endif
+ void resultsMakeTime( uint64_t freq );
+ void resultsDropPending();
+ void resultsReset();
+
+ void blockStart( eProfilerBlock which );
+ void blockEnd();
+
+ Whisper::ProfileCollection& dest;
+#if PROFILER_COLLECT_TAGS
+ uint16_t m_nextTag = 0;
+#endif
+ public:
+
+ GpuProfiler( Whisper::ProfileCollection& pc ) :
+ dest( pc ), queries( *this ) { }
+
+ HRESULT create( size_t maxDepth = 3 );
+
+ class BlockRaii
+ {
+ GpuProfiler* profiler;
+
+ public:
+ BlockRaii( GpuProfiler& owner, eProfilerBlock which )
+ {
+ owner.blockStart( which );
+ profiler = &owner;
+ }
+ ~BlockRaii()
+ {
+ if( nullptr != profiler )
+ {
+ profiler->blockEnd();
+ profiler = nullptr;
+ }
+ }
+ BlockRaii( BlockRaii&& that ) noexcept :
+ profiler( that.profiler )
+ {
+ that.profiler = nullptr;
+ }
+ BlockRaii( const BlockRaii& ) = delete;
+ void operator=( const BlockRaii& ) = delete;
+ void operator=( BlockRaii&& ) = delete;
+ };
+
+ BlockRaii block( eProfilerBlock which )
+ {
+ return BlockRaii{ *this, which };
+ }
+
+ void computeShader( eComputeShader cs );
+
+ bool profileShaders = false;
+ // bool profileShaders = true;
+
+ decltype( auto ) cpuBlock( Whisper::eCpuBlock block )
+ {
+ return dest.cpuBlock( block );
+ }
+ Whisper::ProfileCollection& profiler() { return dest; }
+
+ // Set tag string for the next compute shader
+ // The string should be readonly: for performance reason the implementation doesn’t copy nor compare any strings, it only keeps the pointer
+#if PROFILER_COLLECT_TAGS
+ uint16_t setNextTag( const char* name );
+#else
+ inline uint16_t setNextTag( const char* name ) { return 0; }
+#endif
+
+ void setNextTag( uint16_t tag )
+ {
+#if PROFILER_COLLECT_TAGS
+ m_nextTag = tag;
+#endif
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Utils/GpuProfilerSimple.h b/Whisper/Utils/GpuProfilerSimple.h
new file mode 100644
index 0000000..7938b44
--- /dev/null
+++ b/Whisper/Utils/GpuProfilerSimple.h
@@ -0,0 +1,14 @@
+#pragma once
+#include "../D3D/device.h"
+
+namespace DirectCompute
+{
+ // A simple profiler which doesn't collect anything, used to measure time it took to load the model
+ class GpuProfilerSimple
+ {
+ CComPtr<ID3D11Query> disjoint, begin, end;
+ public:
+ HRESULT create();
+ HRESULT time( uint64_t& rdi ) const;
+ };
+} \ No newline at end of file
diff --git a/Whisper/Utils/Logger.cpp b/Whisper/Utils/Logger.cpp
new file mode 100644
index 0000000..1b4b233
--- /dev/null
+++ b/Whisper/Utils/Logger.cpp
@@ -0,0 +1,240 @@
+#include "stdafx.h"
+#include "Logger.h"
+#include "../API/iContext.cl.h"
+#include <cstdarg>
+#include <atlstr.h>
+
+namespace
+{
+ wchar_t* formatMessage( HRESULT hr )
+ {
+ wchar_t* err;
+ if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
+ NULL,
+ hr,
+ MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ),
+ (LPTSTR)&err,
+ 0,
+ nullptr ) )
+ return err;
+ return nullptr;
+ }
+
+ class Utf
+ {
+ CStringA utf8;
+ CStringW utf16;
+
+ void appendError( HRESULT hr )
+ {
+ const wchar_t* err = formatMessage( hr );
+ if( nullptr != err )
+ {
+ utf16 += err;
+ LocalFree( (HLOCAL)err );
+ utf16.TrimRight();
+ }
+ else
+ utf16.AppendFormat( L"error code %i (0x%08X)", hr, hr );
+ }
+
+ public:
+ const char* print( const char* pszFormat, std::va_list va )
+ {
+ utf8.FormatV( pszFormat, va );
+ return utf8;
+ }
+ const wchar_t* print( const wchar_t* pszFormat, std::va_list va )
+ {
+ utf16.FormatV( pszFormat, va );
+ return utf16;
+ }
+ const wchar_t* upcast( const char* message, int len )
+ {
+ int count = MultiByteToWideChar( CP_UTF8, 0, message, len, nullptr, 0 );
+ if( count == 0 )
+ return nullptr;
+ wchar_t* b = utf16.GetBufferSetLength( len + 1 );
+ count = MultiByteToWideChar( CP_UTF8, 0, message, len, b, len );
+ utf16.ReleaseBuffer( count );
+ return utf16;
+ }
+ int utf8Length() const
+ {
+ return utf8.GetLength();
+ }
+ const wchar_t* printError( HRESULT hr, const char* pszFormat, std::va_list va )
+ {
+ print( pszFormat, va );
+ upcast( utf8, utf8.GetLength() );
+ utf16 += L": ";
+ appendError( hr );
+ return utf16;
+ }
+ const char* downcast()
+ {
+ int count = WideCharToMultiByte( CP_UTF8, 0, utf16, utf16.GetLength(), nullptr, 0, nullptr, nullptr );
+ char* s = utf8.GetBufferSetLength( count + 1 );
+ count = WideCharToMultiByte( CP_UTF8, 0, utf16, utf16.GetLength(), s, count, nullptr, nullptr );
+ utf8.ReleaseBufferSetLength( count );
+ return utf8;
+ }
+ };
+ thread_local Utf ts_utf;
+ using Whisper::eLoggerFlags;
+
+ class Logger : Whisper::sLoggerSetup
+ {
+ inline bool hasFlag( eLoggerFlags bit ) const
+ {
+ return 0 != ( (uint8_t)flags & (uint8_t)bit );
+ }
+
+ bool useStdError() const
+ {
+ return hasFlag( eLoggerFlags::UseStandardError );
+ }
+
+ static void writeStdError( Whisper::eLogLevel lvl, const char* message, int len )
+ {
+ const wchar_t* w = ts_utf.upcast( message, len );
+ if( nullptr != w )
+ fwprintf( stderr, L"%s\n", w );
+ }
+
+ public:
+ Logger()
+ {
+ memset( this, 0, sizeof( Logger ) );
+ }
+
+ bool willLog( Whisper::eLogLevel lvl ) const
+ {
+ if( (uint8_t)lvl > (uint8_t)level )
+ return false;
+ if( useStdError() )
+ return true;
+ return nullptr != sink;
+ }
+
+ void message( Whisper::eLogLevel lvl, const char8_t* pszFormat, std::va_list va ) const
+ {
+ const char* s = ts_utf.print( (const char*)pszFormat, va );
+ auto pfn = sink;
+ if( nullptr != pfn )
+ pfn( context, lvl, s );
+ if( useStdError() )
+ writeStdError( lvl, s, ts_utf.utf8Length() );
+ }
+ void message( Whisper::eLogLevel lvl, const wchar_t* pszFormat, std::va_list va ) const
+ {
+ Utf& u = ts_utf;
+ const wchar_t* w = u.print( pszFormat, va );
+ auto pfn = sink;
+ if( nullptr != pfn )
+ pfn( context, lvl, u.downcast() );
+ if( useStdError() )
+ fwprintf( stderr, L"%s\n", w );
+ }
+ void message( Whisper::eLogLevel lvl, HRESULT hr, const char* pszFormat, std::va_list va ) const
+ {
+ if( hasFlag( eLoggerFlags::SkipFormatMessage ) )
+ {
+ message( lvl, (const char8_t*)pszFormat, va );
+ return;
+ }
+ Utf& u = ts_utf;
+ const wchar_t* w = ts_utf.printError( hr, (const char*)pszFormat, va );
+ auto pfn = sink;
+ if( nullptr != pfn )
+ pfn( context, lvl, u.downcast() );
+ if( useStdError() )
+ fwprintf( stderr, L"%s\n", w );
+ }
+
+ void operator=( const sLoggerSetup& rsi )
+ {
+ sink = rsi.sink;
+ context = rsi.context;
+ level = rsi.level;
+ flags = rsi.flags;
+ }
+ };
+
+ static Logger s_logger;
+}
+
+bool willLogMessage( Whisper::eLogLevel lvl )
+{
+ return s_logger.willLog( lvl );
+}
+
+using Whisper::eLogLevel;
+
+#define LOG_MESSAGE_IMPL( lvl ) \
+ if( !s_logger.willLog( lvl ) ) \
+ return; \
+ std::va_list args; \
+ va_start( args, pszFormat ); \
+ s_logger.message( lvl, pszFormat, args ); \
+ va_end( args );
+
+void logError( const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Error );
+}
+void logError16( const wchar_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Error );
+}
+void logWarning( const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Warning );
+}
+void logWarning16( const wchar_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Warning );
+}
+void logInfo( const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Info );
+}
+void logInfo16( const wchar_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Info );
+}
+void logDebug( const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Debug );
+}
+void logDebug16( const wchar_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Debug );
+}
+#undef LOG_MESSAGE_IMPL
+
+#define LOG_MESSAGE_IMPL( lvl ) \
+ if( !s_logger.willLog( lvl ) ) \
+ return; \
+ std::va_list args; \
+ va_start( args, pszFormat ); \
+ s_logger.message( lvl, hr, (const char*)pszFormat, args ); \
+ va_end( args );
+
+void logErrorHr( long hr, const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Error );
+}
+void logWarningHr( long hr, const char8_t* pszFormat, ... )
+{
+ LOG_MESSAGE_IMPL( eLogLevel::Warning );
+}
+
+#undef LOG_MESSAGE_IMPL
+
+// DLL entry point
+HRESULT COMLIGHTCALL Whisper::setupLogger( const sLoggerSetup& setup )
+{
+ s_logger = setup;
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/Utils/Logger.h b/Whisper/Utils/Logger.h
new file mode 100644
index 0000000..5e6ec0d
--- /dev/null
+++ b/Whisper/Utils/Logger.h
@@ -0,0 +1,23 @@
+#pragma once
+#include "../API/loggerApi.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void logError( const char8_t* pszFormat, ... );
+void logError16( const wchar_t* pszFormat, ... );
+void logErrorHr( long hr, const char8_t* pszFormat, ... );
+void logWarning( const char8_t* pszFormat, ... );
+void logWarning16( const wchar_t* pszFormat, ... );
+void logWarningHr( long hr, const char8_t* pszFormat, ... );
+void logInfo( const char8_t* pszFormat, ... );
+void logInfo16( const wchar_t* pszFormat, ... );
+void logDebug( const char8_t* pszFormat, ... );
+void logDebug16( const wchar_t* pszFormat, ... );
+
+bool willLogMessage( Whisper::eLogLevel lvl );
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/Whisper/Utils/ProfileCollection.cpp b/Whisper/Utils/ProfileCollection.cpp
new file mode 100644
index 0000000..2fc5919
--- /dev/null
+++ b/Whisper/Utils/ProfileCollection.cpp
@@ -0,0 +1,331 @@
+#include "stdafx.h"
+#include "ProfileCollection.h"
+#include "GpuProfiler.h"
+#include "../Whisper/WhisperModel.h"
+#include "../D3D/shaderNames.h"
+using namespace Whisper;
+
+ProfileCollection::Measure& ProfileCollection::measure( DirectCompute::eProfilerBlock which )
+{
+ uint32_t key = (uint16_t)which;
+ key |= 0x20000;
+ return measures[ key ];
+}
+
+ProfileCollection::Measure& ProfileCollection::measure( DirectCompute::eComputeShader which )
+{
+ uint32_t key = (uint16_t)which;
+ key |= 0x30000;
+ return measures[ key ];
+}
+
+ProfileCollection::Measure& ProfileCollection::measure( eCpuBlock which )
+{
+ uint32_t key = (uint8_t)which;
+ key |= 0x10000;
+ CComCritSecLock<CComAutoCriticalSection> lock{ critSec };
+ return measures[ key ];
+}
+
+#if PROFILER_COLLECT_TAGS
+ProfileCollection::Measure& ProfileCollection::measure( DirectCompute::eComputeShader which, uint16_t tag )
+{
+ uint32_t key = (uint8_t)which;
+ key = key << 16;
+ key |= tag;
+ CComCritSecLock<CComAutoCriticalSection> lock{ critSec };
+ return taggedShaders[ key ];
+}
+#endif
+
+namespace
+{
+ using pfnPrintEnum = const char* ( * )( uint16_t val );
+
+ static const char* printCpuBlock( uint16_t id )
+ {
+ const eCpuBlock which = (eCpuBlock)id;
+ switch( which )
+ {
+#define V(x) case eCpuBlock::x: return #x
+ V( LoadModel );
+ V( Run );
+ V( Spectrogram );
+ V( Sample );
+ V( VAD );
+ V( Decode );
+ V( DecodeStep );
+ V( DecodeLayer );
+#undef V
+ }
+ assert( false );
+ return nullptr;
+ }
+
+ static const char* printGpuBlock( uint16_t id )
+ {
+ using DirectCompute::eProfilerBlock;
+ const eProfilerBlock which = (eProfilerBlock)id;
+
+ switch( which )
+ {
+#define V(x) case eProfilerBlock::x: return #x
+ V( LoadModel );
+ V( Run );
+ V( Encode );
+ V( EncodeLayer );
+ V( Decode );
+ V( DecodeStep );
+ V( DecodeLayer );
+#undef V
+ }
+ assert( false );
+ return nullptr;
+ }
+
+ static const char* printShader( uint16_t id )
+ {
+ return DirectCompute::computeShaderName( (DirectCompute::eComputeShader)id );
+ }
+
+ static pfnPrintEnum printSectionStart( uint16_t type )
+ {
+ switch( type )
+ {
+ case 1:
+ logInfo( u8" CPU Tasks" );
+ return &printCpuBlock;
+ case 2:
+ logInfo( u8" GPU Tasks" );
+ return &printGpuBlock;
+ case 3:
+ logInfo( u8" Compute Shaders" );
+ return &printShader;
+ default:
+ return nullptr;
+ }
+ }
+
+ struct PrintedTime
+ {
+ double value;
+ const char* unit;
+
+ PrintedTime( uint64_t ticks )
+ {
+ const double dbl = (double)(int64_t)ticks;
+ if( ticks >= 10'000'000 )
+ {
+ value = dbl / 1.0E+7;
+ unit = "seconds";
+ }
+ else if( ticks >= 10'000 )
+ {
+ value = dbl / 1.0E+4;
+ unit = "milliseconds";
+ }
+ else
+ {
+ value = dbl / 1.0E+1;
+ unit = "microseconds";
+ }
+ }
+ PrintedTime( double dbl )
+ {
+ if( dbl >= 10'000'000 )
+ {
+ value = dbl / 1.0E+7;
+ unit = "seconds";
+ }
+ else if( dbl >= 10'000 )
+ {
+ value = dbl / 1.0E+4;
+ unit = "milliseconds";
+ }
+ else
+ {
+ value = dbl / 1.0E+1;
+ unit = "microseconds";
+ }
+ }
+ };
+}
+
+void ProfileCollection::Measure::print( const char* name ) const
+{
+ PrintedTime total{ totalTicks };
+ if( 1 == count )
+ logInfo( u8"%s\t%g %s", name, total.value, total.unit );
+ else
+ {
+ PrintedTime avg = (double)totalTicks / (double)(int64_t)count;
+ logInfo( u8"%s\t%g %s, %zu calls, %g %s average", name, total.value, total.unit, count, avg.value, avg.unit );
+ }
+}
+
+#if PROFILER_COLLECT_TAGS
+struct TaggedShaderCmp
+{
+ bool operator()( uint16_t cs, uint32_t key ) const
+ {
+ return cs < key >> 16;
+ }
+ bool operator()( uint32_t key, uint16_t cs ) const
+ {
+ return key >> 16 < cs;
+ }
+};
+
+void ProfileCollection::TaggedTemp::print() const
+{
+ PrintedTime total{ ticks };
+ if( 1 == count )
+ logInfo( u8" %s\t%g %s", name, total.value, total.unit );
+ else
+ {
+ PrintedTime avg = (double)ticks / (double)(int64_t)count;
+ logInfo( u8" %s\t%g %s, %zu calls, %g %s average", name, total.value, total.unit, count, avg.value, avg.unit );
+ }
+}
+#endif
+
+void ProfileCollection::print()
+{
+ keysTemp.clear();
+ for( POSITION pos = measures.GetStartPosition(); nullptr != pos; )
+ {
+ auto* p = measures.GetNext( pos );
+ if( p->m_value.count == 0 )
+ continue;
+ keysTemp.push_back( p->m_key );
+ }
+
+ std::sort( keysTemp.begin(), keysTemp.end() );
+ auto it = std::lower_bound( keysTemp.begin(), keysTemp.end(), 0x30000u );
+ if( it != keysTemp.end() )
+ {
+ auto lambda = [ this ]( uint32_t a, uint32_t b )
+ {
+ const uint64_t ta = measures.Lookup( a )->m_value.totalTicks;
+ const uint64_t tb = measures.Lookup( b )->m_value.totalTicks;
+ return ta > tb;
+ };
+ std::stable_sort( it, keysTemp.end(), lambda );
+ }
+
+#if PROFILER_COLLECT_TAGS
+ taggedKeysTemp.clear();
+ for( POSITION pos = taggedShaders.GetStartPosition(); nullptr != pos; )
+ {
+ auto* p = taggedShaders.GetNext( pos );
+ if( p->m_value.count == 0 )
+ continue;
+ taggedKeysTemp.push_back( p->m_key );
+ }
+ std::sort( taggedKeysTemp.begin(), taggedKeysTemp.end() );
+#endif
+
+ uint16_t prevKeyType = 0;
+ pfnPrintEnum pfn = nullptr;
+ for( uint32_t k : keysTemp )
+ {
+ const uint16_t type = (uint16_t)( k >> 16 );
+ if( type != prevKeyType )
+ {
+ prevKeyType = type;
+ pfn = printSectionStart( type );
+ }
+ if( pfn == nullptr )
+ continue;
+ const auto* p = measures.Lookup( k );
+ assert( nullptr != p );
+ p->m_value.print( pfn( (uint16_t)k ) );
+
+#if PROFILER_COLLECT_TAGS
+ if( type == 3 )
+ {
+ // Compute shader
+ auto range = std::equal_range( taggedKeysTemp.begin(), taggedKeysTemp.end(), (uint16_t)k, TaggedShaderCmp{} );
+ if( range.first != range.second )
+ {
+ // We have at least 1 tag for that compute shader
+ taggedTimes.clear();
+ uint64_t totalTicks = 0;
+ size_t totalCount = 0;
+ for( auto it = range.first; it != range.second; it++ )
+ {
+ const uint32_t key = *it;
+ const uint16_t tagId = (uint16_t)key;
+ assert( 0 != tagId );
+ const auto* p = taggedShaders.Lookup( key );
+ assert( nullptr != p );
+
+ auto& rdi = taggedTimes.emplace_back();
+ rdi.ticks = p->m_value.totalTicks;
+ totalTicks += p->m_value.totalTicks;
+
+ rdi.count = p->m_value.count;
+ totalCount += p->m_value.count;
+
+ rdi.name = tagNames[ tagId ];
+ }
+
+ assert( totalCount <= p->m_value.count );
+ if( totalCount < p->m_value.count )
+ {
+ auto& rdi = taggedTimes.emplace_back();
+ rdi.ticks = p->m_value.totalTicks - totalTicks;
+ rdi.count = p->m_value.count - totalCount;
+ rdi.name = tagNames[ 0 ];
+ }
+ std::stable_sort( taggedTimes.begin(), taggedTimes.end() );
+ for( const auto& e : taggedTimes )
+ e.print();
+ }
+ }
+#endif
+ }
+}
+
+void ProfileCollection::reset()
+{
+ for( POSITION pos = measures.GetStartPosition(); nullptr != pos; )
+ measures.GetNextValue( pos ).reset();
+}
+
+ProfileCollection::ProfileCollection( const WhisperModel& model )
+{
+ const __m128i vals = model.getLoadTimes();
+
+ uint64_t s = (uint64_t)_mm_cvtsi128_si64( vals );
+ measure( eCpuBlock::LoadModel ).add( s );
+
+ s = (uint64_t)_mm_extract_epi64( vals, 1 );
+ measure( DirectCompute::eProfilerBlock::LoadModel ).add( s );
+#if PROFILER_COLLECT_TAGS
+ // Tag ID 0 means no tag at all. makeTagId() method returns 0 for nullptr name, and starts numbering with 1 for non-empoty tag names
+ // Push the tag name corresponding to ID = 0, this way we can index directly with tag IDs.
+ tagNames.push_back( "<untagged>" );
+#endif
+}
+
+uint16_t ProfileCollection::makeTagId( const char* tag )
+{
+#if PROFILER_COLLECT_TAGS
+ if( nullptr == tag )
+ return 0;
+ auto p = tagIDs.Lookup( tag );
+ if( nullptr != p )
+ return p->m_value;
+ const size_t newTag = tagIDs.GetCount() + 1;
+ if( newTag <= 0xFFFF )
+ {
+ tagIDs.SetAt( tag, (uint16_t)newTag );
+ tagNames.push_back( tag );
+ return (uint16_t)newTag;
+ }
+ throw DISP_E_OVERFLOW;
+#else
+ return 0;
+#endif
+} \ No newline at end of file
diff --git a/Whisper/Utils/ProfileCollection.h b/Whisper/Utils/ProfileCollection.h
new file mode 100644
index 0000000..732bb48
--- /dev/null
+++ b/Whisper/Utils/ProfileCollection.h
@@ -0,0 +1,112 @@
+#pragma once
+#include <atlcoll.h>
+#include "CpuProfiler.h"
+
+namespace DirectCompute
+{
+ enum struct eComputeShader : uint16_t;
+ enum struct eProfilerBlock : uint16_t;
+}
+
+namespace Whisper
+{
+ struct WhisperModel;
+
+ enum struct eCpuBlock : uint8_t
+ {
+ LoadModel,
+ Run,
+ Spectrogram,
+ Sample,
+ VAD,
+ Decode,
+ DecodeStep,
+ DecodeLayer,
+ };
+
+ class ProfileCollection
+ {
+ public:
+ ProfileCollection( const WhisperModel& model );
+
+ struct Measure
+ {
+ size_t count = 0;
+ // 100-nanosecond ticks
+ uint64_t totalTicks = 0;
+
+ void reset()
+ {
+ count = 0;
+ totalTicks = 0;
+ }
+
+ void print( const char* name ) const;
+
+ void add( uint64_t val )
+ {
+ count++;
+ totalTicks += val;
+ }
+ };
+
+ Measure& measure( DirectCompute::eProfilerBlock which );
+ Measure& measure( DirectCompute::eComputeShader which );
+ Measure& measure( eCpuBlock which );
+#if PROFILER_COLLECT_TAGS
+ Measure& measure( DirectCompute::eComputeShader which, uint16_t tag );
+#endif
+ void print();
+
+ void reset();
+
+ class CpuRaii
+ {
+ Measure& dest;
+ const int64_t tsc;
+
+ public:
+ CpuRaii( Measure& m ) : dest( m ), tsc( tscNow() )
+ { }
+
+ ~CpuRaii()
+ {
+ const int64_t elapsed = tscNow() - tsc;
+ dest.add( ticksFromTsc( elapsed ) );
+ }
+ };
+
+ decltype( auto ) cpuBlock( eCpuBlock which )
+ {
+ return CpuRaii{ measure( which ) };
+ }
+
+ uint16_t makeTagId( const char* tag );
+
+ private:
+ CAtlMap<uint32_t, Measure> measures;
+ CComAutoCriticalSection critSec;
+#if PROFILER_COLLECT_TAGS
+ CAtlMap<const char*, uint16_t> tagIDs;
+ std::vector<const char*> tagNames;
+ CAtlMap<uint32_t, Measure> taggedShaders;
+ std::vector<uint32_t> taggedKeysTemp;
+ struct TaggedTemp
+ {
+ uint64_t ticks;
+ size_t count;
+ const char* name;
+
+ bool operator<( const TaggedTemp& that ) const
+ {
+ // Flipping the comparison to sort in descending order
+ return ticks > that.ticks;
+ }
+
+ void print() const;
+ };
+ std::vector<TaggedTemp> taggedTimes;
+#endif
+ std::vector<uint32_t> keysTemp;
+ };
+} \ No newline at end of file
diff --git a/Whisper/Utils/ReadStream.h b/Whisper/Utils/ReadStream.h
new file mode 100644
index 0000000..c9da400
--- /dev/null
+++ b/Whisper/Utils/ReadStream.h
@@ -0,0 +1,37 @@
+#pragma once
+#include "../ComLightLib/streams.h"
+#include "../ComLightLib/comLightServer.h"
+#define WIN32_LEAN_AND_MEAN
+#include <atlfile.h>
+
+class ReadStream : public ComLight::ObjectRoot<ComLight::iReadStream>
+{
+ CAtlFile file;
+ // TODO: implement a buffer in this class, at least 256kb
+
+ HRESULT COMLIGHTCALL read( void* lpBuffer, int nNumberOfBytesToRead, int& lpNumberOfBytesRead ) override final
+ {
+ return file.Read( lpBuffer, (DWORD)nNumberOfBytesToRead, *(DWORD*)&lpNumberOfBytesRead );
+ }
+ HRESULT COMLIGHTCALL seek( int64_t offset, ComLight::eSeekOrigin origin ) override final
+ {
+ return file.Seek( offset, (uint8_t)origin );
+ }
+ HRESULT COMLIGHTCALL getPosition( int64_t& position ) override final
+ {
+ return file.GetPosition( *(ULONGLONG*)&position );
+ }
+ HRESULT COMLIGHTCALL getLength( int64_t& length ) override final
+ {
+ return file.GetSize( *(ULONGLONG*)&length );
+ }
+
+public:
+
+ HRESULT open( const wchar_t* path )
+ {
+ if( file )
+ return HRESULT_CODE( ERROR_ALREADY_INITIALIZED );
+ return file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN );
+ }
+}; \ No newline at end of file
diff --git a/Whisper/Utils/Trace/TraceStructures.cpp b/Whisper/Utils/Trace/TraceStructures.cpp
new file mode 100644
index 0000000..289a534
--- /dev/null
+++ b/Whisper/Utils/Trace/TraceStructures.cpp
@@ -0,0 +1,31 @@
+#include "stdafx.h"
+#include "TraceStructures.h"
+using namespace Tracing;
+
+uint64_t sTraceItem::buffer( uint64_t off, size_t length, eDataType type )
+{
+ payloadOffset = off;
+ payloadSize = length * DirectCompute::elementSize( type );
+ *(uint64_t*)( &size[ 0 ] ) = length;
+ *(uint64_t*)( &size[ 2 ] ) = 0;
+ _mm_storeu_si128( ( __m128i* )stride.data(), _mm_setzero_si128() );
+ itemType = eItemType::Buffer;
+ dataType = type;
+ return payloadSize;
+}
+
+uint64_t sTraceItem::tensor( uint64_t off, __m128i ne, __m128i nb, eDataType type )
+{
+ payloadOffset = off;
+ _mm_storeu_si128( ( __m128i* )size.data(), ne );
+ _mm_storeu_si128( ( __m128i* )stride.data(), nb );
+ uint64_t count = 1;
+ for( uint32_t i : size )
+ if( i != 0 )
+ count *= i;
+
+ payloadSize = count * DirectCompute::elementSize( type );
+ itemType = eItemType::Tensor;
+ dataType = type;
+ return payloadSize;
+} \ No newline at end of file
diff --git a/Whisper/Utils/Trace/TraceStructures.h b/Whisper/Utils/Trace/TraceStructures.h
new file mode 100644
index 0000000..d1213bf
--- /dev/null
+++ b/Whisper/Utils/Trace/TraceStructures.h
@@ -0,0 +1,55 @@
+#pragma once
+#include <array>
+#include <emmintrin.h>
+#include "../../D3D/enums.h"
+
+namespace Tracing
+{
+ using DirectCompute::eDataType;
+
+ // File header of the trace file
+ struct sFileHeader
+ {
+ static constexpr uint32_t correctMagic = 0xE6B4A12Du; // random.org
+
+ uint32_t magic;
+ uint8_t formatVersion;
+ uint8_t zzPadding;
+ uint16_t cbItem;
+ uint32_t countItems;
+ uint32_t zzPadding2;
+ uint64_t bytesPayload;
+ uint32_t countStrings, bytesStrings;
+ };
+ // Payload data starts immediately after the header, bytesPayload bytes in total.
+ // Then `bytesStrings` with string names, first countStrings * 4 of them are offsets, then ( bytesStrings - countStrings * 4 ) bytes with the string data.
+ // The strings in the file are null-terminated.
+ // Immediately after the strings, the next `cbItem` * `countItems` bytes are actual items (tensors and vectors) saved in the trace.
+ // The format is weird because optimized for streaming.
+ // These traces can grow large, we can’t afford memory keeping the payload data in memory.
+ // Metadata is tiny compared to payload, we accumulate that in memory, and write to the end of the file when closed.
+
+ enum struct eItemType : uint8_t
+ {
+ Buffer = 1,
+ Tensor = 2,
+ };
+
+ struct sTraceItem
+ {
+ uint64_t payloadOffset;
+ uint64_t payloadSize;
+ std::array<uint32_t, 4> size;
+ std::array<uint32_t, 4> stride;
+ std::array<uint32_t, 4> formatArgs;
+ eItemType itemType;
+ eDataType dataType;
+ uint8_t countFormatArgs = 0;
+ uint8_t zzPadding = 0;
+ uint32_t stringIndex;
+
+ uint64_t buffer( uint64_t off, size_t length, eDataType type );
+
+ uint64_t tensor( uint64_t off, __m128i ne, __m128i nb, eDataType type );
+ };
+} \ No newline at end of file
diff --git a/Whisper/Utils/Trace/TraceWriter.cpp b/Whisper/Utils/Trace/TraceWriter.cpp
new file mode 100644
index 0000000..be9946b
--- /dev/null
+++ b/Whisper/Utils/Trace/TraceWriter.cpp
@@ -0,0 +1,263 @@
+#include "stdafx.h"
+#include "TraceWriter.h"
+#include <atlfile.h>
+#include <atlcoll.h>
+#include <atlstr.h>
+#include "TraceStructures.h"
+#include "../../ML/Tensor.h"
+#include "../../CPU/Tensor.h"
+#include <Shlobj.h>
+using namespace Tracing;
+
+namespace
+{
+ static HRESULT createDir( LPCTSTR pathFile )
+ {
+ LPCWSTR fn = PathFindFileName( pathFile );
+ if( fn == pathFile )
+ return E_FAIL;
+
+ const int cc = (int)( fn - pathFile );
+ CString dir{ pathFile, cc };
+ if( PathIsDirectory( dir ) )
+ return S_OK;
+ const int status = SHCreateDirectoryEx( nullptr, dir, nullptr );
+ if( 0 == status )
+ return S_OK;
+ return HRESULT_FROM_WIN32( status );
+ }
+
+ class TraceFileWriter
+ {
+ CAtlFile file;
+ // Concatenated strings, including the 0 terminators
+ std::vector<char> stringsData;
+ // Index = string ID, value = start offset into stringsData
+ std::vector<uint32_t> stringsIndex;
+ // Hash map to unduplicate these strings
+ CAtlMap<CStringA, uint32_t> stringsHash;
+
+ uint32_t addString( const CStringA& s )
+ {
+ auto p = stringsHash.Lookup( s );
+ if( p != nullptr )
+ return p->m_value;
+
+ const uint32_t off = (uint32_t)stringsData.size();
+ const char* rsi = s;
+ stringsData.insert( stringsData.end(), rsi, rsi + s.GetLength() + 1 );
+ stringsIndex.push_back( off );
+
+ const uint32_t newId = (uint32_t)stringsHash.GetCount();
+ stringsHash.SetAt( s, newId );
+ return newId;
+ }
+
+ void addString( sTraceItem& rdi, const ItemName& name )
+ {
+ rdi.countFormatArgs = name.countArgs;
+ rdi.stringIndex = addString( name.pointer );
+ rdi.formatArgs = name.args;
+ }
+
+ std::vector<sTraceItem> items;
+ uint64_t offset = 0;
+
+ public:
+
+ HRESULT create( LPCTSTR path )
+ {
+ CHECK( createDir( path ) );
+ CHECK( file.Create( path, GENERIC_WRITE, 0, CREATE_ALWAYS ) );
+
+ constexpr uint64_t cbHeader = sizeof( sFileHeader );
+ CHECK( file.SetSize( cbHeader ) );
+ CHECK( file.Seek( 0, SEEK_END ) );
+ offset = 0;
+
+ return S_OK;
+ }
+
+ HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt )
+ {
+ sTraceItem& rdi = items.emplace_back();
+ const uint64_t cb = rdi.buffer( offset, length, dt );
+ addString( rdi, name );
+ assert( cb <= UINT_MAX );
+ CHECK( file.Write( rsi, (DWORD)cb ) );
+ offset += cb;
+ return S_OK;
+ }
+
+ HRESULT tensor( const ItemName& name, const void* rsi, __m128i size, __m128i strides, eDataType dt )
+ {
+ sTraceItem& rdi = items.emplace_back();
+ const uint64_t cb = rdi.tensor( offset, size, strides, dt );
+ addString( rdi, name );
+ assert( cb <= UINT_MAX );
+ CHECK( file.Write( rsi, (DWORD)cb ) );
+ offset += cb;
+ return S_OK;
+ }
+
+ HRESULT close()
+ {
+ if( !file )
+ return S_FALSE;
+
+ const uint32_t cbStringsData = (uint32_t)stringsData.size();
+ const uint32_t cbStringsIndex = (uint32_t)( stringsIndex.size() * 4 );
+ if( !stringsIndex.empty() )
+ CHECK( file.Write( stringsIndex.data(), cbStringsIndex ) );
+ if( !stringsData.empty() )
+ CHECK( file.Write( stringsData.data(), cbStringsData ) );
+
+ const uint32_t cbItems = (uint32_t)items.size() * (uint32_t)sizeof( sTraceItem );
+ if( !items.empty() )
+ CHECK( file.Write( items.data(), cbItems ) );
+ CHECK( file.Seek( 0, FILE_BEGIN ) );
+
+ sFileHeader header;
+ memset( &header, 0, sizeof( header ) );
+ header.magic = header.correctMagic;
+ header.cbItem = sizeof( sTraceItem );
+ header.countItems = (uint32_t)items.size();
+ header.bytesPayload = offset;
+ header.countStrings = (uint32_t)stringsIndex.size();
+ header.bytesStrings = cbStringsData + cbStringsIndex;
+ CHECK( file.Write( &header, sizeof( header ) ) );
+ CHECK( file.Flush() );
+ file.Close();
+
+ return S_OK;
+ }
+ };
+
+ class TraceWriter : public iTraceWriter
+ {
+ TraceFileWriter file;
+
+ HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt ) override final
+ {
+ return file.buffer( name, rsi, length, dt );
+ }
+
+ HRESULT tensor( const ItemName& name, const void* rsi, __m128i size, __m128i strides, eDataType dt ) override final
+ {
+ return file.tensor( name, rsi, size, strides, dt );
+ }
+
+ public:
+
+ TraceWriter( LPCTSTR path )
+ {
+ check( file.create( path ) );
+ }
+
+ ~TraceWriter()
+ {
+ check( file.close() );
+ }
+ };
+}
+
+std::unique_ptr<iTraceWriter> iTraceWriter::create( LPCTSTR path )
+{
+ return std::make_unique<TraceWriter>( path );
+}
+
+namespace
+{
+ static std::vector<float> tempFp32;
+ static std::vector<uint16_t> tempFp16;
+
+ template<class E>
+ inline const void* ptr( const std::vector<E>& vec )
+ {
+ return vec.empty() ? nullptr : vec.data();
+ }
+}
+
+HRESULT iTraceWriter::tensor( const ItemName& name, const DirectCompute::Tensor& source )
+{
+ const __m128i size = source.sizeVec();
+ const __m128i strides = source.stridesVec();
+ const eDataType dt = source.getType();
+ if( dt == eDataType::FP32 )
+ {
+ source.download( tempFp32 );
+ return tensor( name, ptr( tempFp32 ), size, strides, eDataType::FP32 );
+ }
+ else if( dt == eDataType::FP16 )
+ {
+ source.download( tempFp16 );
+ return tensor( name, ptr( tempFp16 ), size, strides, eDataType::FP16 );
+ }
+ return E_NOTIMPL;
+}
+
+HRESULT iTraceWriter::tensor( const ItemName& name, const CpuCompute::Tensor& source )
+{
+ const __m128i size = source.sizeVec();
+ const __m128i strides = source.stridesVec();
+ const eDataType dt = source.type();
+
+ if( dt == eDataType::FP32 )
+ return tensor( name, source.fp32(), size, strides, eDataType::FP32 );
+ else if( dt == eDataType::FP16 )
+ return tensor( name, source.fp16(), size, strides, eDataType::FP16 );
+ else
+ return E_NOTIMPL;
+}
+
+#if BUILD_BOTH_VERSIONS
+#include "../../source/ggml.h"
+HRESULT __declspec( noinline ) iTraceWriter::tensor( const ItemName& name, const ggml_tensor& source )
+{
+ __m128i size = load16( source.ne );
+ __m128i strides = _mm_setr_epi32(
+ (int)(uint32_t)source.nb[ 0 ],
+ (int)(uint32_t)source.nb[ 1 ],
+ (int)(uint32_t)source.nb[ 2 ],
+ (int)(uint32_t)source.nb[ 3 ] );
+
+ const __m128i ones = _mm_set1_epi32( 1 );
+ switch( source.n_dims )
+ {
+ case 0:
+ size = ones;
+ break;
+ case 1:
+ size = _mm_blend_epi16( size, ones, 0b11111100 );
+ break;
+ case 2:
+ size = _mm_blend_epi16( size, ones, 0b11110000 );
+ break;
+ case 3:
+ size = _mm_blend_epi16( size, ones, 0b11000000 );
+ break;
+ case 4:
+ break;
+ default:
+ return E_INVALIDARG;
+ }
+
+ const ggml_type dt = source.type;
+ switch( dt )
+ {
+ case GGML_TYPE_F16:
+ strides = _mm_srli_epi32( strides, 1 );
+ return tensor( name, source.data, size, strides, eDataType::FP16 );
+ case GGML_TYPE_F32:
+ strides = _mm_srli_epi32( strides, 2 );
+ return tensor( name, source.data, size, strides, eDataType::FP32 );
+ default:
+ return E_NOTIMPL;
+}
+}
+#else
+HRESULT iTraceWriter::tensor( const ItemName& name, const ggml_tensor& source )
+{
+ return E_NOTIMPL;
+}
+#endif \ No newline at end of file
diff --git a/Whisper/Utils/Trace/TraceWriter.h b/Whisper/Utils/Trace/TraceWriter.h
new file mode 100644
index 0000000..0514c5a
--- /dev/null
+++ b/Whisper/Utils/Trace/TraceWriter.h
@@ -0,0 +1,70 @@
+#pragma once
+#include <memory>
+#include "../../D3D/enums.h"
+
+namespace DirectCompute
+{
+ class Tensor;
+}
+namespace CpuCompute
+{
+ class Tensor;
+}
+
+struct ggml_tensor;
+
+namespace Tracing
+{
+ using DirectCompute::eDataType;
+
+ struct ItemName
+ {
+ const char* pointer;
+ std::array<uint32_t, 4> args;
+ uint8_t countArgs;
+
+ ItemName( const char* str )
+ {
+ pointer = str;
+ _mm_storeu_si128( ( __m128i* )args.data(), _mm_setzero_si128() );
+ countArgs = 0;
+ }
+ ItemName( const char* str, int a0 )
+ {
+ pointer = str;
+ __m128i v = _mm_cvtsi32_si128( a0 );
+ _mm_storeu_si128( ( __m128i* )args.data(), v );
+ countArgs = 1;
+ }
+ ItemName( const char* str, uint32_t a0 )
+ {
+ pointer = str;
+ __m128i v = _mm_cvtsi32_si128( (int)a0 );
+ _mm_storeu_si128( ( __m128i* )args.data(), v );
+ countArgs = 1;
+ }
+ ItemName( const char* str, size_t a0 )
+ {
+ pointer = str;
+ __m128i v = _mm_cvtsi32_si128( (int)a0 );
+ _mm_storeu_si128( ( __m128i* )args.data(), v );
+ countArgs = 1;
+ }
+ };
+
+ class iTraceWriter
+ {
+ public:
+ virtual ~iTraceWriter() {}
+
+ static std::unique_ptr<iTraceWriter> create( LPCTSTR path );
+
+ virtual HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt ) = 0;
+
+ virtual HRESULT tensor( const ItemName& name, const void* rsi, __m128i size, __m128i strides, eDataType dt ) = 0;
+
+ HRESULT tensor( const ItemName& name, const DirectCompute::Tensor& tensor );
+ HRESULT tensor( const ItemName& name, const CpuCompute::Tensor& tensor );
+ HRESULT tensor( const ItemName& name, const ggml_tensor& tensor );
+ };
+} \ No newline at end of file
diff --git a/Whisper/Utils/Trace/tracing.cpp b/Whisper/Utils/Trace/tracing.cpp
new file mode 100644
index 0000000..976f517
--- /dev/null
+++ b/Whisper/Utils/Trace/tracing.cpp
@@ -0,0 +1,60 @@
+#include "stdafx.h"
+#include "tracing.h"
+#include "../../source/ggml.h"
+
+#if SAVE_DEBUG_TRACE
+namespace Tracing
+{
+ std::unique_ptr<iTraceWriter> s_writer;
+
+ static BOOL __stdcall consoleHandler( DWORD dwCtrlType )
+ {
+ if( dwCtrlType == CTRL_C_EVENT )
+ s_writer = nullptr;
+
+ // Return TRUE if handled this message, further handler functions won't be called.
+ // Return FALSE to pass this message to further handlers until default handler calls ExitProcess().
+ return FALSE;
+ }
+
+ void traceCreate( LPCTSTR path )
+ {
+ s_writer = iTraceWriter::create( path );
+ SetConsoleCtrlHandler( &consoleHandler, TRUE );
+ }
+
+ void traceClose()
+ {
+ s_writer = nullptr;
+ }
+
+ iTraceWriter* getWriter()
+ {
+ return s_writer.get();
+ }
+
+ using Pair = std::pair<ItemName, ggml_tensor>;
+ static std::vector<Pair> delayed;
+
+ void delayTensor( const ItemName& name, const ggml_tensor* tensor )
+ {
+ delayed.emplace_back( name, *tensor );
+ }
+
+ HRESULT writeDelayedTensors()
+ {
+ if( delayed.empty() )
+ return S_FALSE;
+ iTraceWriter* w = getWriter();
+ if( nullptr == w )
+ {
+ delayed.clear();
+ return S_FALSE;
+ }
+ for( const Pair& p : delayed )
+ w->tensor( p.first, p.second );
+ delayed.clear();
+ return S_OK;
+ }
+}
+#endif \ No newline at end of file
diff --git a/Whisper/Utils/Trace/tracing.h b/Whisper/Utils/Trace/tracing.h
new file mode 100644
index 0000000..66a7ac4
--- /dev/null
+++ b/Whisper/Utils/Trace/tracing.h
@@ -0,0 +1,67 @@
+#pragma once
+#include "TraceWriter.h"
+
+namespace Tracing
+{
+#if SAVE_DEBUG_TRACE
+ void traceCreate( LPCTSTR path );
+ void traceClose();
+
+ iTraceWriter* getWriter();
+
+ inline HRESULT tensor( const ItemName& name, const DirectCompute::Tensor& tensor )
+ {
+ iTraceWriter* w = getWriter();
+ if( w )
+ return w->tensor( name, tensor );
+ return S_FALSE;
+ }
+ inline HRESULT tensor( const ItemName& name, const CpuCompute::Tensor& tensor )
+ {
+ iTraceWriter* w = getWriter();
+ if( w )
+ return w->tensor( name, tensor );
+ return S_FALSE;
+ }
+
+ inline HRESULT tensor( const ItemName& name, const ggml_tensor* tensor )
+ {
+ iTraceWriter* w = getWriter();
+ if( w )
+ return w->tensor( name, *tensor );
+ return S_FALSE;
+ }
+
+ void delayTensor( const ItemName& name, const ggml_tensor* tensor );
+ HRESULT writeDelayedTensors();
+
+ inline HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt )
+ {
+ iTraceWriter* w = getWriter();
+ if( w )
+ return w->buffer( name, rsi, length, dt );
+ return S_FALSE;
+ }
+
+ inline HRESULT vector( const ItemName& name, const std::vector<float>& vec )
+ {
+ const float* rsi = vec.empty() ? nullptr : vec.data();
+ return buffer( name, rsi, vec.size(), eDataType::FP32 );
+ }
+ inline HRESULT vector( const ItemName& name, const float* rsi, size_t length )
+ {
+ return buffer( name, rsi, length, eDataType::FP32 );
+ }
+#else
+ inline void traceCreate( LPCTSTR path ) { }
+ inline void traceClose() { }
+ inline HRESULT tensor( const ItemName& name, const DirectCompute::Tensor& tensor ) { return S_FALSE; }
+ inline HRESULT tensor( const ItemName& name, const CpuCompute::Tensor& tensor ) { return S_FALSE; }
+ inline HRESULT tensor( const ItemName& name, const ggml_tensor* tensor ) { return S_FALSE; }
+ inline HRESULT buffer( const ItemName& name, const void* rsi, size_t length, eDataType dt ) { return S_FALSE; }
+ inline HRESULT vector( const ItemName& name, const std::vector<float>& vec ) { return S_FALSE; }
+ inline void delayTensor( const ItemName& name, const ggml_tensor* tensor ) { }
+ inline HRESULT writeDelayedTensors() { return S_FALSE; }
+ inline HRESULT vector( const ItemName& name, const float* rsi, size_t length ) { }
+#endif
+} \ No newline at end of file
diff --git a/Whisper/Utils/miscUtils.cpp b/Whisper/Utils/miscUtils.cpp
new file mode 100644
index 0000000..c3f7dd1
--- /dev/null
+++ b/Whisper/Utils/miscUtils.cpp
@@ -0,0 +1,33 @@
+#include "stdafx.h"
+#include "miscUtils.h"
+
+void setCurrentThreadName( const char* threadName )
+{
+ const DWORD dwThreadID = GetCurrentThreadId();
+
+ // https://stackoverflow.com/a/10364541/126995
+#pragma pack(push,8)
+ typedef struct tagTHREADNAME_INFO
+ {
+ DWORD dwType; // Must be 0x1000.
+ LPCSTR szName; // Pointer to name (in user addr space).
+ DWORD dwThreadID; // Thread ID (-1=caller thread).
+ DWORD dwFlags; // Reserved for future use, must be zero.
+ } THREADNAME_INFO;
+#pragma pack(pop)
+
+ THREADNAME_INFO info;
+ info.dwType = 0x1000;
+ info.szName = threadName;
+ info.dwThreadID = dwThreadID;
+ info.dwFlags = 0;
+
+ constexpr DWORD MS_VC_EXCEPTION = 0x406D1388;
+ __try
+ {
+ RaiseException( MS_VC_EXCEPTION, 0, sizeof( info ) / sizeof( ULONG_PTR ), (ULONG_PTR*)&info );
+ }
+ __except( EXCEPTION_EXECUTE_HANDLER )
+ {
+ }
+} \ No newline at end of file
diff --git a/Whisper/Utils/miscUtils.h b/Whisper/Utils/miscUtils.h
new file mode 100644
index 0000000..d665cbc
--- /dev/null
+++ b/Whisper/Utils/miscUtils.h
@@ -0,0 +1,81 @@
+#pragma once
+
+#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; }
+#define CHECK_LOG( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) { logErrorHr(__hr, u8"%s failed", #hr ); return __hr; } }
+
+inline void check( HRESULT hr )
+{
+ if( SUCCEEDED( hr ) )
+ return;
+ throw hr;
+}
+
+inline __m128i load16( const int* rsi )
+{
+ return _mm_loadu_si128( ( const __m128i* )rsi );
+}
+inline __m128i load16( const uint32_t* rsi )
+{
+ return _mm_loadu_si128( ( const __m128i* )rsi );
+}
+inline __m128i load( const std::array<uint32_t, 4>& arr )
+{
+ return load16( arr.data() );
+}
+inline void store16( void* rdi, __m128i v )
+{
+ _mm_storeu_si128( ( __m128i* )rdi, v );
+}
+inline void store12( void* rdi, __m128i v )
+{
+ _mm_storel_epi64( ( __m128i* )rdi, v );
+ ( (int*)rdi )[ 2 ] = _mm_extract_epi32( v, 2 );
+}
+inline void store( std::array<uint32_t, 4>& arr, __m128i v )
+{
+ store16( arr.data(), v );
+}
+inline bool vectorEqual( __m128i a, __m128i b )
+{
+ __m128i xx = _mm_xor_si128( a, b );
+ return (bool)_mm_testz_si128( xx, xx );
+}
+
+inline __m128i setLow_size( size_t low )
+{
+ return _mm_cvtsi64_si128( (int64_t)low );
+}
+inline __m128i setr_size( size_t low, size_t high )
+{
+ __m128i v = setLow_size( low );
+ v = _mm_insert_epi64( v, (int64_t)high, 1 );
+ return v;
+}
+inline __m128i setHigh_size( size_t high )
+{
+ __m128i v = _mm_setzero_si128();
+ v = _mm_insert_epi64( v, (int64_t)high, 1 );
+ return v;
+}
+
+void setCurrentThreadName( const char* name );
+
+inline HRESULT getLastHr()
+{
+ return HRESULT_FROM_WIN32( GetLastError() );
+}
+
+// Scale time in seconds from unsigned 64 bit rational number ( mul / div ) into 100-nanosecond ticks
+// These 100-nanosecond ticks are used in NTFS, FILETIME, .NET standard library, media foundation, and quite a few other places
+inline uint64_t makeTime( uint64_t mul, uint64_t div )
+{
+ mul *= 10'000'000;
+ mul += ( ( div / 2 ) - 1 );
+ return mul / div;
+}
+
+template<class E>
+inline size_t vectorMemoryUse( const std::vector<E>& vec )
+{
+ return sizeof( E ) * vec.capacity();
+} \ No newline at end of file
diff --git a/Whisper/Utils/parallelFor.cpp b/Whisper/Utils/parallelFor.cpp
new file mode 100644
index 0000000..c2b324b
--- /dev/null
+++ b/Whisper/Utils/parallelFor.cpp
@@ -0,0 +1,144 @@
+#include "stdafx.h"
+#include "parallelFor.h"
+
+namespace
+{
+ class alignas( 64 ) ParallelForContext
+ {
+ volatile long threadIndex;
+ volatile HRESULT status;
+
+ alignas( 64 ) void* const context;
+ const Whisper::pfnParallelForCallback pfn;
+
+ static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work );
+
+ public:
+
+ ParallelForContext( void* ctx, Whisper::pfnParallelForCallback pfn );
+
+ PTP_WORK createWork();
+
+ HRESULT getStatus() const;
+ };
+
+ ParallelForContext::ParallelForContext( void* ctx, Whisper::pfnParallelForCallback callback ) :
+ threadIndex( 1 ),
+ status( S_FALSE ),
+ context( ctx ),
+ pfn( callback )
+ { }
+
+ PTP_WORK ParallelForContext::createWork()
+ {
+ return CreateThreadpoolWork( &callbackStatic, this, nullptr );
+ }
+
+ void __stdcall ParallelForContext::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work )
+ {
+ ParallelForContext& context = *(ParallelForContext*)pv;
+ int ith = InterlockedIncrement( &context.threadIndex );
+ ith--;
+ const HRESULT hr = context.pfn( ith, context.context );
+ if( SUCCEEDED( hr ) )
+ return;
+ InterlockedCompareExchange( &context.status, hr, S_FALSE );
+ }
+
+ HRESULT ParallelForContext::getStatus() const
+ {
+ const HRESULT hr = status;
+ if( SUCCEEDED( hr ) )
+ return S_OK;
+ return hr;
+ }
+}
+
+namespace Whisper
+{
+ HRESULT parallelFor( pfnParallelForCallback pfn, int threadsCount, void* ctx )
+ {
+ if( threadsCount < 1 )
+ return E_BOUNDS;
+ if( threadsCount == 1 )
+ return pfn( 0, ctx );
+
+ ParallelForContext context{ ctx, pfn };
+
+ PTP_WORK const pw = context.createWork();
+ if( nullptr == pw )
+ return getLastHr();
+
+ for( int i = 1; i < threadsCount; i++ )
+ SubmitThreadpoolWork( pw );
+
+ const HRESULT hr0 = pfn( 0, ctx );
+
+ WaitForThreadpoolWorkCallbacks( pw, FALSE );
+ CloseThreadpoolWork( pw );
+
+ if( FAILED( hr0 ) )
+ return hr0;
+ return context.getStatus();
+ }
+}
+
+using namespace Whisper;
+
+ThreadPoolWork::~ThreadPoolWork()
+{
+ if( nullptr != work )
+ {
+ CloseThreadpoolWork( work );
+ work = nullptr;
+ }
+}
+
+HRESULT ThreadPoolWork::create()
+{
+ if( nullptr == work )
+ {
+ work = CreateThreadpoolWork( &callbackStatic, this, nullptr );
+ if( nullptr != work )
+ return S_OK;
+ return getLastHr();
+ }
+ return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
+}
+
+HRESULT ThreadPoolWork::parallelFor( int threadsCount ) noexcept
+{
+ if( nullptr != work )
+ {
+ if( threadsCount <= 1 )
+ return threadPoolCallback( 0 );
+
+ threadIndex = 1;
+ status = S_FALSE;
+ for( int i = 1; i < threadsCount; i++ )
+ SubmitThreadpoolWork( work );
+
+ const HRESULT hr0 = threadPoolCallback( 0 );
+
+ WaitForThreadpoolWorkCallbacks( work, FALSE );
+
+ if( FAILED( hr0 ) )
+ return hr0;
+ if( SUCCEEDED( status ) )
+ return S_OK;
+ return status;
+ }
+
+ return OLE_E_BLANK;
+}
+
+void __stdcall ThreadPoolWork::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work )
+{
+ ThreadPoolWork* tpw = (ThreadPoolWork*)pv;
+ int ith = InterlockedIncrement( &tpw->threadIndex );
+ ith--;
+ const HRESULT hr = tpw->threadPoolCallback( ith );
+ if( SUCCEEDED( hr ) )
+ return;
+ InterlockedCompareExchange( &tpw->status, hr, S_FALSE );
+} \ No newline at end of file
diff --git a/Whisper/Utils/parallelFor.h b/Whisper/Utils/parallelFor.h
new file mode 100644
index 0000000..15cd603
--- /dev/null
+++ b/Whisper/Utils/parallelFor.h
@@ -0,0 +1,38 @@
+#pragma once
+
+namespace Whisper
+{
+ // A callback to offload to the thread pool
+ using pfnParallelForCallback = HRESULT( * )( int ith, void* ctx ) noexcept;
+
+ // A simple parallel for implementation; Windows includes a decent thread pool since Vista (2006)
+ HRESULT parallelFor( pfnParallelForCallback pfn, int threadsCount, void* ctx );
+
+ // Use this version when you wanna use the thread pool repeatedly, for the same work.
+ // This class caches native work handle, saving a couple of WinAPI calls.
+ class alignas( 64 ) ThreadPoolWork
+ {
+ PTP_WORK work = nullptr;
+
+ // We want these volatile fields in another cache line from the rest of the data of this class.
+ // threadIndex field is concurrently modified by different CPU cores, and these cache coherency protocols are slow.
+ // OTOH, work and callback fields of this class only change when created / destroyed, that cache line is shared by CPU cores without any performance penalty.
+ alignas( 64 ) volatile long threadIndex = 0;
+ volatile HRESULT status = E_UNEXPECTED;
+
+ static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work );
+
+ protected:
+ virtual HRESULT threadPoolCallback( int ith ) noexcept = 0;
+
+ public:
+ ThreadPoolWork() = default;
+ ThreadPoolWork( const ThreadPoolWork& ) = delete;
+
+ ~ThreadPoolWork();
+
+ HRESULT create();
+
+ HRESULT parallelFor( int threadsCount ) noexcept;
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper.vcxproj b/Whisper/Whisper.vcxproj
new file mode 100644
index 0000000..f270440
--- /dev/null
+++ b/Whisper/Whisper.vcxproj
@@ -0,0 +1,347 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>16.0</VCProjectVersion>
+ <Keyword>Win32Proj</Keyword>
+ <ProjectGuid>{701df8c8-e4a5-43ec-9c6b-747bbf4d8e71}</ProjectGuid>
+ <RootNamespace>Whisper</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>DynamicLibrary</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>DynamicLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <IncludePath>$(ProjectDir);$(IncludePath)</IncludePath>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <IncludePath>$(ProjectDir);$(IncludePath)</IncludePath>
+ </PropertyGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>_DEBUG;WHISPER_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ <PrecompiledHeader>Use</PrecompiledHeader>
+ <MultiProcessorCompilation>true</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <EnableUAC>false</EnableUAC>
+ <ModuleDefinitionFile>whisper.def</ModuleDefinitionFile>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;WHISPER_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ <PrecompiledHeader>Use</PrecompiledHeader>
+ <MultiProcessorCompilation>true</MultiProcessorCompilation>
+ <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <EnableUAC>false</EnableUAC>
+ <ModuleDefinitionFile>whisper.def</ModuleDefinitionFile>
+ <LinkTimeCodeGeneration>UseLinkTimeCodeGeneration</LinkTimeCodeGeneration>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ProjectReference Include="..\ComLightLib\ComLightLib.vcxproj">
+ <Project>{52f486e7-830c-45d8-be47-e76b5aab2772}</Project>
+ </ProjectReference>
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="CPU\BufferAllocator.cpp" />
+ <ClCompile Include="CPU\DecoderTensors.cpp" />
+ <ClCompile Include="CPU\mulMatImpl.avx2.cpp">
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions2</EnableEnhancedInstructionSet>
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions2</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="CPU\mulMatImpl.panel.cpp">
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="CPU\TensorCpu.cpp">
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="CPU\HybridLoader.cpp" />
+ <ClCompile Include="Hybrid\HybridContext.cpp" />
+ <ClCompile Include="CPU\ParallelForRunner.cpp">
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="CPU\LargeBuffer.cpp" />
+ <ClCompile Include="CPU\simdUtils.cpp">
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="CPU\mulMat.cpp">
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="CPU\MlContextCpu.cpp">
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="CPU\KvTensorsCpu.cpp" />
+ <ClCompile Include="Hybrid\KeyValueDownloader.cpp" />
+ <ClCompile Include="CPU\mulMatImpl.cpp">
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ <EnableEnhancedInstructionSet Condition="'$(Configuration)|$(Platform)'=='Release|x64'">AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="ML\Reshaper.cpp" />
+ <ClCompile Include="Utils\Logger.cpp" />
+ <ClCompile Include="MF\AudioCapture.cpp" />
+ <ClCompile Include="Utils\miscUtils.cpp" />
+ <ClCompile Include="Whisper\voiceActivityDetection.cpp" />
+ <ClCompile Include="Whisper\ContextImpl.capture.cpp" />
+ <ClCompile Include="Whisper\MelStreamer.cpp" />
+ <ClCompile Include="Whisper\melSpectrogram.cpp" />
+ <ClCompile Include="modelFactory.cpp" />
+ <ClCompile Include="MF\AudioBuffer.cpp" />
+ <ClCompile Include="MF\PcmReader.cpp" />
+ <ClCompile Include="Utils\Trace\tracing.cpp" />
+ <ClCompile Include="Utils\Trace\TraceStructures.cpp" />
+ <ClCompile Include="Utils\Trace\TraceWriter.cpp" />
+ <ClCompile Include="source.compat\convertThings.cpp" />
+ <ClCompile Include="D3D\shaderNames.cpp" />
+ <ClCompile Include="MF\mfUtils.cpp" />
+ <ClCompile Include="MF\loadAudioFile.cpp" />
+ <ClCompile Include="MF\MediaFoundation.cpp" />
+ <ClCompile Include="MF\mfStartup.cpp" />
+ <ClCompile Include="source.compat\ggmlMsvc.c">
+ <PrecompiledHeader>NotUsing</PrecompiledHeader>
+ <EnableEnhancedInstructionSet>AdvancedVectorExtensions</EnableEnhancedInstructionSet>
+ </ClCompile>
+ <ClCompile Include="Whisper\ContextImpl.misc.cpp" />
+ <ClCompile Include="Utils\ProfileCollection.cpp" />
+ <ClCompile Include="Utils\CpuProfiler.cpp" />
+ <ClCompile Include="D3D\enums.cpp" />
+ <ClCompile Include="Utils\GpuProfiler.cpp" />
+ <ClCompile Include="ML\TensorsArena.cpp" />
+ <ClCompile Include="Whisper\Languages.cpp" />
+ <ClCompile Include="Whisper\ContextImpl.cpp" />
+ <ClCompile Include="Whisper\ModelImpl.cpp" />
+ <ClCompile Include="Utils\parallelFor.cpp" />
+ <ClCompile Include="Whisper\Spectrogram.cpp" />
+ <ClCompile Include="Whisper\WhisperModel.cpp" />
+ <ClCompile Include="Whisper\Vocabulary.cpp" />
+ <ClCompile Include="Whisper\DecoderResultBuffer.cpp" />
+ <ClCompile Include="Whisper\DecoderInputBuffers.cpp" />
+ <ClCompile Include="ML\mlStartup.cpp" />
+ <ClCompile Include="Whisper\KeyValueBuffers.cpp" />
+ <ClCompile Include="D3D\Binder.cpp" />
+ <ClCompile Include="ML\LookupTables.cpp" />
+ <ClCompile Include="ML\LookupTablesData.cpp" />
+ <ClCompile Include="ML\Context.ops.cpp" />
+ <ClCompile Include="ML\MlContext.dbg.cpp" />
+ <ClCompile Include="ML\MlContext.cpp" />
+ <ClCompile Include="ML\ConstantBuffer.cpp" />
+ <ClCompile Include="D3D\createBuffer.cpp" />
+ <ClCompile Include="D3D\device.cpp" />
+ <ClCompile Include="D3D\MappedResource.cpp" />
+ <ClCompile Include="Whisper\WhisperContext.cpp" />
+ <ClCompile Include="Whisper\ModelBuffers.cpp" />
+ <ClCompile Include="D3D\shaders.cpp" />
+ <ClCompile Include="ML\TensorShape.cpp" />
+ <ClCompile Include="DllMain.cpp" />
+ <ClCompile Include="D3D\downloadBuffer.cpp" />
+ <ClCompile Include="D3D\RenderDoc\renderDoc.cpp" />
+ <ClCompile Include="Whisper\MelInputTensor.cpp" />
+ <ClCompile Include="source\ggml.c">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="source\whisper.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="D3D\startup.cpp" />
+ <ClCompile Include="ML\TempBuffers.cpp" />
+ <ClCompile Include="ML\tensorOpsTests.cpp" />
+ <ClCompile Include="stdafx.cpp">
+ <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
+ <PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
+ </ClCompile>
+ <ClCompile Include="ML\testUtils.cpp" />
+ <ClCompile Include="ML\Tensor.cpp" />
+ <ClCompile Include="ML\TensorGpuViews.cpp" />
+ <ClCompile Include="ML\TensorEx.cpp" />
+ <ClCompile Include="whisperCom.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="API\iContext.h" />
+ <ClInclude Include="API\iMediaFoundation.h" />
+ <ClInclude Include="API\iTranscribeResult.h" />
+ <ClInclude Include="API\loggerApi.h" />
+ <ClInclude Include="API\MfStructs.h" />
+ <ClInclude Include="API\iContext.cl.h" />
+ <ClInclude Include="API\iMediaFoundation.cl.h" />
+ <ClInclude Include="API\iTranscribeResult.cl.h" />
+ <ClInclude Include="API\sLanguageList.h" />
+ <ClInclude Include="API\sLoadModelCallbacks.h" />
+ <ClInclude Include="API\SpecialTokens.h" />
+ <ClInclude Include="API\sFullParams.h" />
+ <ClInclude Include="API\whisperComLight.h" />
+ <ClInclude Include="API\whisperWindows.h" />
+ <ClInclude Include="CPU\BufferAllocator.h" />
+ <ClInclude Include="CPU\mulMatUtils.hpp" />
+ <ClInclude Include="CPU\Tensor.h" />
+ <ClInclude Include="CPU\DecoderTensors.h" />
+ <ClInclude Include="CPU\HybridLoader.h" />
+ <ClInclude Include="Hybrid\HybridContext.h" />
+ <ClInclude Include="CPU\ParallelForRunner.h" />
+ <ClInclude Include="CPU\LargeBuffer.h" />
+ <ClInclude Include="CPU\simdUtils.h" />
+ <ClInclude Include="CPU\MlContext.h" />
+ <ClInclude Include="CPU\KvTensors.h" />
+ <ClInclude Include="Hybrid\KeyValueDownloader.h" />
+ <ClInclude Include="ML\reshapedMultiply.h" />
+ <ClInclude Include="ML\testUtilsC.h" />
+ <ClInclude Include="CPU\mulMat.h" />
+ <ClInclude Include="CPU\mulMatImpl.h" />
+ <ClInclude Include="ML\Reshaper.h" />
+ <ClInclude Include="Utils\Logger.h" />
+ <ClInclude Include="MF\AudioCapture.h" />
+ <ClInclude Include="Whisper\sModelParams.h" />
+ <ClInclude Include="Whisper\voiceActivityDetection.h" />
+ <ClInclude Include="Whisper\MelStreamer.h" />
+ <ClInclude Include="Whisper\melSpectrogram.h" />
+ <ClInclude Include="modelFactory.h" />
+ <ClInclude Include="MF\AudioBuffer.h" />
+ <ClInclude Include="MF\PcmReader.h" />
+ <ClInclude Include="Utils\miscUtils.h" />
+ <ClInclude Include="Utils\Trace\tracing.h" />
+ <ClInclude Include="Utils\Trace\TraceStructures.h" />
+ <ClInclude Include="Utils\Trace\TraceWriter.h" />
+ <ClInclude Include="source.compat\convertThings.h" />
+ <ClInclude Include="MF\mfUtils.h" />
+ <ClInclude Include="MF\loadAudioFile.h" />
+ <ClInclude Include="MF\mfStartup.h" />
+ <ClInclude Include="API\TranscribeStructs.h" />
+ <ClInclude Include="resource.h" />
+ <ClInclude Include="Utils\ReadStream.h" />
+ <ClInclude Include="Whisper\audioConstants.h" />
+ <ClInclude Include="Whisper\iSpectrogram.h" />
+ <ClInclude Include="Whisper\sTokenData.h" />
+ <ClInclude Include="Whisper\TranscribeResult.h" />
+ <ClInclude Include="Utils\ProfileCollection.h" />
+ <ClInclude Include="Utils\CpuProfiler.h" />
+ <ClInclude Include="Utils\GpuProfiler.h" />
+ <ClInclude Include="ML\TensorsArena.h" />
+ <ClInclude Include="Utils\GpuProfilerSimple.h" />
+ <ClInclude Include="Whisper\Languages.h" />
+ <ClInclude Include="Whisper\ContextImpl.h" />
+ <ClInclude Include="Whisper\ModelImpl.h" />
+ <ClInclude Include="Utils\parallelFor.h" />
+ <ClInclude Include="Whisper\Spectrogram.h" />
+ <ClInclude Include="Whisper\loaderUtils.h" />
+ <ClInclude Include="Whisper\WhisperModel.h" />
+ <ClInclude Include="Whisper\Vocabulary.h" />
+ <ClInclude Include="Whisper\DecoderResultBuffer.h" />
+ <ClInclude Include="Whisper\DecoderInputBuffers.h" />
+ <ClInclude Include="ML\mlStartup.h" />
+ <ClInclude Include="Whisper\KeyValueBuffers.h" />
+ <ClInclude Include="D3D\Binder.h" />
+ <ClInclude Include="ML\LookupTables.h" />
+ <ClInclude Include="ML\LookupTablesData.h" />
+ <ClInclude Include="ML\MlContext.h" />
+ <ClInclude Include="ML\ConstantBuffer.h" />
+ <ClInclude Include="D3D\device.h" />
+ <ClInclude Include="D3D\createBuffer.h" />
+ <ClInclude Include="D3D\enums.h" />
+ <ClInclude Include="D3D\MappedResource.h" />
+ <ClInclude Include="Whisper\sEncodeParams.h" />
+ <ClInclude Include="Whisper\WhisperContext.h" />
+ <ClInclude Include="Whisper\ModelBuffers.h" />
+ <ClInclude Include="Whisper\ModelLoader.h" />
+ <ClInclude Include="D3D\RenderDoc\renderdoc_app.h" />
+ <ClInclude Include="D3D\shaderNames.h" />
+ <ClInclude Include="D3D\shaders.h" />
+ <ClInclude Include="D3D\downloadBuffer.h" />
+ <ClInclude Include="D3D\RenderDoc\renderDoc.h" />
+ <ClInclude Include="Whisper\MelInputTensor.h" />
+ <ClInclude Include="ML\TensorShape.h" />
+ <ClInclude Include="source\ggml.h" />
+ <ClInclude Include="source\whisper.h" />
+ <ClInclude Include="D3D\startup.h" />
+ <ClInclude Include="ML\TempBuffers.h" />
+ <ClInclude Include="ML\tensorOpsTests.h" />
+ <ClInclude Include="stdafx.h" />
+ <ClInclude Include="ML\testUtils.h" />
+ <ClInclude Include="ML\Tensor.h" />
+ <ClInclude Include="ML\TensorGpuViews.h" />
+ <ClInclude Include="ML\TensorEx.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <None Include="D3D\shaderData-Debug.inl" />
+ <None Include="D3D\shaderData-Release.inl" />
+ <None Include="CPU\mulMat.kernel.hpp" />
+ <None Include="source\LICENSE" />
+ <None Include="whisper.def" />
+ <None Include="Whisper\languageCodez.inl" />
+ <None Include="Whisper\languageCodez.tsv" />
+ </ItemGroup>
+ <ItemGroup>
+ <Natvis Include="misc.natvis" />
+ </ItemGroup>
+ <ItemGroup>
+ <ResourceCompile Include="Resource.rc" />
+ </ItemGroup>
+ <ItemGroup>
+ <Text Include="API\Readme.txt" />
+ <Text Include="CPU\Readme.txt" />
+ <Text Include="Hybrid\Readme.txt" />
+ <Text Include="Readme.txt" />
+ <Text Include="source.compat\Readme.txt" />
+ <Text Include="source\Readme.txt" />
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/Whisper/Whisper.vcxproj.filters b/Whisper/Whisper.vcxproj.filters
new file mode 100644
index 0000000..193fbe7
--- /dev/null
+++ b/Whisper/Whisper.vcxproj.filters
@@ -0,0 +1,214 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <ClCompile Include="DllMain.cpp" />
+ <ClCompile Include="whisperCom.cpp" />
+ <ClCompile Include="source\ggml.c" />
+ <ClCompile Include="source\whisper.cpp" />
+ <ClCompile Include="D3D\device.cpp" />
+ <ClCompile Include="ML\ConstantBuffer.cpp" />
+ <ClCompile Include="D3D\MappedResource.cpp" />
+ <ClCompile Include="D3D\shaders.cpp" />
+ <ClCompile Include="D3D\startup.cpp" />
+ <ClCompile Include="D3D\createBuffer.cpp" />
+ <ClCompile Include="ML\TempBuffers.cpp" />
+ <ClCompile Include="ML\MlContext.cpp" />
+ <ClCompile Include="ML\tensorOpsTests.cpp" />
+ <ClCompile Include="D3D\Binder.cpp" />
+ <ClCompile Include="stdafx.cpp" />
+ <ClCompile Include="ML\testUtils.cpp" />
+ <ClCompile Include="D3D\downloadBuffer.cpp" />
+ <ClCompile Include="D3D\RenderDoc\renderDoc.cpp" />
+ <ClCompile Include="Whisper\MelInputTensor.cpp" />
+ <ClCompile Include="Whisper\ModelBuffers.cpp" />
+ <ClCompile Include="ML\Context.ops.cpp" />
+ <ClCompile Include="ML\TensorShape.cpp" />
+ <ClCompile Include="ML\Tensor.cpp" />
+ <ClCompile Include="ML\TensorGpuViews.cpp" />
+ <ClCompile Include="ML\TensorEx.cpp" />
+ <ClCompile Include="Whisper\WhisperContext.cpp" />
+ <ClCompile Include="ML\MlContext.dbg.cpp" />
+ <ClCompile Include="ML\LookupTablesData.cpp" />
+ <ClCompile Include="ML\LookupTables.cpp" />
+ <ClCompile Include="Whisper\KeyValueBuffers.cpp" />
+ <ClCompile Include="ML\mlStartup.cpp" />
+ <ClCompile Include="Whisper\DecoderInputBuffers.cpp" />
+ <ClCompile Include="Whisper\DecoderResultBuffer.cpp" />
+ <ClCompile Include="Whisper\Vocabulary.cpp" />
+ <ClCompile Include="Whisper\WhisperModel.cpp" />
+ <ClCompile Include="Whisper\Spectrogram.cpp" />
+ <ClCompile Include="Utils\parallelFor.cpp" />
+ <ClCompile Include="Whisper\ModelImpl.cpp" />
+ <ClCompile Include="Whisper\ContextImpl.cpp" />
+ <ClCompile Include="Whisper\Languages.cpp" />
+ <ClCompile Include="ML\TensorsArena.cpp" />
+ <ClCompile Include="D3D\enums.cpp" />
+ <ClCompile Include="Utils\GpuProfiler.cpp" />
+ <ClCompile Include="Utils\CpuProfiler.cpp" />
+ <ClCompile Include="Utils\ProfileCollection.cpp" />
+ <ClCompile Include="D3D\shaderNames.cpp" />
+ <ClCompile Include="MF\mfStartup.cpp" />
+ <ClCompile Include="MF\MediaFoundation.cpp" />
+ <ClCompile Include="MF\loadAudioFile.cpp" />
+ <ClCompile Include="MF\mfUtils.cpp" />
+ <ClCompile Include="source.compat\convertThings.cpp" />
+ <ClCompile Include="source.compat\ggmlMsvc.c" />
+ <ClCompile Include="Whisper\ContextImpl.misc.cpp" />
+ <ClCompile Include="Utils\Trace\TraceWriter.cpp" />
+ <ClCompile Include="Utils\Trace\TraceStructures.cpp" />
+ <ClCompile Include="Utils\Trace\tracing.cpp" />
+ <ClCompile Include="MF\AudioBuffer.cpp" />
+ <ClCompile Include="modelFactory.cpp" />
+ <ClCompile Include="MF\PcmReader.cpp" />
+ <ClCompile Include="Whisper\melSpectrogram.cpp" />
+ <ClCompile Include="Whisper\MelStreamer.cpp" />
+ <ClCompile Include="Utils\miscUtils.cpp" />
+ <ClCompile Include="MF\AudioCapture.cpp" />
+ <ClCompile Include="Utils\Logger.cpp" />
+ <ClCompile Include="Whisper\ContextImpl.capture.cpp" />
+ <ClCompile Include="Whisper\voiceActivityDetection.cpp" />
+ <ClCompile Include="CPU\LargeBuffer.cpp" />
+ <ClCompile Include="CPU\ParallelForRunner.cpp" />
+ <ClCompile Include="CPU\simdUtils.cpp" />
+ <ClCompile Include="CPU\mulMat.cpp" />
+ <ClCompile Include="CPU\TensorCpu.cpp" />
+ <ClCompile Include="CPU\MlContextCpu.cpp" />
+ <ClCompile Include="CPU\BufferAllocator.cpp" />
+ <ClCompile Include="CPU\HybridLoader.cpp" />
+ <ClCompile Include="CPU\DecoderTensors.cpp" />
+ <ClCompile Include="Hybrid\HybridContext.cpp" />
+ <ClCompile Include="CPU\KvTensorsCpu.cpp" />
+ <ClCompile Include="Hybrid\KeyValueDownloader.cpp" />
+ <ClCompile Include="CPU\mulMatImpl.cpp" />
+ <ClCompile Include="CPU\mulMatImpl.avx2.cpp" />
+ <ClCompile Include="CPU\mulMatImpl.panel.cpp" />
+ <ClCompile Include="ML\Reshaper.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="source\ggml.h" />
+ <ClInclude Include="source\whisper.h" />
+ <ClInclude Include="API\iContext.cl.h" />
+ <ClInclude Include="API\sFullParams.h" />
+ <ClInclude Include="D3D\device.h" />
+ <ClInclude Include="ML\ConstantBuffer.h" />
+ <ClInclude Include="D3D\MappedResource.h" />
+ <ClInclude Include="D3D\shaderNames.h" />
+ <ClInclude Include="D3D\shaders.h" />
+ <ClInclude Include="D3D\startup.h" />
+ <ClInclude Include="D3D\createBuffer.h" />
+ <ClInclude Include="ML\TempBuffers.h" />
+ <ClInclude Include="ML\MlContext.h" />
+ <ClInclude Include="ML\tensorOpsTests.h" />
+ <ClInclude Include="D3D\Binder.h" />
+ <ClInclude Include="stdafx.h" />
+ <ClInclude Include="ML\testUtils.h" />
+ <ClInclude Include="D3D\downloadBuffer.h" />
+ <ClInclude Include="D3D\RenderDoc\renderdoc_app.h" />
+ <ClInclude Include="D3D\RenderDoc\renderDoc.h" />
+ <ClInclude Include="Whisper\MelInputTensor.h" />
+ <ClInclude Include="Whisper\ModelBuffers.h" />
+ <ClInclude Include="Whisper\ModelLoader.h" />
+ <ClInclude Include="ML\TensorShape.h" />
+ <ClInclude Include="ML\Tensor.h" />
+ <ClInclude Include="ML\TensorGpuViews.h" />
+ <ClInclude Include="D3D\enums.h" />
+ <ClInclude Include="ML\TensorEx.h" />
+ <ClInclude Include="Whisper\WhisperContext.h" />
+ <ClInclude Include="ML\LookupTablesData.h" />
+ <ClInclude Include="ML\LookupTables.h" />
+ <ClInclude Include="Whisper\sEncodeParams.h" />
+ <ClInclude Include="Whisper\KeyValueBuffers.h" />
+ <ClInclude Include="ML\mlStartup.h" />
+ <ClInclude Include="Whisper\DecoderInputBuffers.h" />
+ <ClInclude Include="Whisper\DecoderResultBuffer.h" />
+ <ClInclude Include="Whisper\Vocabulary.h" />
+ <ClInclude Include="Whisper\WhisperModel.h" />
+ <ClInclude Include="Whisper\loaderUtils.h" />
+ <ClInclude Include="Whisper\Spectrogram.h" />
+ <ClInclude Include="Utils\parallelFor.h" />
+ <ClInclude Include="Whisper\ModelImpl.h" />
+ <ClInclude Include="Whisper\ContextImpl.h" />
+ <ClInclude Include="Whisper\Languages.h" />
+ <ClInclude Include="ML\TensorsArena.h" />
+ <ClInclude Include="Utils\GpuProfiler.h" />
+ <ClInclude Include="Utils\GpuProfilerSimple.h" />
+ <ClInclude Include="Utils\CpuProfiler.h" />
+ <ClInclude Include="Utils\ProfileCollection.h" />
+ <ClInclude Include="MF\mfStartup.h" />
+ <ClInclude Include="API\iMediaFoundation.cl.h" />
+ <ClInclude Include="MF\loadAudioFile.h" />
+ <ClInclude Include="MF\mfUtils.h" />
+ <ClInclude Include="API\TranscribeStructs.h" />
+ <ClInclude Include="API\iTranscribeResult.cl.h" />
+ <ClInclude Include="Whisper\TranscribeResult.h" />
+ <ClInclude Include="Utils\ReadStream.h" />
+ <ClInclude Include="API\SpecialTokens.h" />
+ <ClInclude Include="Whisper\sTokenData.h" />
+ <ClInclude Include="resource.h" />
+ <ClInclude Include="source.compat\convertThings.h" />
+ <ClInclude Include="Utils\Trace\TraceWriter.h" />
+ <ClInclude Include="Utils\Trace\TraceStructures.h" />
+ <ClInclude Include="Utils\Trace\tracing.h" />
+ <ClInclude Include="Utils\miscUtils.h" />
+ <ClInclude Include="MF\AudioBuffer.h" />
+ <ClInclude Include="modelFactory.h" />
+ <ClInclude Include="Whisper\iSpectrogram.h" />
+ <ClInclude Include="MF\PcmReader.h" />
+ <ClInclude Include="Whisper\audioConstants.h" />
+ <ClInclude Include="Whisper\melSpectrogram.h" />
+ <ClInclude Include="Whisper\MelStreamer.h" />
+ <ClInclude Include="API\MfStructs.h" />
+ <ClInclude Include="MF\AudioCapture.h" />
+ <ClInclude Include="API\loggerApi.h" />
+ <ClInclude Include="Utils\Logger.h" />
+ <ClInclude Include="Whisper\voiceActivityDetection.h" />
+ <ClInclude Include="CPU\LargeBuffer.h" />
+ <ClInclude Include="API\iContext.h" />
+ <ClInclude Include="API\iMediaFoundation.h" />
+ <ClInclude Include="API\iTranscribeResult.h" />
+ <ClInclude Include="API\whisperComLight.h" />
+ <ClInclude Include="API\whisperWindows.h" />
+ <ClInclude Include="API\sLanguageList.h" />
+ <ClInclude Include="CPU\ParallelForRunner.h" />
+ <ClInclude Include="CPU\simdUtils.h" />
+ <ClInclude Include="ML\testUtilsC.h" />
+ <ClInclude Include="CPU\mulMat.h" />
+ <ClInclude Include="CPU\Tensor.h" />
+ <ClInclude Include="CPU\MlContext.h" />
+ <ClInclude Include="CPU\BufferAllocator.h" />
+ <ClInclude Include="CPU\DecoderTensors.h" />
+ <ClInclude Include="CPU\HybridLoader.h" />
+ <ClInclude Include="Whisper\sModelParams.h" />
+ <ClInclude Include="Hybrid\HybridContext.h" />
+ <ClInclude Include="CPU\KvTensors.h" />
+ <ClInclude Include="Hybrid\KeyValueDownloader.h" />
+ <ClInclude Include="CPU\mulMatUtils.hpp" />
+ <ClInclude Include="CPU\mulMatImpl.h" />
+ <ClInclude Include="API\sLoadModelCallbacks.h" />
+ <ClInclude Include="ML\Reshaper.h" />
+ <ClInclude Include="ML\reshapedMultiply.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <None Include="whisper.def" />
+ <None Include="D3D\shaderData-Debug.inl" />
+ <None Include="D3D\shaderData-Release.inl" />
+ <None Include="Whisper\languageCodez.inl" />
+ <None Include="Whisper\languageCodez.tsv" />
+ <None Include="CPU\mulMat.kernel.hpp" />
+ <None Include="source\LICENSE" />
+ </ItemGroup>
+ <ItemGroup>
+ <Natvis Include="misc.natvis" />
+ </ItemGroup>
+ <ItemGroup>
+ <ResourceCompile Include="Resource.rc" />
+ </ItemGroup>
+ <ItemGroup>
+ <Text Include="Readme.txt" />
+ <Text Include="API\Readme.txt" />
+ <Text Include="source.compat\Readme.txt" />
+ <Text Include="source\Readme.txt" />
+ <Text Include="Hybrid\Readme.txt" />
+ <Text Include="CPU\Readme.txt" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/Whisper/Whisper/ContextImpl.capture.cpp b/Whisper/Whisper/ContextImpl.capture.cpp
new file mode 100644
index 0000000..86dc0d2
--- /dev/null
+++ b/Whisper/Whisper/ContextImpl.capture.cpp
@@ -0,0 +1,418 @@
+#include "stdafx.h"
+#include "ContextImpl.h"
+#include "../API/iMediaFoundation.cl.h"
+#include "../MF/AudioBuffer.h"
+#include "../MF/mfUtils.h"
+#include <mfidl.h>
+#include <mfapi.h>
+#include <mfreadwrite.h>
+#include "voiceActivityDetection.h"
+
+namespace
+{
+ using namespace Whisper;
+
+ class TranscribeBuffer : public ComLight::ObjectRoot<iAudioBuffer>
+ {
+ // ==== iAudioBuffer ====
+ uint32_t COMLIGHTCALL countSamples() const override final
+ {
+ return (uint32_t)pcm.mono.size();
+ }
+ const float* COMLIGHTCALL getPcmMono() const override final
+ {
+ if( !pcm.mono.empty() )
+ return pcm.mono.data();
+ return nullptr;
+ }
+ const float* COMLIGHTCALL getPcmStereo() const override final
+ {
+ if( !pcm.stereo.empty() )
+ return pcm.stereo.data();
+ return nullptr;
+ }
+ HRESULT COMLIGHTCALL getTime( int64_t& rdi ) const override final
+ {
+ rdi = MFllMulDiv( currentOffset, 10'000'000, SAMPLE_RATE, 0 );
+ return S_OK;
+ }
+ public:
+ AudioBuffer pcm;
+ int64_t currentOffset = 0;
+ };
+
+ class TranscribeBufferObj : public ComLight::Object<TranscribeBuffer>
+ {
+ uint32_t Release() override final
+ {
+ return RefCounter::implRelease();
+ }
+ };
+
+ struct CaptureParams
+ {
+ uint32_t minDuration, maxDuration, dropStartSilence, pauseDuration;
+ uint32_t flags;
+
+ CaptureParams( const sCaptureParams& cp )
+ {
+ // Convert these floats from seconds to samples
+ __m128 floats = _mm_loadu_ps( &cp.minDuration );
+ floats = _mm_mul_ps( floats, _mm_set1_ps( (float)SAMPLE_RATE ) );
+ floats = _mm_round_ps( floats, _MM_FROUND_NINT );
+ __m128i ints = _mm_cvtps_epi32( floats );
+ store16( &minDuration, ints );
+
+ flags = cp.flags;
+ }
+ };
+
+ class Capture
+ {
+ CComPtr<IMFSourceReader> reader;
+ const CaptureParams captureParams;
+ const sCaptureCallbacks callbacks;
+ // Count of channels delivered from the source reader
+ uint8_t readerChannels = 0;
+ volatile char stateFlags = 0;
+
+ PTP_WORK work = nullptr;
+ volatile HRESULT workStatus = S_OK;
+
+ TranscribeBufferObj buffer;
+ CComAutoCriticalSection critSec;
+ AudioBuffer pcm;
+ AudioBuffer::pfnAppendSamples pfnAppendSamples = nullptr;
+ int64_t pcmStartTime = 0;
+ int64_t nextSampleTime = 0;
+ VAD vad;
+ sFullParams fullParams;
+ ProfileCollection& profiler;
+ iContext* const whisperContext;
+
+ HRESULT setStateFlag( eCaptureStatus newBit ) noexcept
+ {
+ const uint8_t bit = (uint8_t)newBit;
+ const uint8_t oldVal = (uint8_t)InterlockedOr8( &stateFlags, (char)bit );
+ if( nullptr == callbacks.captureStatus )
+ return S_OK; // no callbacks
+ if( 0 != ( oldVal & bit ) )
+ return S_OK; // The bit was already set
+ return callbacks.captureStatus( callbacks.pv, (eCaptureStatus)( oldVal | bit ) );
+ }
+
+ HRESULT clearStateFlag( eCaptureStatus clearBit ) noexcept
+ {
+ const uint8_t bit = (uint8_t)clearBit;
+ const uint8_t mask = ~bit;
+ const uint8_t oldVal = (uint8_t)InterlockedAnd8( &stateFlags, (char)mask );
+ if( nullptr == callbacks.captureStatus )
+ return S_OK; // no callbacks
+ if( 0 == ( oldVal & bit ) )
+ return S_OK; // The bit wasn't there
+ return callbacks.captureStatus( callbacks.pv, (eCaptureStatus)( oldVal & mask ) );
+ }
+
+ bool hasStateFlag( eCaptureStatus testBit ) const
+ {
+ const uint8_t bit = (uint8_t)testBit;
+ return 0 != ( (uint8_t)stateFlags & bit );
+ }
+
+ HRESULT workCallback();
+ static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work );
+
+ HRESULT readSample( bool discard );
+
+ // Run voice detection on the data in pcm.mono vector.
+ // When not detected, return 0. When detected, return last frame index where it is detected.
+ size_t detectVoice();
+
+ HRESULT postPoolWork()
+ {
+ assert( workStatus == S_OK );
+ CHECK( setStateFlag( eCaptureStatus::Transcribing ) );
+
+ workStatus = S_FALSE;
+ buffer.currentOffset = pcmStartTime;
+ buffer.pcm.mono = pcm.mono;
+ buffer.pcm.stereo = pcm.stereo;
+ SubmitThreadpoolWork( work );
+ pcmStartTime = nextSampleTime;
+ pcm.clear();
+ vad.clear();
+ return S_OK;
+ }
+
+ public:
+ Capture( const sCaptureCallbacks& cb, const iAudioCapture* ac, const sFullParams& sfp, iContext* wc, ProfileCollection& pc ) :
+ callbacks( cb ),
+ captureParams( ac->getParams() ),
+ fullParams( sfp ), whisperContext( wc ), profiler( pc )
+ {
+ }
+
+ ~Capture()
+ {
+ if( workStatus == S_FALSE && nullptr != work )
+ WaitForThreadpoolWorkCallbacks( work, FALSE );
+
+ if( nullptr != work )
+ {
+ CloseThreadpoolWork( work );
+ work = nullptr;
+ }
+ }
+
+ HRESULT startup( const iAudioCapture* ac );
+
+ HRESULT checkCancel() noexcept
+ {
+ if( nullptr == callbacks.shouldCancel )
+ return S_FALSE;
+ return callbacks.shouldCancel( callbacks.pv );
+ }
+
+ HRESULT run();
+ };
+
+ HRESULT Capture::startup( const iAudioCapture* ac )
+ {
+ // Initialize the MF source reader
+ CHECK( ac->getReader( &reader ) );
+ work = CreateThreadpoolWork( &callbackStatic, this, nullptr );
+ if( nullptr == work )
+ return HRESULT_FROM_WIN32( GetLastError() );
+
+ // Set up media type, and figure out sample handler
+ CHECK( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) );
+ CHECK( reader->SetStreamSelection( MF_SOURCE_READER_FIRST_AUDIO_STREAM, TRUE ) );
+
+ CComPtr<IMFMediaType> mtNative;
+ CHECK( reader->GetNativeMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, MF_SOURCE_READER_CURRENT_TYPE_INDEX, &mtNative ) );
+ UINT32 numChannels;
+ CHECK( mtNative->GetUINT32( MF_MT_AUDIO_NUM_CHANNELS, &numChannels ) );
+
+ const bool sourceMono = numChannels < 2;
+ const bool wantStereo = 0 != ( captureParams.flags & (uint32_t)eCaptureFlags::Stereo );
+ pfnAppendSamples = AudioBuffer::appendSamplesFunc( sourceMono, wantStereo );
+
+ CComPtr<IMFMediaType> mt;
+ this->readerChannels = ( !sourceMono && wantStereo ) ? 2 : 1;
+ CHECK( createMediaType( !sourceMono, &mt ) );
+ CHECK( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) );
+
+ CHECK( setStateFlag( eCaptureStatus::Listening ) );
+ return S_OK;
+ }
+
+ HRESULT Capture::run()
+ {
+ HRESULT hr;
+ if( hasStateFlag( eCaptureStatus::Stalled ) )
+ {
+ hr = workStatus;
+ CHECK( hr );
+ if( S_OK != hr )
+ {
+ // Still stalled, discard the upcoming sample
+ return readSample( true );
+ }
+ else
+ {
+ // The postponed task has completed by now, no longer stalled
+ // Move the current PCM buffer to the transcribe thread
+ CHECK( clearStateFlag( eCaptureStatus::Stalled ) );
+ return postPoolWork();
+ }
+ }
+
+ const size_t oldSamples = pcm.mono.size();
+ CHECK( readSample( false ) );
+ const size_t newSamples = pcm.mono.size();
+
+ const size_t lastVoiceFrame = detectVoice();
+ if( lastVoiceFrame == 0 )
+ {
+ // No voice is detected in the entire buffered audio
+ clearStateFlag( eCaptureStatus::Voice );
+ if( newSamples < captureParams.dropStartSilence )
+ return S_OK;
+
+ pcm.clear();
+ vad.clear();
+ pcmStartTime = nextSampleTime;
+ return S_OK;
+ }
+
+ const bool newFrameVoice = lastVoiceFrame + captureParams.pauseDuration >= oldSamples;
+
+ if( newFrameVoice )
+ {
+ // A voice is detected in the buffer, and it was fairly recently
+ setStateFlag( eCaptureStatus::Voice );
+ if( newSamples < captureParams.maxDuration )
+ return S_OK; // While voice is continuously detected, we allow to grow the buffer up to `maxDuration` time
+ }
+ else
+ {
+ // A voice is detected in the buffer, but it was a while ago
+ clearStateFlag( eCaptureStatus::Voice );
+ if( newSamples < captureParams.minDuration )
+ return S_OK; // When detected pause in the voice, we fire the transcribe task right away.
+ }
+
+ // Hopefully, we have enough captured PCM data to run the ASR model.
+ // Check the background task status first.
+ hr = workStatus;
+ CHECK( hr );
+ if( hr == S_OK )
+ return postPoolWork();
+
+ // workStatus = S_FALSE means the previous task has not finished yet.
+ // We don't want concurrency here because it's not implemented, and will simply crash.
+ // The "Stalled" flag which will cause capture to drop further samples.
+ setStateFlag( eCaptureStatus::Stalled );
+ return S_OK;
+ }
+
+ HRESULT Capture::readSample( bool discard )
+ {
+ while( true )
+ {
+ DWORD dwFlags = 0;
+ CComPtr<IMFSample> sample;
+
+ // Read the next sample
+ HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"IMFSourceReader.ReadSample" );
+ return hr;
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED )
+ {
+ logError( u8"Media type changes ain’t supported by the library." );
+ return E_UNEXPECTED;
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM )
+ return E_EOF;
+
+ if( !sample )
+ continue;
+
+ // Get a pointer to the audio data in the sample.
+ CComPtr<IMFMediaBuffer> buffer;
+ hr = sample->ConvertToContiguousBuffer( &buffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ const float* pAudioData = nullptr;
+ DWORD cbBuffer;
+ hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ try
+ {
+ assert( 0 == ( cbBuffer % sizeof( float ) ) );
+ const size_t countFloats = cbBuffer / sizeof( float );
+ if( !discard )
+ {
+ const size_t prevSize = pcm.mono.size();
+ ( pcm.*pfnAppendSamples )( pAudioData, countFloats );
+ const size_t newSize = pcm.mono.size();
+ this->nextSampleTime += ( newSize - prevSize );
+ }
+ else
+ {
+ this->nextSampleTime += countFloats / readerChannels;
+ }
+ }
+ catch( const std::bad_alloc& )
+ {
+ buffer->Unlock();
+ return E_OUTOFMEMORY;
+ }
+
+ // Unlock the buffer
+ hr = buffer->Unlock();
+ if( FAILED( hr ) )
+ return hr;
+
+ return S_OK;
+ }
+ }
+
+ HRESULT Capture::workCallback()
+ {
+ CHECK( whisperContext->runFull( fullParams, &buffer ) );
+ CHECK( clearStateFlag( eCaptureStatus::Transcribing ) );
+ return S_OK;
+ }
+
+ void __stdcall Capture::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work )
+ {
+ Capture* pThis = (Capture*)pv;
+ HRESULT status = E_UNEXPECTED;
+ try
+ {
+ status = pThis->workCallback();
+ }
+ catch( HRESULT hr )
+ {
+ status = hr;
+ }
+ catch( const std::bad_alloc& )
+ {
+ status = E_OUTOFMEMORY;
+ }
+ catch( const std::exception& )
+ {
+ status = E_FAIL;
+ }
+ assert( S_OK == status || FAILED( status ) );
+ pThis->workStatus = status;
+ }
+
+ size_t Capture::detectVoice()
+ {
+ auto pf = profiler.cpuBlock( eCpuBlock::VAD );
+ return vad.detect( pcm.mono.data(), pcm.mono.size() );
+ }
+}
+
+HRESULT COMLIGHTCALL ContextImpl::runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader )
+{
+ if( nullptr == reader )
+ return E_POINTER;
+
+ // Validate a few things
+ {
+ const auto& cp = reader->getParams();
+ if( cp.minDuration < 0.125f || cp.minDuration > 30.0f )
+ {
+ logError( u8"%s parameter %g is out of range", "minDuration", cp.minDuration );
+ return E_INVALIDARG;
+ }
+ if( cp.maxDuration < 0.125f || cp.maxDuration > 30.0f )
+ {
+ logError( u8"%s parameter %g is out of range", "maxDuration", cp.maxDuration );
+ return E_INVALIDARG;
+ }
+ }
+
+ Capture capture{ callbacks, reader, params, this, profiler };
+ CHECK( capture.startup( reader ) );
+
+ while( true )
+ {
+ HRESULT hr = capture.checkCancel();
+ CHECK( hr );
+ if( hr == S_OK )
+ return S_OK;
+ CHECK( capture.run() );
+ }
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
new file mode 100644
index 0000000..a8e16f5
--- /dev/null
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -0,0 +1,528 @@
+#include "stdafx.h"
+#include "ContextImpl.h"
+#include "Languages.h"
+#include "../Utils/Trace/tracing.h"
+using namespace Whisper;
+
+ContextImpl::ContextImpl( const WhisperModel& modelData, iModel* modelPointer ) :
+ model( modelData ),
+ modelPtr( modelPointer ),
+ context( modelData, profiler ),
+ profiler( modelData )
+{ }
+
+#define WHISPER_CHUNK_SIZE 30
+
+HRESULT ContextImpl::encode( iSpectrogram& mel, int seek )
+{
+ // whisper_encode
+ using namespace DirectCompute;
+
+ sEncodeParams ep;
+ ep.n_ctx = ( exp_n_audio_ctx > 0 ) ? exp_n_audio_ctx : model.parameters.n_audio_ctx;
+ ep.n_mels = model.parameters.n_mels;
+ ep.mel_offset = seek;
+ ep.layersCount = model.parameters.n_audio_layer;
+ ep.n_state = model.parameters.n_audio_state;
+ ep.n_head = model.parameters.n_audio_head;
+ ep.n_audio_ctx = model.parameters.n_audio_ctx;
+ ep.n_text_state = model.parameters.n_text_state;
+ ep.n_text_layer = model.parameters.n_text_layer;
+ ep.n_text_ctx = model.parameters.n_text_ctx;
+ try
+ {
+ auto cur = context.encode( mel, ep );
+ Tracing::tensor( "encode-out", cur );
+ return S_OK;
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+}
+
+HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int threads )
+{
+ // whisper_decode
+ using namespace DirectCompute;
+ sDecodeParams dp;
+ dp.n_state = model.parameters.n_audio_state;
+ dp.n_head = model.parameters.n_audio_head;
+ dp.n_ctx = model.parameters.n_text_ctx;
+ dp.n_past = n_past;
+ dp.M = exp_n_audio_ctx > 0 ? exp_n_audio_ctx : model.parameters.n_audio_ctx;
+ dp.n_text_layer = model.parameters.n_text_layer;
+ dp.n_vocab = model.parameters.n_vocab;
+
+ try
+ {
+ context.decode( tokens, (int)length, dp, probs, threads );
+ return S_OK;
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+}
+
+// the most basic sampling scheme - select the top token
+sTokenData ContextImpl::sampleBest( const float* probs, bool force_timestamp, bool is_initial )
+{
+ // whisper_sample_best
+ const Vocabulary& vocab = model.vocab;
+ sTokenData result = { 0 };
+
+ size_t n_logits = vocab.size();
+
+ probs_id.clear();
+ probs_id.reserve( n_logits );
+
+ for( size_t i = 0; i < n_logits; i++ )
+ probs_id.emplace_back( probs[ i ], (int)i );
+
+ {
+ double sum_ts = 0.0;
+ double max_ts = -1.0;
+ double max_tx = -1.0;
+
+ for( int i = 0; i < vocab.token_beg; i++ )
+ max_tx = std::max( max_tx, probs_id[ i ].first );
+
+ const int i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
+ const int i1 = is_initial ? vocab.token_beg + 101 : (int)n_logits;
+
+ // the initial timestamp cannot be larger than 100
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
+ if( is_initial )
+ {
+ for( int i = i0; i < n_logits; i++ )
+ probs_id[ i ].first = -INFINITY;
+ }
+
+ for( int i = vocab.token_beg; i < i1; i++ )
+ {
+ sum_ts += probs_id[ i ].first;
+ if( probs_id[ i ].first > max_ts )
+ {
+ max_ts = probs_id[ i ].first;
+ result.tid = probs_id[ i ].second;
+ }
+ }
+
+ // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
+ // timestamp token
+ if( sum_ts > max_tx || force_timestamp )
+ {
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
+ for( int i = 0; i < vocab.token_beg; i++ )
+ probs_id[ i ].first = -INFINITY;
+ }
+
+ result.pt = (float)( max_ts / ( sum_ts + 1e-10 ) );
+ result.ptsum = (float)sum_ts;
+ }
+
+ // find the top K tokens
+ const int top_k = 4;
+
+ std::partial_sort(
+ probs_id.begin(),
+ probs_id.begin() + top_k, probs_id.end(),
+ []( const std::pair<double, Vocabulary::id>& a, const std::pair<double, Vocabulary::id>& b ) {
+ return a.first > b.first;
+ } );
+
+ probs_id.resize( top_k );
+
+ //printf("\n");
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
+ //}
+
+ int res = 0;
+ while( ( probs_id[ res ].second == vocab.token_sot ||
+ probs_id[ res ].second == vocab.token_solm ||
+ probs_id[ res ].second == vocab.token_not ) &&
+ res < (int)probs_id.size() - 1 )
+ {
+ res++;
+ }
+
+ result.id = probs_id[ res ].second;
+ result.p = (float)probs_id[ res ].first;
+
+ return result;
+}
+
+sTokenData ContextImpl::sampleBest()
+{
+ const int n_vocab = model.vocab.n_vocab;
+ return sampleBest( probs.data() + ( probs.size() - n_vocab ), false, false );
+}
+
+sTokenData ContextImpl::sampleTimestamp( bool initial )
+{
+ const int n_vocab = model.vocab.n_vocab;
+ return sampleBest( probs.data() + ( probs.size() - n_vocab ), true, initial );
+}
+
+void ContextImpl::expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum )
+{
+ // whisper_exp_compute_token_level_timestamps
+ throw E_NOTIMPL;
+}
+
+static std::string to_timestamp( int64_t t, bool comma = false )
+{
+ int64_t msec = t * 10;
+ int64_t hr = msec / ( 1000 * 60 * 60 );
+ msec = msec - hr * ( 1000 * 60 * 60 );
+ int64_t min = msec / ( 1000 * 60 );
+ msec = msec - min * ( 1000 * 60 );
+ int64_t sec = msec / 1000;
+ msec = msec - sec * 1000;
+
+ char buf[ 32 ];
+ snprintf( buf, sizeof( buf ), "%02d:%02d:%02d%s%03d", (int)hr, (int)min, (int)sec, comma ? "," : ".", (int)msec );
+
+ return std::string( buf );
+}
+
+HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const sProgressSink& progress, iSpectrogram& mel )
+{
+ // Ported from whisper_full() function
+ result_all.clear();
+ if( params.flag( eFullParamsFlags::SpeedupAudio ) )
+ {
+ logError( u8"GPU model doesn't implement the SpeedupAudio flag" );
+ return E_NOTIMPL;
+ }
+
+ const int seek_start = params.offset_ms / 10;
+ const int seek_end = seek_start + ( params.duration_ms == 0 ? (int)mel.getLength() : params.duration_ms / 10 );
+
+ // if length of spectrogram is less than 1s (100 samples), then return
+ // basically don't process anything that is less than 1s
+ // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
+ if( seek_end < 100 + seek_start )
+ return S_FALSE;
+
+ // the accumulated text context so far
+ if( params.flag( eFullParamsFlags::NoContext ) )
+ prompt_past.clear();
+
+ // prepend the prompt tokens to the prompt_past
+ if( params.prompt_tokens && params.prompt_n_tokens > 0 )
+ {
+ // parse tokens from the pointer
+ for( int i = 0; i < params.prompt_n_tokens; i++ )
+ prompt_past.push_back( params.prompt_tokens[ i ] );
+ std::rotate( prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end() );
+ }
+
+ // overwrite audio_ctx
+ exp_n_audio_ctx = params.audio_ctx;
+
+ // these tokens determine the task that will be performed
+ std::vector<whisper_token> prompt_init = { model.vocab.token_sot };
+ if( model.vocab.is_multilingual() )
+ {
+ int langId = lookupLanguageId( params.language );
+ if( langId < 0 )
+ {
+ char lang[ 5 ];
+ *(uint32_t*)( &lang[ 0 ] ) = params.language;
+ lang[ 4 ] = '\0';
+ logError( u8"%s: unknown language '%s'", __func__, lang );
+ return E_INVALIDARG;
+ }
+
+ prompt_init.push_back( model.vocab.token_sot + 1 + langId );
+ if( params.flag( eFullParamsFlags::Translate ) )
+ prompt_init.push_back( model.vocab.token_translate );
+ else
+ prompt_init.push_back( model.vocab.token_transcribe );
+ }
+
+ // int progress_prev = 0;
+ // int progress_step = 5;
+
+ std::vector<sTokenData> tokens_cur;
+ tokens_cur.reserve( model.parameters.n_text_ctx );
+ std::vector<whisper_token> prompt;
+ prompt.reserve( model.parameters.n_text_ctx );
+
+ // main loop
+ int seek = seek_start;
+ auto prof = context.completeProfiler();
+ while( true )
+ {
+ if( nullptr != progress.pfn )
+ {
+ const int pos = seek - seek_start;
+ const int total = seek_end - seek_start;
+ const double percentage = (double)pos / (double)total;
+ CHECK( progress.pfn( percentage, this, progress.pv ) );
+ }
+ /*
+ const int progress_cur = ( 100 * ( seek - seek_start ) ) / ( seek_end - seek_start );
+ while( progress_cur >= progress_prev + progress_step )
+ {
+ progress_prev += progress_step;
+ if( params.flag( eFullParamsFlags::PrintProgress ) )
+ logInfo( u8"%s: progress = %3d%%", __func__, progress_prev );
+ }
+ */
+
+ if( seek + 100 >= seek_end )
+ break;
+
+ if( nullptr != params.encoder_begin_callback )
+ {
+ HRESULT hr = params.encoder_begin_callback( this, params.encoder_begin_callback_user_data );
+ if( FAILED( hr ) )
+ return hr;
+ if( hr != S_OK )
+ break;
+ }
+
+ // encode audio features starting at offset seek
+ CHECK( encode( mel, seek ) );
+
+ int n_past = 0;
+ prompt.clear();
+
+ // if we have already generated some text, use it as a prompt to condition the next generation
+ if( !prompt_past.empty() )
+ {
+ int n_take = std::min( std::min( params.n_max_text_ctx, model.parameters.n_text_ctx / 2 ), int( prompt_past.size() ) );
+
+ prompt = { model.vocab.token_prev };
+ prompt.insert( prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end() );
+
+ prompt_past.clear();
+ prompt_past.insert( prompt_past.end(), prompt.begin() + 1, prompt.end() );
+ }
+
+ prompt.insert( prompt.end(), prompt_init.begin(), prompt_init.end() );
+
+ int seek_delta = 100 * WHISPER_CHUNK_SIZE;
+
+ // print the prompt
+ //printf("\n\n");
+ //for (int i = 0; i < prompt.size(); i++) {
+ // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
+ //}
+ //printf("\n\n");
+
+ // the accumulated transcription in the current iteration
+ int result_len = 0;
+ tokens_cur.clear();
+
+ bool failed = false;
+ bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
+
+ {
+ auto prof = context.decodeProfiler();
+ for( int i = 0, n_max = model.parameters.n_text_ctx / 2 - 4; i < n_max; i++ )
+ {
+ CHECK( decode( prompt.data(), prompt.size(), n_past, params.cpuThreads ) );
+
+ n_past += (int)prompt.size();
+ prompt.clear();
+
+ // very basic greedy sampling strategy:
+ //
+ // - always take the most probable token
+ //
+ // more sophisticated sampling strategies could be implemented here, but we keep it simple
+ // feel free to experiment!
+ //
+ {
+ auto p = profiler.cpuBlock( eCpuBlock::Sample );
+ const sTokenData token = ( i == 0 ) ? sampleTimestamp( true ) : sampleBest();
+
+ // timestamp token - update sliding window
+ if( token.id > model.vocab.token_beg )
+ {
+ const int seek_delta_new = 2 * ( token.id - model.vocab.token_beg );
+
+ // do not allow to go back in time
+ if( has_ts && seek_delta > seek_delta_new && result_len < i )
+ break;
+
+ seek_delta = seek_delta_new;
+ result_len = i + 1;
+ has_ts = true;
+ }
+
+ // add it to the context
+ prompt.push_back( token.id );
+ tokens_cur.push_back( token );
+
+ //{
+ // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
+ // printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
+ //}
+
+ // end of segment
+ if( token.id == model.vocab.token_eot || // end of text token
+ ( params.max_tokens > 0 && i >= params.max_tokens ) || // max tokens per segment reached
+ ( has_ts && seek + seek_delta + 100 >= seek_end ) // end of audio reached
+ )
+ {
+ if( result_len == 0 )
+ {
+ if( seek + seek_delta + 100 >= seek_end )
+ result_len = i + 1;
+ else
+ {
+ failed = true;
+ break;
+ }
+ }
+
+ if( params.flag( eFullParamsFlags::SingleSegment ) )
+ {
+ result_len = i + 1;
+ seek_delta = 100 * WHISPER_CHUNK_SIZE;
+ }
+
+ break;
+ }
+ }
+
+ // sometimes, the decoding can get stuck in a repetition loop
+ // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
+ // the sliding window by 1 second
+ if( i == n_max - 1 && ( result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2 ) )
+ {
+ failed = true;
+ break;
+ }
+ }
+ }
+
+ if( failed )
+ {
+ logError( u8"%s: failed to generate timestamp token - skipping one second", __func__ );
+ seek += 100;
+ continue;
+ }
+
+ // shrink down to result_len
+ tokens_cur.resize( result_len );
+
+ for( const auto& r : tokens_cur )
+ prompt_past.push_back( r.id );
+
+ // store the text from this iteration
+ if( !tokens_cur.empty() )
+ {
+ int i0 = 0;
+ int t0 = seek + 2 * ( tokens_cur.front().tid - model.vocab.token_beg );
+ std::string text = "";
+
+ for( int i = 0; i < (int)tokens_cur.size(); i++ )
+ {
+ //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
+ // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
+ // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
+ if( params.flag( eFullParamsFlags::PrintSpecial ) || tokens_cur[ i ].id < model.vocab.token_eot )
+ text += model.vocab.string( tokens_cur[ i ].id );
+
+ if( tokens_cur[ i ].id > model.vocab.token_beg && !params.flag( eFullParamsFlags::SingleSegment ) )
+ {
+ const int t1 = seek + 2 * ( tokens_cur[ i ].tid - model.vocab.token_beg );
+ if( !text.empty() )
+ {
+ const bool speedUp = params.flag( eFullParamsFlags::SpeedupAudio );
+ const int tt0 = speedUp ? 2 * t0 : t0;
+ const int tt1 = speedUp ? 2 * t1 : t1;
+
+ if( params.flag( eFullParamsFlags::PrintRealtime ) )
+ {
+ if( params.flag( eFullParamsFlags::PrintTimestamps ) )
+ printf( "[%s --> %s] %s\n", to_timestamp( tt0 ).c_str(), to_timestamp( tt1 ).c_str(), text.c_str() );
+ else
+ {
+ printf( "%s", text.c_str() );
+ fflush( stdout );
+ }
+ }
+
+ result_all.push_back( { tt0, tt1, text, {} } );
+ for( int j = i0; j <= i; j++ )
+ result_all.back().tokens.push_back( tokens_cur[ j ] );
+
+ int n_new = 1;
+
+ if( params.flag( eFullParamsFlags::TokenTimestamps ) )
+ {
+ expComputeTokenLevelTimestamps( (int)result_all.size() - 1, params.thold_pt, params.thold_ptsum );
+ if( params.max_len > 0 )
+ n_new = wrapSegment( params.max_len );
+ }
+ if( nullptr != params.new_segment_callback )
+ {
+ HRESULT hr = params.new_segment_callback( this, n_new, params.new_segment_callback_user_data );
+ if( FAILED( hr ) )
+ return hr;
+ }
+ }
+ text = "";
+ while( i < (int)tokens_cur.size() && tokens_cur[ i ].id > model.vocab.token_beg )
+ i++;
+ i--;
+ t0 = t1;
+ i0 = i + 1;
+ }
+ }
+
+ if( !text.empty() )
+ {
+ const int t1 = seek + seek_delta;
+
+ const bool speedUp = params.flag( eFullParamsFlags::SpeedupAudio );
+ const int tt0 = speedUp ? 2 * t0 : t0;
+ const int tt1 = speedUp ? 2 * t1 : t1;
+
+ if( params.flag( eFullParamsFlags::PrintRealtime ) )
+ {
+ if( params.flag( eFullParamsFlags::PrintTimestamps ) )
+ printf( "[%s --> %s] %s\n", to_timestamp( tt0 ).c_str(), to_timestamp( tt1 ).c_str(), text.c_str() );
+ else
+ {
+ printf( "%s", text.c_str() );
+ fflush( stdout );
+ }
+ }
+
+ result_all.push_back( { tt0, tt1, text, {} } );
+ for( int j = i0; j < (int)tokens_cur.size(); j++ )
+ result_all.back().tokens.push_back( tokens_cur[ j ] );
+
+ int n_new = 1;
+ if( params.flag( eFullParamsFlags::TokenTimestamps ) )
+ {
+ expComputeTokenLevelTimestamps( (int)result_all.size() - 1, params.thold_pt, params.thold_ptsum );
+ if( params.max_len > 0 )
+ n_new = wrapSegment( params.max_len );
+ }
+ if( nullptr != params.new_segment_callback )
+ {
+ HRESULT hr = params.new_segment_callback( this, n_new, params.new_segment_callback_user_data );
+ if( FAILED( hr ) )
+ return hr;
+ }
+ }
+ }
+ seek += seek_delta;
+ }
+
+ if( nullptr != progress.pfn )
+ {
+ CHECK( progress.pfn( 1.0, this, progress.pv ) );
+ }
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h
new file mode 100644
index 0000000..971f629
--- /dev/null
+++ b/Whisper/Whisper/ContextImpl.h
@@ -0,0 +1,75 @@
+#pragma once
+#include "../API/iContext.cl.h"
+#include "../ComLightLib/comLightServer.h"
+#include "WhisperContext.h"
+#include "Spectrogram.h"
+#include "TranscribeResult.h"
+#include "sTokenData.h"
+
+namespace Whisper
+{
+ class ContextImpl : public ComLight::ObjectRoot<iContext>
+ {
+ const WhisperModel& model;
+ ComLight::CComPtr<iModel> modelPtr;
+ DirectCompute::WhisperContext context;
+ Spectrogram spectrogram;
+ int64_t mediaTimeOffset = 0;
+ ProfileCollection profiler;
+
+ HRESULT COMLIGHTCALL getModel( iModel** pp ) override final;
+ HRESULT COMLIGHTCALL timingsPrint() override final;
+ HRESULT COMLIGHTCALL timingsReset() override final;
+ HRESULT COMLIGHTCALL fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi ) override final;
+ HRESULT COMLIGHTCALL runFullImpl( const sFullParams& params, const sProgressSink& progress, iSpectrogram& mel );
+ HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) override final;
+ HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) override final;
+ HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) override final;
+
+ struct Segment
+ {
+ int64_t t0;
+ int64_t t1;
+ std::string text;
+ std::vector<sTokenData> tokens;
+ size_t memoryUsage() const;
+ };
+ std::vector<Segment> result_all;
+
+ std::vector<whisper_token> prompt_past;
+
+ // [EXPERIMENTAL] token-level timestamps data
+ int64_t t_beg = 0;
+ int64_t t_last = 0;
+ whisper_token tid_last = 0;
+ std::vector<float> energy; // PCM signal energy
+
+ // [EXPERIMENTAL] speed-up techniques
+ int32_t exp_n_audio_ctx = 0; // 0 - use default
+
+ HRESULT encode( iSpectrogram& mel, int seek );
+ HRESULT decode( const int* tokens, size_t length, int n_past, int threads );
+ sTokenData sampleBest( const float* probs, bool force_timestamp, bool is_initial );
+ sTokenData sampleBest();
+ sTokenData sampleTimestamp( bool initial );
+ int wrapSegment( int max_len );
+ void expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum );
+
+ std::vector<float> probs;
+ std::vector<std::pair<double, Vocabulary::id>> probs_id;
+
+ mutable TranscribeResultStatic results;
+
+ HRESULT COMLIGHTCALL makeResults( eResultFlags flags, TranscribeResult& res ) const noexcept;
+
+ HRESULT COMLIGHTCALL getResults( eResultFlags flags, iTranscribeResult** pp ) const noexcept override final;
+
+ int defaultThreadsCount() const;
+
+ __m128i getMemoryUse() const;
+
+ public:
+
+ ContextImpl( const WhisperModel& modelData, iModel* modelPointer );
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp
new file mode 100644
index 0000000..98d2164
--- /dev/null
+++ b/Whisper/Whisper/ContextImpl.misc.cpp
@@ -0,0 +1,408 @@
+#include "stdafx.h"
+#include "ContextImpl.h"
+#include <mfapi.h>
+#include "MelStreamer.h"
+#include "../API/iMediaFoundation.cl.h"
+#include "../Utils/Trace/tracing.h"
+using namespace Whisper;
+
+static int getCpuCoresCount()
+{
+ DWORD bufferSize = 0;
+ GetLogicalProcessorInformation( NULL, &bufferSize );
+
+ // The SYSTEM_LOGICAL_PROCESSOR_INFORMATION structure has a uint64_t field
+ // Ideally need to align by 8 bytes, and that's why uint64_t type for the storage
+ std::unique_ptr<uint64_t[]> buffer = std::make_unique<uint64_t[]>( ( bufferSize + 7 ) / 8 );
+
+ SYSTEM_LOGICAL_PROCESSOR_INFORMATION* ptr = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION*)buffer.get();
+ if( !GetLogicalProcessorInformation( ptr, &bufferSize ) )
+ {
+ HRESULT hr = getLastHr();
+ logWarningHr( hr, u8"GetLogicalProcessorInformation" );
+ return 0;
+ }
+
+ DWORD byteOffset = 0;
+ int physicalCores = 0;
+ while( byteOffset < bufferSize )
+ {
+ if( ptr->Relationship == RelationProcessorCore )
+ physicalCores++;
+ byteOffset += sizeof( SYSTEM_LOGICAL_PROCESSOR_INFORMATION );
+ ptr++;
+ }
+ return physicalCores;
+}
+
+int ContextImpl::defaultThreadsCount() const
+{
+#if BUILD_HYBRID_VERSION
+ const bool isHybrid = !model.hybridTensors.layers.empty();
+#else
+ constexpr bool isHybrid = false;
+#endif
+
+ SYSTEM_INFO si;
+ GetSystemInfo( &si );
+ const int hardwareThreads = (int)si.dwNumberOfProcessors;
+
+ if( !isHybrid )
+ return std::min( hardwareThreads, 4 );
+
+ // It seems the CPU decoder in the hybrid context doesn’t scale well with count of hardware threads, but it does scale with count of physical cores.
+ int cores = getCpuCoresCount();
+ if( cores > 1 )
+ return cores;
+
+ return hardwareThreads;
+}
+
+HRESULT COMLIGHTCALL ContextImpl::fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi )
+{
+ // whisper_full_default_params
+ if( nullptr == rdi )
+ return E_POINTER;
+ memset( rdi, 0, sizeof( sFullParams ) );
+
+ rdi->strategy = strategy;
+ rdi->cpuThreads = defaultThreadsCount();
+ rdi->n_max_text_ctx = 16384;
+ rdi->flags = eFullParamsFlags::PrintProgress | eFullParamsFlags::PrintTimestamps;
+ rdi->thold_pt = 0.01f;
+ rdi->thold_ptsum = 0.01f;
+ rdi->language = makeLanguageKey( "en" );
+
+ switch( strategy )
+ {
+ case eSamplingStrategy::Greedy:
+ rdi->beam_search.n_past = -1;
+ rdi->beam_search.beam_width = -1;
+ rdi->beam_search.n_best = -1;
+ break;
+ case eSamplingStrategy::BeamSearch:
+ rdi->greedy.n_past = -1;
+ rdi->beam_search.beam_width = 10;
+ rdi->beam_search.n_best = 5;
+ break;
+ default:
+ logError( u8"Unknown sampling strategy %i", (int)strategy );
+ return E_INVALIDARG;
+ }
+ return S_OK;
+}
+
+HRESULT COMLIGHTCALL ContextImpl::getModel( iModel** pp )
+{
+ if( nullptr == pp )
+ return E_POINTER;
+ if( !modelPtr )
+ return OLE_E_BLANK;
+ *pp = modelPtr;
+ modelPtr->AddRef();
+ return S_OK;
+}
+
+size_t ContextImpl::Segment::memoryUsage() const
+{
+ return text.capacity() + vectorMemoryUse( tokens );
+}
+
+__m128i ContextImpl::getMemoryUse() const
+{
+ // Misc. system RAM
+ size_t cb = vectorMemoryUse( result_all );
+ for( const auto& r : result_all )
+ cb += r.memoryUsage();
+ cb += vectorMemoryUse( prompt_past );
+ cb += vectorMemoryUse( energy );
+ cb += vectorMemoryUse( probs );
+ cb += vectorMemoryUse( probs_id );
+ cb += vectorMemoryUse( results.segments );
+ cb += vectorMemoryUse( results.tokens );
+ cb += spectrogram.memoryUsage();
+
+ __m128i res = setLow_size( cb );
+ // Add all the VRAM in the temporary buffers
+ res = _mm_add_epi64( res, context.getMemoryUse() );
+ return res;
+}
+
+namespace
+{
+ struct PrintedSize
+ {
+ double val;
+ const char* unit;
+ PrintedSize( int64_t cb )
+ {
+ if( cb < ( 1 << 10 ) )
+ {
+ val = (double)cb;
+ unit = "bytes";
+ }
+ else if( cb < ( 1 << 20 ) )
+ {
+ val = (double)cb * ( 1.0 / ( 1 << 10 ) );
+ unit = "KB";
+ }
+ else if( cb < ( 1 << 30 ) )
+ {
+ val = (double)cb * ( 1.0 / ( 1 << 20 ) );
+ unit = "MB";
+ }
+ else
+ {
+ val = (double)cb * ( 1.0 / ( 1 << 30 ) );
+ unit = "GB";
+ }
+ }
+ };
+
+ static void __declspec( noinline ) logMemoryUse( const char* what, __m128i cb )
+ {
+ PrintedSize sys{ _mm_cvtsi128_si64( cb ) };
+ PrintedSize vram{ _mm_extract_epi64( cb, 1 ) };
+ logInfo( u8"%s\t%g %s RAM, %g %s VRAM", what, sys.val, sys.unit, vram.val, vram.unit );
+ }
+}
+
+HRESULT COMLIGHTCALL ContextImpl::timingsPrint()
+{
+ profiler.print();
+
+ const __m128i memModel = model.getMemoryUse();
+ const __m128i memContext = getMemoryUse();
+ logInfo( u8" Memory Usage" );
+ logMemoryUse( "Model", memModel );
+ logMemoryUse( "Context", memContext );
+ logMemoryUse( "Total", _mm_add_epi64( memModel, memContext ) );
+ return S_OK;
+}
+
+HRESULT COMLIGHTCALL ContextImpl::timingsReset()
+{
+ profiler.reset();
+ return S_OK;
+}
+
+HRESULT COMLIGHTCALL ContextImpl::getResults( eResultFlags flags, iTranscribeResult** pp ) const noexcept
+{
+ if( nullptr == pp )
+ return E_POINTER;
+
+ if( flags & eResultFlags::NewObject )
+ {
+ ComLight::CComPtr<ComLight::Object<TranscribeResult>> obj;
+ CHECK( ComLight::Object<TranscribeResult>::create( obj ) );
+ CHECK( makeResults( flags, *obj ) );
+ obj.detach( pp );
+ return S_OK;
+ }
+ else
+ {
+ CHECK( makeResults( flags, results ) );
+ iTranscribeResult* res = &results;
+ res->AddRef();
+ *pp = res;
+ return S_OK;
+ }
+}
+
+inline int64_t scaleTime( int64_t wisperTicks )
+{
+ return MFllMulDiv( wisperTicks, 10'000'000, 100, 0 );
+}
+
+HRESULT COMLIGHTCALL ContextImpl::makeResults( eResultFlags flags, TranscribeResult& res ) const noexcept
+{
+ const size_t segments = result_all.size();
+ // Resize both vectors
+ try
+ {
+ res.segments.resize( segments );
+ if( flags & eResultFlags::Tokens )
+ {
+ size_t tc = 0;
+ for( const auto& s : result_all )
+ tc += s.tokens.size();
+ res.tokens.resize( tc );
+ }
+ else
+ res.tokens.clear();
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+
+ const Vocabulary::id tokenEot = model.vocab.token_eot;
+
+ size_t tokensSoFar = 0;
+ for( size_t i = 0; i < segments; i++ )
+ {
+ sSegment& rdi = res.segments[ i ];
+ const auto& rsi = result_all[ i ];
+ rdi.text = rsi.text.c_str();
+ if( flags & eResultFlags::Timestamps )
+ {
+ // Offset the time relative to the start of the media
+ rdi.time.begin = scaleTime( rsi.t0 ) + mediaTimeOffset;
+ rdi.time.end = scaleTime( rsi.t1 ) + mediaTimeOffset;
+ }
+ else
+ store16( &rdi.time, _mm_setzero_si128() );
+
+ rdi.firstToken = (uint32_t)tokensSoFar;
+ const size_t tc = rsi.tokens.size();
+ rdi.countTokens = (uint32_t)tc;
+
+ if( flags & eResultFlags::Tokens )
+ {
+ for( size_t i = 0; i < tc; i++ )
+ {
+ sToken& rdi = res.tokens[ tokensSoFar + i ];
+ const auto& src = rsi.tokens[ i ];
+ rdi.text = model.vocab.string( src.id );
+
+ if( flags & eResultFlags::Timestamps )
+ {
+ // Offset the time relative to the start of the media
+ rdi.time.begin = scaleTime( src.t0 ) + mediaTimeOffset;
+ rdi.time.end = scaleTime( src.t1 ) + mediaTimeOffset;
+ }
+ else
+ store16( &rdi.time, _mm_setzero_si128() );
+
+ // Copy 4 floats with unaligned load and store instructions
+ _mm_storeu_ps( &rdi.probability, _mm_loadu_ps( &src.p ) );
+
+ rdi.id = src.id;
+
+ uint32_t flags = 0;
+ if( src.id >= tokenEot )
+ flags |= (uint32_t)eTokenFlags::Special;
+ rdi.flags = (eTokenFlags)flags;
+ }
+ }
+ tokensSoFar += tc;
+ }
+ return S_OK;
+}
+
+int ContextImpl::wrapSegment( int max_len )
+{
+ // whisper_wrap_segment
+ auto segment = result_all.back();
+ int res = 1;
+ int acc = 0;
+ std::string text;
+ const int tokenEot = model.vocab.token_eot;
+
+ for( int i = 0; i < (int)segment.tokens.size(); i++ )
+ {
+ const auto& token = segment.tokens[ i ];
+ if( token.id >= tokenEot )
+ continue;
+
+ const char* txt = model.vocab.string( token.id );
+ const int cur = (int)strlen( txt );
+
+ if( acc + cur > max_len && i > 0 )
+ {
+ // split here
+ result_all.back().text = std::move( text );
+ result_all.back().t1 = token.t0;
+ result_all.back().tokens.resize( i );
+
+ result_all.push_back( {} );
+ result_all.back().t0 = token.t0;
+ result_all.back().t1 = segment.t1;
+
+ // add tokens [i, end] to the new segment
+ result_all.back().tokens.insert( result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end() );
+
+ acc = 0;
+ text = "";
+
+ segment = result_all.back();
+ i = -1;
+
+ res++;
+ }
+ else
+ {
+ acc += cur;
+ text += txt;
+ }
+ }
+
+ result_all.back().text = std::move( text );
+ return res;
+}
+
+HRESULT COMLIGHTCALL ContextImpl::runFull( const sFullParams& params, const iAudioBuffer* buffer )
+{
+#if SAVE_DEBUG_TRACE
+ Tracing::vector( "runFull.pcm.in", buffer->getPcmMono(), buffer->countSamples() );
+#endif
+ CHECK( buffer->getTime( mediaTimeOffset ) );
+
+ auto profCompleteCpu = profiler.cpuBlock( eCpuBlock::Run );
+ {
+ auto p = profiler.cpuBlock( eCpuBlock::Spectrogram );
+ CHECK( spectrogram.pcmToMel( buffer, model.filters, params.cpuThreads ) );
+ }
+
+ if( params.flag( eFullParamsFlags::TokenTimestamps ) )
+ {
+ t_beg = 0;
+ t_last = 0;
+ tid_last = 0;
+ computeSignalEnergy( energy, buffer, 32 );
+ }
+
+ try
+ {
+ sProgressSink progressSink{ nullptr, nullptr };
+ return runFullImpl( params, progressSink, spectrogram );
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+}
+
+HRESULT COMLIGHTCALL ContextImpl::runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader )
+{
+ if( params.flag( eFullParamsFlags::TokenTimestamps ) )
+ {
+ logError( u8"eFullParamsFlags.TokenTimestamps flag is not supported in streaming mode" );
+ return E_NOTIMPL;
+ }
+
+ mediaTimeOffset = 0;
+ auto profCompleteCpu = profiler.cpuBlock( eCpuBlock::Run );
+
+ CComPtr<IMFSourceReader> mfReader;
+ CHECK( reader->getReader( &mfReader ) );
+ const bool stereo = reader->requestedStereo() == S_OK;
+
+ try
+ {
+ if( params.cpuThreads > 1 )
+ {
+ MelStreamerThread mel{ model.filters, profiler, mfReader, params.cpuThreads, stereo };
+ return runFullImpl( params, progress, mel );
+ }
+ else
+ {
+ MelStreamerSimple mel{ model.filters, profiler, mfReader, stereo };
+ return runFullImpl( params, progress, mel );
+ }
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+} \ No newline at end of file
diff --git a/Whisper/Whisper/DecoderInputBuffers.cpp b/Whisper/Whisper/DecoderInputBuffers.cpp
new file mode 100644
index 0000000..68d3cec
--- /dev/null
+++ b/Whisper/Whisper/DecoderInputBuffers.cpp
@@ -0,0 +1,66 @@
+#include "stdafx.h"
+#include "DecoderInputBuffers.h"
+#include "../D3D/createBuffer.h"
+#include "../D3D/MappedResource.h"
+using namespace DirectCompute;
+
+void DecoderInputBuffers::resize( uint32_t size )
+{
+ if( 0 == size )
+ throw E_INVALIDARG;
+
+ if( size <= m_capacity )
+ {
+ m_size = size;
+ return;
+ }
+
+ embd = nullptr;
+
+ // Round up by 256, mostly for lulz
+ const uint32_t newCapacity = ( size + 0xFFu ) & ( ~( 0xFFu ) );
+ const size_t totalBytes = (size_t)4 * newCapacity;
+
+ check( createBuffer( eBufferUse::Dynamic, totalBytes, &embd, nullptr, nullptr ) );
+
+ m_capacity = newCapacity;
+ m_size = size;
+}
+
+namespace
+{
+ static Tensor createView( ID3D11Buffer* buffer, uint32_t length )
+ {
+ Tensor res;
+
+ TensorGpuViews& views = res;
+ check( views.create( buffer, DXGI_FORMAT_R32_UINT, length, false ) );
+
+ res.ne = { length, 1, 1, 1 };
+ res.setDenseStrides();
+ return res;
+ }
+}
+
+Tensor DecoderInputBuffers::embedding( const int* rsi ) const
+{
+ if( nullptr == embd || m_size == 0 )
+ throw OLE_E_BLANK;
+
+ // Upload the data
+ {
+ MappedResource mapped;
+ check( mapped.map( embd, false ) );
+ int* const rdi = (int*)mapped.data();
+ memcpy( rdi, rsi, m_size * (size_t)4 );
+ }
+
+ return createView( embd, m_size );
+}
+
+void DecoderInputBuffers::clear()
+{
+ embd = nullptr;
+ m_size = 0;
+ m_capacity = 0;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/DecoderInputBuffers.h b/Whisper/Whisper/DecoderInputBuffers.h
new file mode 100644
index 0000000..9ce8f75
--- /dev/null
+++ b/Whisper/Whisper/DecoderInputBuffers.h
@@ -0,0 +1,29 @@
+#pragma once
+#include "../ML/Tensor.h"
+
+namespace DirectCompute
+{
+ // Two dynamic buffers
+ class DecoderInputBuffers
+ {
+ CComPtr<ID3D11Buffer> embd;
+ uint32_t m_size = 0;
+ uint32_t m_capacity = 0;
+
+ public:
+
+ void resize( uint32_t size );
+
+ // Create 1D tensor with R32_UINT elements, upload the source data
+ Tensor embedding( const int* rsi ) const;
+
+ void clear();
+
+ __m128i getMemoryUse() const
+ {
+ size_t i = m_capacity;
+ i *= sizeof( uint32_t );
+ return _mm_set_epi64x( (int64_t)i, 0 );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/DecoderResultBuffer.cpp b/Whisper/Whisper/DecoderResultBuffer.cpp
new file mode 100644
index 0000000..3a26d51
--- /dev/null
+++ b/Whisper/Whisper/DecoderResultBuffer.cpp
@@ -0,0 +1,48 @@
+#include "stdafx.h"
+#include "DecoderResultBuffer.h"
+#include "../D3D/MappedResource.h"
+using namespace DirectCompute;
+
+void DecoderResultBuffer::copyFromVram( const Tensor& rsi )
+{
+ ID3D11ShaderResourceView* srv = rsi;
+ if( nullptr == srv )
+ throw OLE_E_BLANK;
+ if( !rsi.isContinuous() )
+ throw E_INVALIDARG;
+
+ const uint32_t len = rsi.countElements();
+ if( len > m_capacity )
+ {
+ buffer = nullptr;
+ CD3D11_BUFFER_DESC desc{ len * 4, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ };
+ check( device()->CreateBuffer( &desc, nullptr, &buffer ) );
+ m_capacity = len;
+ }
+
+ CComPtr<ID3D11Resource> source;
+ srv->GetResource( &source );
+ // Coordinates of a box are in bytes for buffers
+ D3D11_BOX box;
+ store16( &box, _mm_setr_epi32( 0, 0, 0, (int)( len * 4 ) ) );
+ *(uint64_t*)&box.bottom = 0x100000001ull;
+ context()->CopySubresourceRegion( buffer, 0, 0, 0, 0, source, 0, &box );
+ m_size = len;
+}
+
+void DecoderResultBuffer::copyToVector( std::vector<float>& vec ) const
+{
+ vec.resize( m_size );
+ if( vec.empty() )
+ throw OLE_E_BLANK;
+
+ MappedResource mapped;
+ check( mapped.map( buffer, true ) );
+ memcpy( vec.data(), mapped.data(), (size_t)4 * m_size );
+}
+
+void DecoderResultBuffer::clear()
+{
+ buffer = nullptr;
+ m_size = m_capacity = 0;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/DecoderResultBuffer.h b/Whisper/Whisper/DecoderResultBuffer.h
new file mode 100644
index 0000000..471395f
--- /dev/null
+++ b/Whisper/Whisper/DecoderResultBuffer.h
@@ -0,0 +1,29 @@
+#pragma once
+#include "../ML/Tensor.h"
+
+namespace DirectCompute
+{
+ class DecoderResultBuffer
+ {
+ CComPtr<ID3D11Buffer> buffer;
+ uint32_t m_size = 0;
+ uint32_t m_capacity = 0;
+
+ public:
+
+ void copyFromVram( const Tensor& rsi );
+
+ void copyToVector( std::vector<float>& vec ) const;
+
+ uint32_t size() const
+ {
+ return m_size;
+ }
+ void clear();
+
+ __m128i getMemoryUse() const
+ {
+ return bufferMemoryUsage( buffer );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/KeyValueBuffers.cpp b/Whisper/Whisper/KeyValueBuffers.cpp
new file mode 100644
index 0000000..b932fdb
--- /dev/null
+++ b/Whisper/Whisper/KeyValueBuffers.cpp
@@ -0,0 +1,42 @@
+#include "stdafx.h"
+#include "KeyValueBuffers.h"
+#include "../D3D/createBuffer.h"
+using namespace DirectCompute;
+
+void AttentionBuffer::resize( uint32_t size )
+{
+ if( size <= m_size )
+ return;
+
+ buffer = nullptr;
+ check( createBuffer( eBufferUse::ReadWrite, (size_t)2 * size, &buffer, nullptr, nullptr ) );
+ m_size = size;
+}
+
+Tensor AttentionBuffer::view( uint32_t length, uint32_t offset ) const
+{
+ if( length + offset > m_size )
+ throw E_BOUNDS;
+ if( 0 == length )
+ throw E_INVALIDARG;
+
+ CComPtr<ID3D11ShaderResourceView> srv;
+ CComPtr<ID3D11UnorderedAccessView> uav;
+
+ CD3D11_SHADER_RESOURCE_VIEW_DESC srvDesc{ D3D11_SRV_DIMENSION_BUFFER, DXGI_FORMAT_R16_FLOAT, offset, length };
+ check( device()->CreateShaderResourceView( buffer, &srvDesc, &srv ) );
+
+ CD3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc{ D3D11_UAV_DIMENSION_BUFFER, DXGI_FORMAT_R16_FLOAT, offset, length };
+ check( device()->CreateUnorderedAccessView( buffer, &uavDesc, &uav ) );
+
+ TensorShape shape;
+ shape.ne = { length, 1, 1, 1 };
+ shape.setDenseStrides();
+ return Tensor( shape, srv, uav );
+}
+
+void KeyValueBuffers::resize( uint32_t size )
+{
+ keys.resize( size );
+ values.resize( size );
+} \ No newline at end of file
diff --git a/Whisper/Whisper/KeyValueBuffers.h b/Whisper/Whisper/KeyValueBuffers.h
new file mode 100644
index 0000000..9c737be
--- /dev/null
+++ b/Whisper/Whisper/KeyValueBuffers.h
@@ -0,0 +1,50 @@
+#pragma once
+#include "../ML/Tensor.h"
+
+namespace DirectCompute
+{
+ // FP16 buffer for self-attention and cross-attention layers
+ class AttentionBuffer
+ {
+ CComPtr<ID3D11Buffer> buffer;
+ uint32_t m_size = 0;
+
+ public:
+ // Create buffer for the specified count of elements
+ void resize( uint32_t size );
+
+ // Create an 1D tensor which references a slice of that buffer
+ Tensor view( uint32_t length, uint32_t offset ) const;
+
+ void clear()
+ {
+ buffer = nullptr;
+ m_size = 0;
+ }
+
+ ID3D11Buffer* getBuffer() const { return buffer; }
+
+ uint32_t getSize() const { return m_size; }
+ };
+
+ struct KeyValueBuffers
+ {
+ AttentionBuffer keys, values;
+
+ void resize( uint32_t size );
+
+ void clear()
+ {
+ keys.clear();
+ values.clear();
+ }
+
+ __m128i getMemoryUse() const
+ {
+ size_t i = keys.getSize();
+ i += values.getSize();
+ i *= sizeof( uint16_t );
+ return setHigh_size( (int64_t)i ); // They both are in VRAM
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/Languages.cpp b/Whisper/Whisper/Languages.cpp
new file mode 100644
index 0000000..2081d6a
--- /dev/null
+++ b/Whisper/Whisper/Languages.cpp
@@ -0,0 +1,122 @@
+#include "stdafx.h"
+#include "Languages.h"
+#include <atlcoll.h>
+#include "../API/iContext.cl.h"
+
+namespace
+{
+ // These structures are compiled into the DLL, in read only data section
+ using Lang = Whisper::sLanguageEntry;
+
+ static const Lang s_languageData[] =
+ {
+#include "languageCodez.inl"
+ };
+
+ using Whisper::makeLanguageKey;
+
+ // Values for the hash map
+ struct sLanguage
+ {
+ int id;
+ const char* name;
+ };
+
+ class LanguageIDs
+ {
+ CAtlMap<uint32_t, sLanguage> map;
+
+ void add( const char* code, int id, const char* name )
+ {
+ assert( strlen( code ) <= 4 );
+ const uint16_t key = makeLanguageKey( code );
+ map.SetAt( key, sLanguage{ id, name } );
+ }
+
+ public:
+
+ LanguageIDs() :
+ map( 103u, 0.75f, 0.25f, 2.25f, 99 )
+ {
+ for( const Lang& e : s_languageData )
+ map.SetAt( e.key, sLanguage{ e.id, e.name } );
+ };
+
+ int lookupId( const char* code ) const
+ {
+ const uint32_t key = makeLanguageKey( code );
+ auto p = map.Lookup( key );
+ return ( nullptr != p ) ? p->m_value.id : -1;
+ }
+
+ int lookupKey( uint32_t key ) const
+ {
+ auto p = map.Lookup( key );
+ return ( nullptr != p ) ? p->m_value.id : -1;
+ }
+
+ const char* lookupName( const char* code ) const
+ {
+ const uint32_t key = makeLanguageKey( code );
+ auto p = map.Lookup( key );
+ return ( nullptr != p ) ? p->m_value.name : nullptr;
+ }
+ };
+
+ static const LanguageIDs g_table;
+}
+
+namespace Whisper
+{
+ int lookupLanguageId( const char* code )
+ {
+ return g_table.lookupId( code );
+ }
+ int lookupLanguageId( uint32_t key )
+ {
+ return g_table.lookupKey( key );
+ }
+ const char* lookupLanguageName( const char* code )
+ {
+ return g_table.lookupName( code );
+ }
+ int COMLIGHTCALL getLanguageId( const char* lang )
+ {
+ return lookupLanguageId( lang );
+ }
+
+ uint32_t COMLIGHTCALL findLanguageKeyW( const wchar_t* lang )
+ {
+ uint32_t key = 0;
+ uint32_t shift = 0;
+ for( size_t i = 0; i < 4; i++, lang++, shift += 8 )
+ {
+ const wchar_t c = *lang;
+ if( c == L'\0' )
+ break;
+ if( c >= 0x80 )
+ return UINT_MAX;
+ uint32_t u32 = (uint8_t)c;
+ u32 = u32 << shift;
+ key |= u32;
+ }
+ if( g_table.lookupKey( key ) >= 0 )
+ return key;
+ return UINT_MAX;
+ }
+
+ uint32_t COMLIGHTCALL findLanguageKeyA( const char* lang )
+ {
+ const uint32_t key = makeLanguageKey( lang );
+ if( g_table.lookupKey( key ) >= 0 )
+ return key;
+ return UINT_MAX;
+ }
+
+ HRESULT COMLIGHTCALL getSupportedLanguages( sLanguageList& rdi )
+ {
+ rdi.length = sizeof( s_languageData ) / sizeof( s_languageData[ 0 ] );
+ rdi.pointer = s_languageData;
+ return S_OK;
+ }
+} \ No newline at end of file
diff --git a/Whisper/Whisper/Languages.h b/Whisper/Whisper/Languages.h
new file mode 100644
index 0000000..bb9e599
--- /dev/null
+++ b/Whisper/Whisper/Languages.h
@@ -0,0 +1,12 @@
+#pragma once
+#include "../../ComLightLib/comLightCommon.h"
+
+namespace Whisper
+{
+ int lookupLanguageId( const char* code );
+ int lookupLanguageId( uint32_t key );
+
+ const char* lookupLanguageName( const char* code );
+
+ int COMLIGHTCALL getLanguageId( const char* lang );
+} \ No newline at end of file
diff --git a/Whisper/Whisper/MelInputTensor.cpp b/Whisper/Whisper/MelInputTensor.cpp
new file mode 100644
index 0000000..c2c5e43
--- /dev/null
+++ b/Whisper/Whisper/MelInputTensor.cpp
@@ -0,0 +1,63 @@
+#include "stdafx.h"
+#include "MelInputTensor.h"
+#include "../D3D/MappedResource.h"
+#include "../D3D/createBuffer.h"
+#include <mfapi.h> // MFCopyImage
+using namespace DirectCompute;
+
+HRESULT MelInputTensor::create( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams )
+{
+ // Ported from the initial portion of whisper_encode() function
+ const size_t ne0 = encParams.n_ctx * 2;
+ const size_t ne1 = encParams.n_mels;
+ const size_t totalElts = ne0 * ne1;
+ const size_t totalBytes = totalElts * 4;
+
+ if( capacity < (uint32_t)totalElts )
+ {
+ // The old buffer is too small: drop the old one, and create a larger buffer with SRV
+ buffer = nullptr;
+ TensorGpuViews::clear();
+
+ CHECK( createBuffer( eBufferUse::Dynamic, totalBytes, &buffer, nullptr, nullptr ) );
+ CHECK( TensorGpuViews::create( buffer, DXGI_FORMAT_R32_FLOAT, totalElts, false ) );
+
+ capacity = (uint32_t)totalElts;
+ }
+
+ // Upload data to VRAM using D3D11_MAP_WRITE_DISCARD, that's why we made a dynamic buffer
+ {
+ // Ported from whisper_encode() function
+ MappedResource mapped;
+ CHECK( mapped.map( buffer, false ) );
+ float* const dst = (float*)mapped.data();
+ memset( dst, 0, totalBytes );
+
+ const size_t n_len = spectrogram.getLength();
+ const size_t i0 = std::min( (size_t)encParams.mel_offset, n_len );
+ const size_t i1 = std::min( (size_t)encParams.mel_offset + 2 * encParams.n_ctx, n_len );
+
+ // Whisper::MelBufferRaii sourceBuffer{ spectrogram, i0, i1 - i0 };
+ constexpr DWORD n_mel = Whisper::N_MEL;
+ const size_t rowBytes = ( i1 - i0 ) * 4;
+ /*
+ for( size_t j = 0; j < n_mel; j++ )
+ {
+ float* rdi = dst + j * 2 * encParams.n_ctx;
+ const float* rsi = sourceBuffer[ j ];
+ memcpy( rdi, rsi, rowBytes );
+ } */
+
+ Whisper::MelBufferRaii sourceBuffer;
+ CHECK( sourceBuffer.make( spectrogram, i0, i1 - i0 ) );
+ CHECK( MFCopyImage(
+ (BYTE*)dst, (LONG)( 2 * encParams.n_ctx * sizeof( float ) ),
+ sourceBuffer.bytePtr(), sourceBuffer.strideBytes(),
+ (DWORD)rowBytes, n_mel ) );
+ }
+
+ // Shape the tensor
+ ne = { 2 * encParams.n_ctx, encParams.n_mels, 1, 1 };
+ TensorShape::setDenseStrides();
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/MelInputTensor.h b/Whisper/Whisper/MelInputTensor.h
new file mode 100644
index 0000000..3923ad6
--- /dev/null
+++ b/Whisper/Whisper/MelInputTensor.h
@@ -0,0 +1,22 @@
+#pragma once
+#include "../ML/TensorEx.h"
+#include "sEncodeParams.h"
+#include "iSpectrogram.h"
+
+namespace DirectCompute
+{
+ // Input tensor in VRAM, in a dynamic FP32 buffer
+ class MelInputTensor : public TensorEx
+ {
+ uint32_t capacity;
+
+ public:
+
+ HRESULT create( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams );
+
+ __m128i getMemoryUse() const
+ {
+ return setHigh_size( (size_t)capacity * 4 );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/MelStreamer.cpp b/Whisper/Whisper/MelStreamer.cpp
new file mode 100644
index 0000000..54268f2
--- /dev/null
+++ b/Whisper/Whisper/MelStreamer.cpp
@@ -0,0 +1,493 @@
+#include "stdafx.h"
+#include "MelStreamer.h"
+#include "../Utils/parallelFor.h"
+using namespace Whisper;
+
+MelStreamer::MelStreamer( const Filters& filters, ProfileCollection& prof, IMFSourceReader* source, bool stereo ) :
+ reader( source, stereo ),
+ melContext( filters ),
+ profiler( prof )
+{ }
+
+void MelStreamer::dropOldChunks( size_t off )
+{
+ const bool stereo = reader.outputsStereo();
+ for( size_t i = streamStartOffset; i < off; i++ )
+ {
+ queuePcmMono.pop_front();
+ queueMel.pop_front();
+ if( stereo )
+ queuePcmStereo.pop_front();
+ }
+ streamStartOffset = off;
+}
+
+HRESULT MelStreamer::ensurePcmChunks( size_t len )
+{
+ if( readerEof )
+ return queuePcmMono.empty() ? E_EOF : S_FALSE;
+
+ const bool loadStereo = reader.outputsStereo();
+
+ const size_t neededChunks = len + FFT_SIZE / FFT_STEP;
+ while( true )
+ {
+ if( queuePcmMono.size() >= neededChunks )
+ return S_OK;
+
+ PcmMonoChunk& mono = queuePcmMono.emplace_back();
+ PcmStereoChunk* stereo = loadStereo ? &queuePcmStereo.emplace_back() : nullptr;
+ HRESULT hr = reader.readChunk( mono, stereo );
+ if( SUCCEEDED( hr ) )
+ continue;
+
+ queuePcmMono.pop_back();
+ if( loadStereo )
+ queuePcmStereo.pop_back();
+
+ if( hr == E_EOF )
+ {
+ readerEof = true;
+ return S_FALSE;
+ }
+
+ return hr;
+ }
+}
+
+size_t MelStreamer::serializePcm( size_t startOffset )
+{
+ const ptrdiff_t chunks = (ptrdiff_t)queuePcmMono.size() - (ptrdiff_t)startOffset;
+ assert( chunks > 0 );
+
+ tempPcm.resize( chunks * FFT_STEP );
+ float* rdi = tempPcm.data();
+
+ for( auto it = queuePcmMono.begin() + startOffset; it != queuePcmMono.end(); it++ )
+ {
+ memcpy( rdi, it->mono.data(), FFT_STEP * 4 );
+ rdi += FFT_STEP;
+ }
+ return chunks;
+}
+
+namespace
+{
+ __forceinline __m128 transpose4x80( __m128 vmax, const float* c0, const float* c1, const float* c2, const float* c3, float* rdi, size_t stride )
+ {
+ const float* const c0End = c0 + 80;
+ for( ; c0 < c0End; c0 += 4, c1 += 4, c2 += 4, c3 += 4, rdi += stride * 4 )
+ {
+ __m128 r0 = _mm_loadu_ps( c0 );
+ __m128 r1 = _mm_loadu_ps( c1 );
+ __m128 r2 = _mm_loadu_ps( c2 );
+ __m128 r3 = _mm_loadu_ps( c3 );
+
+ __m128 ax01 = _mm_max_ps( r0, r1 );
+ __m128 ax02 = _mm_max_ps( r2, r3 );
+ __m128 ax = _mm_max_ps( ax01, ax02 );
+ vmax = _mm_max_ps( vmax, ax );
+
+ _MM_TRANSPOSE4_PS( r0, r1, r2, r3 );
+
+ _mm_storeu_ps( rdi, r0 );
+ _mm_storeu_ps( rdi + stride, r1 );
+ _mm_storeu_ps( rdi + stride * 2, r2 );
+ _mm_storeu_ps( rdi + stride * 3, r3 );
+ }
+ return vmax;
+ }
+
+ __forceinline __m128 transpose80( __m128 vmax, const float* c0, float* rdi, size_t stride )
+ {
+ const float* const c0End = c0 + 80;
+ for( ; c0 < c0End; c0 += 4, rdi += stride * 4 )
+ {
+ __m128 r0 = _mm_loadu_ps( c0 );
+ vmax = _mm_max_ps( vmax, r0 );
+
+ _mm_store_ss( rdi, r0 );
+ *(int*)( rdi + stride ) = _mm_extract_ps( r0, 1 );
+ *(int*)( rdi + stride * 2 ) = _mm_extract_ps( r0, 2 );
+ *(int*)( rdi + stride * 3 ) = _mm_extract_ps( r0, 3 );
+ }
+ return vmax;
+ }
+
+ __forceinline float horizontalMaximum( __m128 v )
+ {
+ v = _mm_max_ps( v, _mm_movehl_ps( v, v ) );
+ v = _mm_max_ss( v, _mm_movehdup_ps( v ) );
+ return _mm_cvtss_f32( v );
+ }
+}
+
+void MelStreamer::makeTransposedBuffer( size_t off, size_t len )
+{
+ // Resize the output
+ assert( len <= queueMel.size() );
+ outputMel.resize( len * N_MEL ); // N_MEL = 80
+
+ // First pass, copy transposed MEL data, and compute the maximum
+ const size_t lengthAligned = ( len / 4 ) * 4;
+ __m128 vMax = _mm_set1_ps( 1e-20f );
+ float* rdi = outputMel.data();
+
+ size_t i;
+ for( i = 0; i < lengthAligned; i += 4, rdi += 4 )
+ {
+ vMax = transpose4x80( vMax,
+ queueMel[ i ].data(),
+ queueMel[ i + 1 ].data(),
+ queueMel[ i + 2 ].data(),
+ queueMel[ i + 3 ].data(),
+ rdi, len );
+ }
+ for( ; i < len; i++, rdi++ )
+ vMax = transpose80( vMax, queueMel[ i ].data(), rdi, len );
+
+ // Second pass, clamping and normalization
+ float mmax;
+ const size_t bufferEnd = off + len;
+ if( lastBufferEnd != bufferEnd )
+ {
+ // Store maximum value in this class, along with the end sample index
+ mmax = horizontalMaximum( vMax );
+ lastBufferEnd = bufferEnd;
+ lastBufferMax = mmax;
+ }
+ else
+ {
+ // We're probably at the and of the stream, the caller asked for a smalled slice of the samples with the same end as the last time.
+ // Discard the computed maximum value, and instead use the number stored in this class
+ mmax = lastBufferMax;
+ }
+
+ mmax -= 8.0f;
+ vMax = _mm_set1_ps( mmax );
+
+ rdi = outputMel.data();
+ float* const rdiEnd = rdi + outputMel.size();
+ const __m128 add = _mm_set1_ps( 4 );
+ const __m128 mul = _mm_set1_ps( 1.0f / 4.0f );
+ for( ; rdi < rdiEnd; rdi += 4 )
+ {
+ __m128 v = _mm_loadu_ps( rdi );
+ v = _mm_max_ps( v, vMax );
+ v = _mm_add_ps( v, add );
+ v = _mm_mul_ps( v, mul );
+ _mm_storeu_ps( rdi, v );
+ }
+}
+
+HRESULT MelStreamerSimple::makeBuffer( size_t off, size_t len, const float** buffer, size_t& stride ) noexcept
+{
+ if( off < streamStartOffset )
+ {
+ logError( u8"MelStreamer doesn't support backwards seeks" );
+ return E_UNEXPECTED;
+ }
+
+ if( off > streamStartOffset )
+ {
+ // The model wants to advance forward, drop now irrelevant chunks of data
+ dropOldChunks( off );
+ }
+
+ // Compute all these MEL chunks
+ const size_t availableMel = queueMel.size();
+ if( availableMel < len )
+ {
+ CHECK( ensurePcmChunks( len ) );
+
+ const size_t pcmChunks = serializePcm( availableMel );
+ const size_t missingMelChunks = len - availableMel;
+ size_t i;
+ const size_t loop1 = std::min( missingMelChunks, pcmChunks );
+ {
+ auto profilerBlock = profiler.cpuBlock( eCpuBlock::Spectrogram );
+ for( i = 0; i < loop1; i++ )
+ {
+ // if( readerEof && i + 1 == loop1 ) __debugbreak();
+ auto& arr = queueMel.emplace_back();
+ const float* sourcePcm = tempPcm.data() + i * FFT_STEP;
+ size_t availableChunks = pcmChunks - i;
+ size_t availableFloats = availableChunks * FFT_STEP;
+ melContext.fft( arr, sourcePcm, availableFloats );
+ }
+ }
+ for( ; i < missingMelChunks; i++ )
+ {
+ assert( readerEof );
+ auto& arr = queueMel.emplace_back();
+ memset( arr.data(), 0, N_MEL * 4 );
+ }
+ }
+
+ // Produce the result
+ makeTransposedBuffer( off, len );
+ stride = len;
+ *buffer = outputMel.data();
+ return S_OK;
+}
+
+MelStreamerThread::MelStreamerThread( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, int countThreads, bool stereo ) :
+ MelStreamer( filters, profiler, source, stereo ),
+ workerThreads( countThreads )
+{
+ if( workerThreads > 1 )
+ {
+ check( ThreadPoolWork::create() );
+ melContextsWorkers.reserve( workerThreads - 1 );
+ for( int i = 1; i < workerThreads; i++ )
+ melContextsWorkers.emplace_back( filters );
+ }
+
+ InitializeConditionVariable( &wakeMain );
+ InitializeConditionVariable( &wakeBackground );
+ threadStatus = eThreadStatus::NotStarted;
+ const HANDLE h = CreateThread( nullptr, 0, &threadProcStatic, this, 0, nullptr );
+ if( nullptr == h )
+ throw HRESULT_FROM_WIN32( GetLastError() );
+ threadHandle.Attach( h );
+}
+
+using Lock = CComCritSecLock<CComAutoCriticalSection>;
+
+constexpr ptrdiff_t prebufferChunks = 3000 * 2;
+constexpr ptrdiff_t chunksPerWakeup = 512;
+constexpr ptrdiff_t minChunksPerThread = 64;
+
+HRESULT MelStreamerThread::threadMain()
+{
+ pendingChunks.reserve( chunksPerWakeup );
+
+ EnterCriticalSection( &m_cs.m_sec );
+ threadStatus = eThreadStatus::Working;
+
+ while( true )
+ {
+ if( shuttingDown )
+ {
+ LeaveCriticalSection( &m_cs.m_sec );
+ return S_FALSE;
+ }
+
+ // Count of available MEL chunks
+ const ptrdiff_t availableMel = queueMel.size();
+ if( availableMel >= prebufferChunks )
+ {
+ threadStatus = eThreadStatus::Idle;
+ SleepConditionVariableCS( &wakeBackground, &m_cs.m_sec, INFINITE );
+ threadStatus = eThreadStatus::Working;
+ continue;
+ }
+ // Count of MEL chunks remaining in the whole stream
+ // availableMel of them are already on the queue
+ const ptrdiff_t remainingMel = (ptrdiff_t)getLength() - (ptrdiff_t)streamStartOffset;
+ LeaveCriticalSection( &m_cs.m_sec );
+
+ const ptrdiff_t missingChunks = prebufferChunks - availableMel;
+ ptrdiff_t chunks = std::min( missingChunks, chunksPerWakeup );
+ chunks = std::min( chunks, remainingMel - availableMel );
+ if( chunks <= 0 )
+ return S_OK; // This thread has produced all chunks of the stream
+
+ CHECK( ensurePcmChunks( availableMel + chunks ) );
+ const size_t pcmChunks = serializePcm( availableMel );
+ if( 0 == pcmChunks )
+ return S_OK;
+
+ pendingChunks.clear();
+
+ chunks = std::min( chunks, (ptrdiff_t)pcmChunks );
+ {
+ auto profilerBlock = profiler.cpuBlock( eCpuBlock::Spectrogram );
+
+ if( this->workerThreads <= 1 || chunks < minChunksPerThread * 2 )
+ {
+ // Thread pool disabled with a setting, or not enough work for the thread pool
+ for( ptrdiff_t i = 0; i < chunks; i++ )
+ {
+ MelChunk& arr = pendingChunks.emplace_back();
+ const float* sourcePcm = tempPcm.data() + i * FFT_STEP;
+ size_t availableChunks = pcmChunks - i;
+ size_t availableFloats = availableChunks * FFT_STEP;
+ melContext.fft( arr, sourcePcm, availableFloats );
+ }
+ }
+ else
+ {
+ // Use thread pool for these FFTs
+ pendingChunks.resize( chunks );
+ int nth = (int)( ( chunks + minChunksPerThread - 1 ) / minChunksPerThread );
+ nth = std::min( nth, this->workerThreads );
+ assert( nth > 1 );
+ this->fftChunks = (int)chunks;
+ this->fftThreads = nth;
+ CHECK( ThreadPoolWork::parallelFor( nth ) );
+ }
+ }
+
+ EnterCriticalSection( &m_cs.m_sec );
+ if( shuttingDown )
+ {
+ LeaveCriticalSection( &m_cs.m_sec );
+ return S_FALSE;
+ }
+
+ for( const auto& a : pendingChunks )
+ queueMel.push_back( a );
+
+ LeaveCriticalSection( &m_cs.m_sec );
+
+ WakeAllConditionVariable( &wakeMain );
+ pendingChunks.clear();
+
+ EnterCriticalSection( &m_cs.m_sec );
+ }
+}
+
+HRESULT MelStreamerThread::threadPoolCallback( int ith ) noexcept
+{
+ SpectrogramContext& ctx = ( 0 != ith ) ? melContextsWorkers[ ith - 1 ] : melContext;
+
+ // Figure out the slice of the chunks to generate in this thread
+ const int nth = this->fftThreads;
+ const int chunks = this->fftChunks;
+ const int i0 = ( ith * chunks ) / nth;
+ const int i1 = ( ( ith + 1 ) * chunks ) / nth;
+
+ // Run these FFTs
+ const size_t pcmChunks = tempPcm.size() / FFT_STEP;
+ for( int i = i0; i < i1; i++ )
+ {
+ MelChunk& arr = pendingChunks[ i ];
+ const float* sourcePcm = tempPcm.data() + i * FFT_STEP;
+ size_t availableChunks = pcmChunks - i;
+ size_t availableFloats = availableChunks * FFT_STEP;
+ ctx.fft( arr, sourcePcm, availableFloats );
+ }
+ return S_OK;
+}
+
+HRESULT MelStreamerThread::run() noexcept
+{
+ HRESULT status;
+ try
+ {
+ status = threadMain();
+ }
+ catch( HRESULT hr )
+ {
+ status = hr;
+ }
+ catch( const std::bad_alloc& )
+ {
+ status = E_OUTOFMEMORY;
+ }
+ catch( const std::exception& )
+ {
+ status = E_FAIL;
+ }
+
+ {
+ Lock lk( m_cs );
+ threadStatus = SUCCEEDED( status ) ? eThreadStatus::Completed : eThreadStatus::Failed;
+ }
+
+ // Especially when things fail, we want to wake the main thread up, so it's aware of the situation.
+ WakeAllConditionVariable( &wakeMain );
+ return status;
+}
+
+DWORD __stdcall MelStreamerThread::threadProcStatic( void* lpParameter )
+{
+ setCurrentThreadName( "Whisper.dll MEL Streamer Thread" );
+ MelStreamerThread* p = (MelStreamerThread*)lpParameter;
+ return (DWORD)p->run();
+}
+
+HRESULT MelStreamerThread::makeBuffer( size_t off, size_t len, const float** buffer, size_t& stride ) noexcept
+{
+ bool wakeThread = false;
+
+ {
+ Lock lock( m_cs );
+ if( off < streamStartOffset )
+ {
+ logError( u8"MelStreamer doesn't support backwards seeks" );
+ return E_UNEXPECTED;
+ }
+
+ if( off > streamStartOffset )
+ {
+ // The model wants to advance forward, drop now irrelevant chunks of data
+ dropOldChunks( off );
+ wakeThread = ( threadStatus == eThreadStatus::Working || threadStatus == eThreadStatus::Idle );
+ }
+
+ while( true )
+ {
+ const size_t availableMel = queueMel.size();
+ if( availableMel >= len )
+ break;
+
+ const eThreadStatus ts = threadStatus;
+ if( ts == eThreadStatus::Working || ts == eThreadStatus::Idle )
+ {
+ WakeAllConditionVariable( &wakeBackground );
+ SleepConditionVariableCS( &wakeMain, &m_cs.m_sec, INFINITE );
+ continue;
+ }
+ if( ts == eThreadStatus::Failed )
+ {
+ DWORD code;
+ if( GetExitCodeThread( threadHandle, &code ) )
+ return (HRESULT)code;
+ else
+ return HRESULT_FROM_WIN32( GetLastError() );
+ }
+ assert( ts == eThreadStatus::Completed );
+ break;
+ }
+
+ if( queueMel.size() < len )
+ {
+ assert( readerEof || threadStatus == eThreadStatus::Failed );
+ while( queueMel.size() < len )
+ {
+ auto& arr = queueMel.emplace_back();
+ memset( arr.data(), 0, N_MEL * 4 );
+ }
+ }
+
+ // Produce the result
+ makeTransposedBuffer( off, len );
+
+ } // Unlock the critical section
+
+ stride = len;
+ *buffer = outputMel.data();
+ if( wakeThread )
+ WakeAllConditionVariable( &wakeBackground );
+ return S_OK;
+}
+
+MelStreamerThread::~MelStreamerThread()
+{
+ if( !threadHandle )
+ return;
+
+ {
+ Lock lock( m_cs );
+ if( threadStatus != eThreadStatus::Working )
+ return;
+ shuttingDown = true;
+ }
+
+ DWORD res = WaitForSingleObject( threadHandle, 100 );
+ if( res == WAIT_OBJECT_0 )
+ return;
+ // TODO: log a warning
+} \ No newline at end of file
diff --git a/Whisper/Whisper/MelStreamer.h b/Whisper/Whisper/MelStreamer.h
new file mode 100644
index 0000000..152c1b6
--- /dev/null
+++ b/Whisper/Whisper/MelStreamer.h
@@ -0,0 +1,99 @@
+#pragma once
+#include <deque>
+#include "../MF/PcmReader.h"
+#include "melSpectrogram.h"
+#include "iSpectrogram.h"
+#include <atlbase.h>
+#include "../Utils/parallelFor.h"
+#include "../Utils/ProfileCollection.h"
+
+namespace Whisper
+{
+ // Base class for both single- and multi-threaded MEL streamers
+ class MelStreamer : public iSpectrogram
+ {
+ protected:
+ PcmReader reader;
+ std::deque<PcmMonoChunk> queuePcmMono;
+ using MelChunk = std::array<float, N_MEL>;
+ std::deque<MelChunk> queueMel;
+ size_t streamStartOffset = 0;
+ std::vector<float> tempPcm;
+ std::vector<float> outputMel;
+ SpectrogramContext melContext;
+ bool readerEof = false;
+ ProfileCollection& profiler;
+ std::deque<PcmStereoChunk> queuePcmStereo;
+
+ // If the streamStartOffset value is less than the argument,
+ // remove ( off - streamStartOffset ) chunks from the start of all 3 queues, and advance streamStartOffset to the `off` argument
+ void dropOldChunks( size_t off );
+
+ // Ensure PCM queues have enough chunks to generate specified count of MEL chunks
+ // At the end of the stream, the method delivers less chunks then requested and returns S_FALSE
+ HRESULT ensurePcmChunks( size_t len );
+
+ // Copy mono PCM chunks from the queue (starting at the specified element index) into the continuous tempPcm vector
+ // Returns count of chunks copied there.
+ size_t serializePcm( size_t startOffset );
+
+ size_t lastBufferEnd = ~(size_t)0;
+ float lastBufferMax = 0.0f;
+ void makeTransposedBuffer( size_t off, size_t len );
+
+ size_t getLength() const noexcept override final { return reader.getLength(); }
+
+ public:
+ MelStreamer( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, bool stereo );
+ };
+
+ // Single-threaded MEL streamer: runs these FFTs on-demand, from within makeBuffer() method
+ class MelStreamerSimple : public MelStreamer
+ {
+ HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride ) noexcept override final;
+
+ public:
+ MelStreamerSimple( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, bool stereo ) :
+ MelStreamer( filters, profiler, source, stereo ) { }
+ };
+
+ // Multi threaded MEL streamers: runs FFT on a background thread ahead of time
+ // The background thread tries to keep the queueMel full, this way the makeBuffer() method has very little to do
+ // makeBuffer() only transposes the data, and does clamping + normalization, both steps are pretty fast
+ class MelStreamerThread : public MelStreamer,
+ ThreadPoolWork
+ {
+ HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride ) noexcept override final;
+
+ static DWORD __stdcall threadProcStatic( void* lpParameter );
+ HRESULT run() noexcept;
+ HRESULT threadMain();
+
+ std::vector<MelChunk> pendingChunks;
+ int fftChunks = 0;
+ int fftThreads = 0;
+ std::vector<SpectrogramContext> melContextsWorkers;
+ CComAutoCriticalSection m_cs;
+ CONDITION_VARIABLE wakeMain, wakeBackground;
+ const int workerThreads;
+ enum struct eThreadStatus : uint8_t
+ {
+ NotStarted = 0,
+ Idle,
+ Working,
+ Completed,
+ Failed
+ };
+ eThreadStatus threadStatus;
+ bool shuttingDown = false;
+ CHandle threadHandle;
+
+ HRESULT threadPoolCallback( int ith ) noexcept override final;
+
+ public:
+
+ MelStreamerThread( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, int countThreads, bool stereo );
+
+ ~MelStreamerThread();
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ModelBuffers.cpp b/Whisper/Whisper/ModelBuffers.cpp
new file mode 100644
index 0000000..7ad8e13
--- /dev/null
+++ b/Whisper/Whisper/ModelBuffers.cpp
@@ -0,0 +1,115 @@
+#include "stdafx.h"
+#include "ModelLoader.h"
+
+
+#if BUILD_BOTH_VERSIONS
+namespace DirectCompute
+{
+ static ModelBuffers s_model;
+ const ModelBuffers& gpuModel = s_model;
+}
+
+using namespace DirectCompute;
+
+ModelLoader::ModelLoader( int encoderLayers, int decoderLayers ) :
+ model( s_model )
+{
+ if( encoderLayers <= 0 || decoderLayers <= 0 )
+ throw E_INVALIDARG;
+ model.enc.layers.resize( (uint32_t)encoderLayers );
+ model.dec.layers.resize( (uint32_t)decoderLayers );
+}
+
+void ModelLoader::add( const ggml_tensor* ggml, Tensor& gpu )
+{
+ if( nullptr == ggml )
+ throw E_POINTER;
+
+ auto res = map.try_emplace( ggml, &gpu );
+ if( !res.second )
+ throw E_INVALIDARG;
+}
+
+Tensor* ModelLoader::lookup( const ggml_tensor* ggml ) const
+{
+ auto it = map.find( ggml );
+ if( it == map.end() )
+ return nullptr;
+ return it->second;
+}
+
+bool ModelLoader::tryLoad( const ggml_tensor* ggml )
+{
+ Tensor* rdi = lookup( ggml );
+ if( nullptr == rdi )
+ return false;
+ HRESULT hr = rdi->create( *ggml, eBufferUse::Immutable, true );
+ if( SUCCEEDED( hr ) )
+ return true;
+ throw hr;
+}
+#endif
+
+__m128i __declspec( noinline ) DirectCompute::TensorPair::getMemoryUse() const
+{
+ return _mm_add_epi64( w.getMemoryUse(), b.getMemoryUse() );
+}
+
+__m128i DirectCompute::LayerEncoder::getMemoryUse() const
+{
+ __m128i v = attnLn0.getMemoryUse();
+ v = _mm_add_epi64( v, attnLn1.getMemoryUse() );
+ v = _mm_add_epi64( v, attnQuery.getMemoryUse() );
+ v = _mm_add_epi64( v, attnKey.getMemoryUse() );
+ v = _mm_add_epi64( v, attnValue.getMemoryUse() );
+ v = _mm_add_epi64( v, mlpLn.getMemoryUse() );
+ v = _mm_add_epi64( v, mlp0.getMemoryUse() );
+ v = _mm_add_epi64( v, mlp1.getMemoryUse() );
+ return v;
+}
+
+__m128i DirectCompute::EncoderBuffers::getMemoryUse() const
+{
+ __m128i v = _mm_cvtsi64_si128( vectorMemoryUse( layers ) );
+ v = _mm_add_epi64( v, positionalEmbedding.getMemoryUse() );
+ v = _mm_add_epi64( v, conv1.getMemoryUse() );
+ v = _mm_add_epi64( v, conv2.getMemoryUse() );
+ v = _mm_add_epi64( v, lnPost.getMemoryUse() );
+ for( const auto& layer : layers )
+ v = _mm_add_epi64( v, layer.getMemoryUse() );
+ return v;
+}
+
+__m128i DirectCompute::LayerDecoder::getMemoryUse() const
+{
+ __m128i v = attnLn0.getMemoryUse();
+ v = _mm_add_epi64( v, attnLn1.getMemoryUse() );
+ v = _mm_add_epi64( v, attnQuery.getMemoryUse() );
+ v = _mm_add_epi64( v, attnKey.getMemoryUse() );
+ v = _mm_add_epi64( v, attnValue.getMemoryUse() );
+ v = _mm_add_epi64( v, crossAttnLn0.getMemoryUse() );
+ v = _mm_add_epi64( v, crossAttnLn1.getMemoryUse() );
+ v = _mm_add_epi64( v, crossAttnQuery.getMemoryUse() );
+ v = _mm_add_epi64( v, crossAttnKey.getMemoryUse() );
+ v = _mm_add_epi64( v, crossAttnValue.getMemoryUse() );
+ v = _mm_add_epi64( v, mlpLn.getMemoryUse() );
+ v = _mm_add_epi64( v, mlp0.getMemoryUse() );
+ v = _mm_add_epi64( v, mlp1.getMemoryUse() );
+ return v;
+}
+
+__m128i DirectCompute::DecoderBuffers::getMemoryUse() const
+{
+ __m128i v = _mm_cvtsi64_si128( vectorMemoryUse( layers ) );
+ v = _mm_add_epi64( v, positionalEmbedding.getMemoryUse() );
+ v = _mm_add_epi64( v, tokenEmbedding.getMemoryUse() );
+ v = _mm_add_epi64( v, ln.getMemoryUse() );
+ for( const auto& layer : layers )
+ v = _mm_add_epi64( v, layer.getMemoryUse() );
+ return v;
+}
+
+__m128i DirectCompute::ModelBuffers::getMemoryUse() const
+{
+ return _mm_add_epi64( enc.getMemoryUse(), dec.getMemoryUse() );
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ModelBuffers.h b/Whisper/Whisper/ModelBuffers.h
new file mode 100644
index 0000000..403d274
--- /dev/null
+++ b/Whisper/Whisper/ModelBuffers.h
@@ -0,0 +1,114 @@
+#pragma once
+#include "../ML/Tensor.h"
+#include <vector>
+
+namespace DirectCompute
+{
+ // A pair of tensors containing weights and biases; apparently, both tensors are of the same shape.
+ struct TensorPair
+ {
+ Tensor w, b;
+
+ __m128i getMemoryUse() const;
+ };
+
+ // A set of tensors for one encoder's layer
+ struct LayerEncoder
+ {
+ // encoder.blocks.*.attn_ln
+ TensorPair attnLn0;
+ // encoder.blocks.*.attn.out
+ TensorPair attnLn1;
+ // encoder.blocks.*.attn.query
+ TensorPair attnQuery;
+ // encoder.blocks.*.attn.key
+ Tensor attnKey;
+ // encoder.blocks.*.attn.value
+ TensorPair attnValue;
+ // encoder.blocks.*.mlp_ln
+ TensorPair mlpLn;
+ // encoder.blocks.*.mlp.0
+ TensorPair mlp0;
+ // encoder.blocks.*.mlp.2
+ TensorPair mlp1;
+
+ __m128i getMemoryUse() const;
+ };
+
+ // A set of tensors for the encoder
+ struct EncoderBuffers
+ {
+ // encoder.positional_embedding
+ Tensor positionalEmbedding;
+ // encoder.conv1
+ TensorPair conv1;
+ // encoder.conv2
+ TensorPair conv2;
+ // encoder.ln_post
+ TensorPair lnPost;
+ // A vector of layers
+ std::vector<LayerEncoder> layers;
+
+ __m128i getMemoryUse() const;
+ };
+
+ // A set of tensors for one decoder's layer
+ struct LayerDecoder
+ {
+ // decoder.blocks.*.attn_ln
+ TensorPair attnLn0;
+ // decoder.blocks.*.attn.out
+ TensorPair attnLn1;
+ // decoder.blocks.*.attn.query
+ TensorPair attnQuery;
+ // decoder.blocks.*.attn.key
+ Tensor attnKey;
+ // decoder.blocks.*.attn.value
+ TensorPair attnValue;
+ // decoder.blocks.*.cross_attn_ln
+ TensorPair crossAttnLn0;
+ // decoder.blocks.*.cross_attn.out
+ TensorPair crossAttnLn1;
+ // decoder.blocks.*.cross_attn.query
+ TensorPair crossAttnQuery;
+ // decoder.blocks.*.cross_attn.key
+ Tensor crossAttnKey;
+ // decoder.blocks.*.cross_attn.value
+ TensorPair crossAttnValue;
+ // decoder.blocks.*.mlp_ln
+ TensorPair mlpLn;
+ // decoder.blocks.*.mlp.0
+ TensorPair mlp0;
+ // decoder.blocks.*.mlp.2
+ TensorPair mlp1;
+
+ __m128i getMemoryUse() const;
+ };
+
+ // A set of tensors for the decoder
+ struct DecoderBuffers
+ {
+ // decoder.positional_embedding
+ Tensor positionalEmbedding;
+ // decoder.token_embedding
+ Tensor tokenEmbedding;
+ // decoder.ln
+ TensorPair ln;
+ // A vector of layers
+ std::vector<LayerDecoder> layers;
+
+ __m128i getMemoryUse() const;
+ };
+
+ // A complete set of tensors for a model
+ struct ModelBuffers
+ {
+ EncoderBuffers enc;
+ DecoderBuffers dec;
+ __m128i getMemoryUse() const;
+ };
+
+#if BUILD_BOTH_VERSIONS
+ extern const ModelBuffers& gpuModel;
+#endif
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ModelImpl.cpp b/Whisper/Whisper/ModelImpl.cpp
new file mode 100644
index 0000000..968c3ce
--- /dev/null
+++ b/Whisper/Whisper/ModelImpl.cpp
@@ -0,0 +1,122 @@
+#include "stdafx.h"
+#include "ModelImpl.h"
+#include "../ML/mlStartup.h"
+#include "ContextImpl.h"
+#include <intrin.h>
+#include "../Utils/ReadStream.h"
+#include "../modelFactory.h"
+using namespace Whisper;
+
+namespace
+{
+ volatile long s_refCounter = 0;
+}
+
+HRESULT ModelImpl::FinalConstruct()
+{
+ if( 1 != InterlockedIncrement( &s_refCounter ) )
+ return S_FALSE;
+ return DirectCompute::mlStartup();
+}
+
+void ModelImpl::FinalRelease()
+{
+ if( 0 == InterlockedDecrement( &s_refCounter ) )
+ DirectCompute::mlShutdown();
+}
+
+HRESULT COMLIGHTCALL ModelImpl::createContext( iContext** pp )
+{
+ ComLight::CComPtr<ComLight::Object<ContextImpl>> obj;
+
+ iModel* m = this;
+ CHECK( ComLight::Object<ContextImpl>::create( obj, model, m ) );
+
+ obj.detach( pp );
+ return S_OK;
+}
+
+HRESULT ModelImpl::load( iReadStream* stm, bool hybrid, const sLoadModelCallbacks* callbacks )
+{
+ return model.load( stm, hybrid, callbacks );
+}
+
+inline bool hasSse41()
+{
+ int cpu_info[ 4 ];
+ __cpuid( cpu_info, 1 );
+ return ( cpu_info[ 2 ] & ( 1 << 19 ) ) != 0;
+}
+
+// True when the current CPU is good enough to run the hybrid model
+inline bool hasAvxAndFma()
+{
+ // AVX needs OS support to preserve the 32-bytes registers across context switches, CPU support alone ain't enough
+ // Calling a kernel API to check that support
+ // The magic number is from there: https://stackoverflow.com/a/35096938/126995
+ if( 0 == ( GetEnabledXStateFeatures() & 4 ) )
+ return false;
+
+ // FMA3 and F16C
+ int cpuInfo[ 4 ];
+ __cpuid( cpuInfo, 1 );
+ // The magic numbers are from "Feature Information" table on Wikipedia:
+ // https://en.wikipedia.org/wiki/CPUID#EAX=1:_Processor_Info_and_Feature_Bits
+ constexpr int requiredBits = ( 1 << 12 ) | ( 1 << 29 );
+ if( requiredBits != ( cpuInfo[ 2 ] & requiredBits ) )
+ return false;
+
+ // BMI1
+ // https://en.wikipedia.org/wiki/CPUID#EAX=7,_ECX=0:_Extended_Features
+ __cpuid( cpuInfo, 7 );
+ if( 0 == ( cpuInfo[ 1 ] & ( 1 << 3 ) ) )
+ return false;
+
+ return true;
+}
+
+HRESULT __stdcall Whisper::loadGpuModel( const wchar_t* path, bool hybrid, const sLoadModelCallbacks* callbacks, iModel** pp )
+{
+ if( nullptr == path || nullptr == pp )
+ return E_POINTER;
+
+ if( hybrid )
+ {
+#if BUILD_HYBRID_VERSION
+ if( !hasAvxAndFma() )
+ {
+ logError( u8"eModelImplementation.Hybrid model requires a CPU with AVX1, FMA3, F16C and BMI1 support" );
+ return ERROR_HV_CPUID_FEATURE_VALIDATION;
+ }
+#else
+ logError( u8"This build of the DLL doesn’t implement eModelImplementation.Hybrid model" );
+ return E_NOTIMPL;
+#endif
+ }
+ else if( !hasSse41() )
+ {
+ logError( u8"eModelImplementation.GPU model requires a CPU with SSE 4.1 support" );
+ return ERROR_HV_CPUID_FEATURE_VALIDATION;
+ }
+
+ ComLight::Object<ReadStream> stream;
+ HRESULT hr = stream.open( path );
+ if( FAILED( hr ) )
+ {
+ logError16( L"Unable to open model binary file \"%s\"", path );
+ return hr;
+ }
+
+ ComLight::CComPtr<ComLight::Object<ModelImpl>> obj;
+ CHECK( ComLight::Object<ModelImpl>::create( obj ) );
+ hr = obj->load( &stream, hybrid, callbacks );
+ if( FAILED( hr ) )
+ {
+ logError16( L"Error loading the model from \"%s\"", path );
+ return hr;
+ }
+
+ obj.detach( pp );
+ logInfo16( L"Loaded model from \"%s\" to VRAM", path );
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ModelImpl.h b/Whisper/Whisper/ModelImpl.h
new file mode 100644
index 0000000..4bcea12
--- /dev/null
+++ b/Whisper/Whisper/ModelImpl.h
@@ -0,0 +1,40 @@
+#pragma once
+#include "../API/iContext.cl.h"
+#include "../ComLightLib/comLightServer.h"
+#include "WhisperModel.h"
+#include "../ComLightLib/streams.h"
+
+namespace Whisper
+{
+ using ComLight::iReadStream;
+
+ class ModelImpl : public ComLight::ObjectRoot<iModel>
+ {
+ WhisperModel model;
+
+ HRESULT COMLIGHTCALL createContext( iContext** pp ) override final;
+
+ HRESULT COMLIGHTCALL getSpecialTokens( SpecialTokens& rdi ) override final
+ {
+ model.vocab.getSpecialTokens( rdi );
+ return S_OK;
+ }
+
+ HRESULT COMLIGHTCALL isMultilingual() override final
+ {
+ return model.vocab.is_multilingual() ? S_OK : S_FALSE;
+ }
+
+ const char* COMLIGHTCALL stringFromToken( whisper_token token ) override final
+ {
+ return model.vocab.string( token );
+ }
+
+ public:
+
+ HRESULT FinalConstruct();
+ void FinalRelease();
+
+ HRESULT load( iReadStream* stm, bool hybrid, const sLoadModelCallbacks* callbacks );
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ModelLoader.h b/Whisper/Whisper/ModelLoader.h
new file mode 100644
index 0000000..b136dd2
--- /dev/null
+++ b/Whisper/Whisper/ModelLoader.h
@@ -0,0 +1,29 @@
+#pragma once
+#include "ModelBuffers.h"
+#include <map>
+
+namespace DirectCompute
+{
+ struct ModelLoader
+ {
+ ModelLoader( int encoderLayers, int decoderLayers );
+
+ void add( const ggml_tensor* ggml, Tensor& gpu );
+
+ void add( const ggml_tensor* w, const ggml_tensor* b, TensorPair& gpu )
+ {
+ add( w, gpu.w );
+ add( b, gpu.b );
+ }
+
+ bool tryLoad( const ggml_tensor* ggml );
+
+ ModelBuffers& model;
+
+ private:
+
+ Tensor* lookup( const ggml_tensor* ggml ) const;
+
+ std::map<const ggml_tensor*, Tensor*> map;
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/Spectrogram.cpp b/Whisper/Whisper/Spectrogram.cpp
new file mode 100644
index 0000000..400f9b1
--- /dev/null
+++ b/Whisper/Whisper/Spectrogram.cpp
@@ -0,0 +1,124 @@
+#include "stdafx.h"
+#include "Spectrogram.h"
+#include <memory>
+#define _USE_MATH_DEFINES
+#include <math.h>
+#include "../Utils/parallelFor.h"
+#include "../API/iMediaFoundation.cl.h"
+#include "../ML/testUtils.h"
+#include "melSpectrogram.h"
+using namespace Whisper;
+
+class alignas( 64 ) Spectrogram::MelContext
+{
+ const float* const samples;
+ const size_t countSamples;
+ Spectrogram& result;
+ const int n_threads;
+ SpectrogramContext context;
+
+public:
+
+ MelContext( const float* rsi, size_t len, const Filters& f, Spectrogram& rdi, int countThreads ) :
+ samples( rsi ), countSamples( len ), result( rdi ), n_threads( countThreads ),
+ context( f )
+ { }
+
+ void run( int ith );
+
+ static HRESULT workCallback( int ith, void* ctx ) noexcept;
+};
+
+void Spectrogram::MelContext::run( int ith )
+{
+ std::array<float, N_MEL> arr;
+ for( uint32_t i = ith; i < result.length; i += n_threads )
+ {
+ const int offset = i * FFT_STEP;
+ const float* rsi = samples + offset;
+ context.fft( arr, rsi, countSamples - offset );
+
+ for( size_t j = 0; j < N_MEL; j++ )
+ result.data[ j * result.length + i ] = arr[ j ];
+ }
+}
+
+HRESULT Spectrogram::MelContext::workCallback( int ith, void* ctx ) noexcept
+{
+ std::vector<Spectrogram::MelContext>& contexts = *( std::vector<Spectrogram::MelContext>* )ctx;
+ try
+ {
+ contexts.at( ith ).run( ith );
+ return S_OK;
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+ catch( const std::exception& )
+ {
+ return E_FAIL;
+ }
+}
+
+HRESULT Spectrogram::pcmToMel( const iAudioBuffer* buffer, const Filters& filters, int threads )
+{
+ if( nullptr == buffer )
+ return E_POINTER;
+ const uint32_t countSamples = buffer->countSamples();
+ if( 0 == countSamples )
+ return OLE_E_BLANK;
+ const float* const samples = buffer->getPcmMono();
+
+ length = ( countSamples ) / FFT_STEP;
+ data.resize( N_MEL * length );
+
+ if( threads < 2 )
+ {
+ MelContext ctx{ samples, countSamples, filters, *this, 1 };
+ ctx.run( 0 );
+ }
+ else
+ {
+ std::vector<MelContext> contexts;
+ contexts.reserve( threads );
+ for( int i = 0; i < threads; i++ )
+ contexts.emplace_back( MelContext{ samples, countSamples, filters, *this, (int)threads } );
+ CHECK( parallelFor( &MelContext::workCallback, threads, &contexts ) );
+ }
+
+ // clamping and normalization
+ double mmax = -1e20;
+ for( double f : data )
+ mmax = std::max( mmax, f );
+ //printf("%s: max = %f\n", __func__, mmax);
+
+ mmax -= 8.0;
+
+ for( float& f : data )
+ {
+ if( f < mmax )
+ f = (float)mmax;
+ f = (float)( ( f + 4.0 ) / 4.0 );
+ }
+ // DirectCompute::dbgWriteBinaryFile( LR"(C:\Temp\2remove\ML\mel-my.bin)", data.data(), data.size() * 4 );
+ return S_OK;
+}
+
+void Whisper::computeSignalEnergy( std::vector<float>& result, const iAudioBuffer* buffer, int n_samples_per_half_window )
+{
+ const size_t countSamples = buffer->countSamples();
+ const float* const samples = buffer->getPcmMono();
+
+ const int hw = n_samples_per_half_window;
+ result.resize( countSamples );
+
+ for( size_t i = 0; i < countSamples; i++ )
+ {
+ float sum = 0;
+ for( int j = -hw; j <= hw; j++ )
+ if( i + j >= 0 && i + j < countSamples )
+ sum += fabsf( samples[ i + j ] );
+ result[ i ] = sum / ( 2 * hw + 1 );
+ }
+} \ No newline at end of file
diff --git a/Whisper/Whisper/Spectrogram.h b/Whisper/Whisper/Spectrogram.h
new file mode 100644
index 0000000..04e2c06
--- /dev/null
+++ b/Whisper/Whisper/Spectrogram.h
@@ -0,0 +1,42 @@
+#pragma once
+#include "WhisperModel.h"
+#include "iSpectrogram.h"
+#include "audioConstants.h"
+
+namespace Whisper
+{
+ struct iAudioBuffer;
+
+ class Spectrogram: public iSpectrogram
+ {
+ uint32_t length = 0;
+ static constexpr uint32_t mel = N_MEL;
+ std::vector<float> data;
+
+ HRESULT makeBuffer( size_t off, size_t len, const float** buffer, size_t& stride ) noexcept override final
+ {
+ if( off + len > length )
+ return E_BOUNDS;
+ *buffer = &data[ off ];
+ stride = length;
+ return S_OK;
+ }
+
+ class MelContext;
+
+ public:
+ size_t getLength() const noexcept override final
+ {
+ return length;
+ }
+ HRESULT pcmToMel( const iAudioBuffer* buffer, const Filters& filters, int threads = 1 );
+
+ size_t memoryUsage() const
+ {
+ return data.size() * 4;
+ }
+ };
+
+ // average the fabs of the signal
+ void computeSignalEnergy( std::vector<float>& result, const iAudioBuffer* buffer, int n_samples_per_half_window );
+} \ No newline at end of file
diff --git a/Whisper/Whisper/TranscribeResult.h b/Whisper/Whisper/TranscribeResult.h
new file mode 100644
index 0000000..8f8e408
--- /dev/null
+++ b/Whisper/Whisper/TranscribeResult.h
@@ -0,0 +1,43 @@
+#pragma once
+#include "../API/iTranscribeResult.cl.h"
+#include "../ComLightLib/comLightServer.h"
+
+namespace Whisper
+{
+ class TranscribeResult : public ComLight::ObjectRoot<iTranscribeResult>
+ {
+ HRESULT COMLIGHTCALL getSize( sTranscribeLength& rdi ) const noexcept override final
+ {
+ rdi.countSegments = (uint32_t)segments.size();
+ rdi.countTokens = (uint32_t)tokens.size();
+ return S_OK;
+ }
+ const sSegment* COMLIGHTCALL getSegments() const noexcept override final
+ {
+ if( !segments.empty() )
+ return segments.data();
+ return nullptr;
+ }
+ const sToken* COMLIGHTCALL getTokens() const noexcept override final
+ {
+ if( !tokens.empty() )
+ return tokens.data();
+ return nullptr;
+ }
+
+ public:
+ std::vector<sSegment> segments;
+ std::vector<sToken> tokens;
+ };
+
+ class TranscribeResultStatic : public ComLight::Object<TranscribeResult>
+ {
+ uint32_t COMLIGHTCALL Release() override final
+ {
+ // When the ref.counter reaches zero, Object.Release() method calls `delete this`.
+ // We don't want that for the aggregated object.
+ // Instead we only decrement the ref.counter, but do not delete the object even when the counter reaches zero.
+ return RefCounter::implRelease();
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/Vocabulary.cpp b/Whisper/Whisper/Vocabulary.cpp
new file mode 100644
index 0000000..2e5eaec
--- /dev/null
+++ b/Whisper/Whisper/Vocabulary.cpp
@@ -0,0 +1,129 @@
+#include "stdafx.h"
+#include "Vocabulary.h"
+#include "loaderUtils.h"
+using ComLight::iReadStream;
+using namespace Whisper;
+
+void Vocabulary::addExtra( int index, const char* format, int i )
+{
+ const int len = std::snprintf( nullptr, 0, format, i );
+ const size_t offset = stringData.size();
+ stringData.resize( offset + len + 1 );
+ char* const rdi = stringData.data() + offset;
+ std::snprintf( rdi, len + 1, format, i );
+ rdi[ len ] = '\0';
+ tokens[ index ] = reinterpret_cast<const char*>( offset );
+}
+
+void Vocabulary::completeBuild()
+{
+ stringData.shrink_to_fit();
+
+ const size_t dataLength = stringData.size();
+ for( auto& s : tokens )
+ {
+ // The reason this hack works - on Windows, lower 2GB of address space is reserved to the kernel.
+ // That's why the strings from the read only section of this DLL like "[_PREV_]" are guaranteed to have their addresses much larger than the size of the data buffer
+ const size_t ri = reinterpret_cast<size_t>( s );
+ if( ri < dataLength )
+ s = stringData.data() + ri;
+ }
+
+ int64_t cb = stringData.size();
+ cb += tokens.size() * sizeof( void* );
+ constexpr double mulKb = 1.0 / ( 1 << 10 );
+ logDebug( u8"Loaded vocabulary, %zu strings, %.1f kb RAM", tokens.size(), mulKb * cb );
+}
+
+HRESULT Vocabulary::load( ComLight::iReadStream* stm, int lengthInHeader )
+{
+ if( lengthInHeader <= 0 )
+ return E_INVALIDARG;
+
+ tokens.clear();
+ stringData.clear();
+
+ int countWords = 0;
+ CHECK( readStruct( stm, countWords ) );
+ if( countWords <= 0 )
+ return E_INVALIDARG;
+
+ const size_t count = (uint32_t)countWords;
+ const size_t actualCount = std::max( count, (size_t)lengthInHeader );
+ tokens.resize( actualCount );
+
+ for( int i = 0; i < count; i++ )
+ {
+ int countChars = 0;
+ CHECK( readStruct( stm, countChars ) );
+ if( countChars < 0 )
+ {
+ logError( u8"Vocabulary.load failed: string length is negative" );
+ return E_INVALIDARG;
+ }
+ if( countChars == 0 )
+ {
+ // This happens with `ggml-large.bin` and `ggml-large-v1.bin` models.
+ // A bug in the model maybe?
+ tokens[ i ] = "";
+ continue;
+ }
+ const size_t len = (size_t)countChars;
+
+ const size_t offset = stringData.size();
+ stringData.resize( offset + len + 1 );
+
+ CHECK( readBytes( stm, &stringData[ offset ], len ) );
+ *stringData.rbegin() = '\0';
+
+ tokens[ i ] = reinterpret_cast<const char*>( offset );
+ }
+
+ n_vocab = lengthInHeader;
+
+ if( is_multilingual() )
+ {
+ token_eot++;
+ token_sot++;
+ token_prev++;
+ token_solm++;
+ token_not++;
+ token_beg++;
+ };
+
+ if( countWords < lengthInHeader )
+ {
+ for( int i = countWords; i < lengthInHeader; i++ )
+ {
+ if( i > token_beg )
+ addExtra( i, "[_TT_%i]", i - token_beg );
+ else if( i == token_eot )
+ tokens[ i ] = "[_EOT_]";
+ else if( i == token_sot )
+ tokens[ i ] = "[_SOT_]";
+ else if( i == token_prev )
+ tokens[ i ] = "[_PREV_]";
+ else if( i == token_not )
+ tokens[ i ] = "[_NOT_]";
+ else if( i == token_beg )
+ tokens[ i ] = "[_BEG_]";
+ else
+ addExtra( i, "[_extra_token_%i]", i );
+ }
+ }
+
+ completeBuild();
+ return S_OK;
+}
+
+void Vocabulary::getSpecialTokens( SpecialTokens& rdi ) const
+{
+ rdi.TranscriptionEnd = token_eot;
+ rdi.TranscriptionStart = token_sot;
+ rdi.PreviousWord = token_prev;
+ rdi.SentenceStart = token_solm;
+ rdi.Not = token_not;
+ rdi.TranscriptionBegin = token_beg;
+ rdi.TaskTranslate = token_translate;
+ rdi.TaskTranscribe = token_transcribe;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/Vocabulary.h b/Whisper/Whisper/Vocabulary.h
new file mode 100644
index 0000000..6250494
--- /dev/null
+++ b/Whisper/Whisper/Vocabulary.h
@@ -0,0 +1,58 @@
+#pragma once
+#include "../../ComLightLib/streams.h"
+#include "../API/SpecialTokens.h"
+
+namespace Whisper
+{
+ class Vocabulary
+ {
+ std::vector<const char*> tokens;
+ std::vector<char> stringData;
+
+ void addExtra( int index, const char* format, int i );
+
+ void completeBuild();
+ public:
+
+ int n_vocab = 51864;
+
+ HRESULT load( ComLight::iReadStream* stm, int lengthInHeader );
+
+ using id = int;
+
+ id token_eot = 50256;
+ id token_sot = 50257;
+ id token_prev = 50360;
+ id token_solm = 50361; // ??
+ id token_not = 50362; // no timestamps
+ id token_beg = 50363;
+
+ // available tasks
+ static const id token_translate = 50358;
+ static const id token_transcribe = 50359;
+
+ bool is_multilingual() const
+ {
+ return n_vocab == 51865;
+ }
+
+ const char* string( int id ) const
+ {
+ if( id >= 0 && id < (int)tokens.size() )
+ return tokens[ id ];
+ return nullptr;
+ }
+
+ size_t size() const
+ {
+ return tokens.size();
+ }
+
+ void getSpecialTokens( SpecialTokens& rdi ) const;
+
+ size_t getMemoryUse() const
+ {
+ return vectorMemoryUse( tokens ) + vectorMemoryUse( stringData );
+ }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/WhisperContext.cpp b/Whisper/Whisper/WhisperContext.cpp
new file mode 100644
index 0000000..c983039
--- /dev/null
+++ b/Whisper/Whisper/WhisperContext.cpp
@@ -0,0 +1,673 @@
+#include "stdafx.h"
+#include "WhisperContext.h"
+#include "ModelBuffers.h"
+#include <optional>
+#include "../Utils/Trace/tracing.h"
+#include "../D3D/RenderDoc/renderDoc.h"
+#include "../ML/testUtils.h"
+using namespace DirectCompute;
+
+namespace
+{
+ // True to measure GPU time of individual shaders which run during the encode step of the algorithm
+ constexpr bool profileEncodeShaders = true;
+ // True to measure GPU time of individual shaders which run during the decode step of the algorithm
+ constexpr bool profileDecodeShaders = true;
+
+ LPCTSTR traceFileNative = LR"(C:\Temp\2remove\Whisper\gpu.bin)";
+ LPCTSTR traceFileHybrid = LR"(C:\Temp\2remove\Whisper\hybrid.bin)";
+
+ TensorsArena::sArenaConfigs defaultArenaConfigs()
+ {
+ TensorsArena::sArenaConfigs res = {};
+ return res;
+ }
+}
+
+WhisperContext::Arenas::Arenas() :
+ enc( defaultArenaConfigs() ), encLayer( defaultArenaConfigs() ), dec( defaultArenaConfigs() ), decLayer( defaultArenaConfigs() )
+{ }
+
+Tensor WhisperContext::DecoderLayerPool::tensor( eDataType type, const std::array<uint32_t, 4>& ne )
+{
+ assert( type == eDataType::FP32 );
+ return result.tensor( eDataType::FP32, ne, &DirectCompute::defaultNewCapacity );
+}
+
+class WhisperContext::ArenaRaii
+{
+ WhisperContext& context;
+ iTensorArena* prevCurrent;
+
+public:
+
+ ArenaRaii( WhisperContext& ctx, iTensorArena& ta ) :
+ context( ctx )
+ {
+ prevCurrent = ctx.currentArena;
+ ctx.currentArena = &ta;
+ }
+
+ ~ArenaRaii()
+ {
+ context.currentArena->reset();
+ context.currentArena = prevCurrent;
+ }
+};
+
+WhisperContext::WhisperContext( const Whisper::WhisperModel& wm, Whisper::ProfileCollection& pc ) :
+ MlContext( pc ),
+ gpuModel( wm.tensors )
+{
+#if BUILD_HYBRID_VERSION
+ if( !wm.hybridTensors.layers.empty() )
+ {
+ hybridContext = std::make_unique<HybridContext>( wm );
+ check( hybridContext->create() );
+#if SAVE_DEBUG_TRACE
+ Tracing::traceCreate( traceFileHybrid );
+#endif
+ }
+ else
+#endif
+ {
+#if SAVE_DEBUG_TRACE
+ Tracing::traceCreate( traceFileNative );
+#endif
+ }
+}
+
+#if BUILD_BOTH_VERSIONS
+namespace
+{
+ thread_local WhisperContext* ts_context = nullptr;
+ const ModelBuffers& getGlobalModel()
+ {
+ return gpuModel;
+ }
+}
+
+/*
+WhisperContext::WhisperContext() :
+ gpuModel( getGlobalModel() ),
+{
+ if( nullptr != ts_context )
+ throw HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
+ ts_context = this;
+}*/
+
+WhisperContext::~WhisperContext()
+{
+ Tracing::traceClose();
+ if( ts_context == nullptr )
+ return;
+ assert( ts_context == this );
+ ts_context = nullptr;
+}
+
+WhisperContext& WhisperContext::current()
+{
+ WhisperContext* c = ts_context;
+ if( nullptr == c )
+ throw OLE_E_BLANK;
+ return *c;
+}
+#else
+WhisperContext& WhisperContext::current()
+{
+ throw E_NOTIMPL;
+}
+#endif
+
+Tensor WhisperContext::createTensor( eDataType type, const std::array<uint32_t, 4>& ne )
+{
+ // return MlContext::createTensor( type, ne );
+
+ iTensorArena* const ca = currentArena;
+ if( nullptr != ca )
+ return ca->tensor( type, ne );
+ else
+ return MlContext::createTensor( type, ne );
+}
+
+void WhisperContext::fmaRepeat( Tensor& cur, const TensorPair& that )
+{
+ MlContext::fmaRepeat( cur, that.w, that.b );
+}
+
+Tensor WhisperContext::convolutionAndGelu( const Tensor& mel, uint32_t n_ctx )
+{
+ const EncoderBuffers& model = gpuModel.enc;
+ Tensor cur = conv_1d_1s( model.conv1.w, mel );
+ Tracing::tensor( "enc.conv1", cur );
+ addRepeatGelu( cur, model.conv1.b );
+ Tracing::tensor( "enc.temp1", cur );
+
+ cur = conv_1d_2s( model.conv2.w, cur );
+ addRepeatGelu( cur, model.conv2.b );
+
+ const Tensor& posEmbed = model.positionalEmbedding;
+ const uint32_t peStride = posEmbed.ne[ 0 ];
+ constexpr uint32_t peOffset = 0;
+
+ Tensor e_pe = view2d( posEmbed, posEmbed.ne[ 0 ], n_ctx, peStride, peOffset );
+ cur = add( e_pe, transpose( cur ) );
+ return cur;
+}
+
+Tensor WhisperContext::encodeLayer( const Tensor& source, size_t index, uint32_t n_state, uint32_t n_head, uint32_t n_ctx )
+{
+ auto prof = profiler.block( eProfilerBlock::EncodeLayer );
+ ArenaRaii arenaRaii{ *this, arenas.encLayer };
+
+ const LayerEncoder& layer = gpuModel.enc.layers[ index ];
+ // norm
+ Tensor cur = norm( source );
+ if( 0 == index )
+ Tracing::tensor( "enc-norm", cur );
+ fmaRepeat( cur, layer.attnLn0 );
+
+ // self-attention
+ Tensor Qcur;
+ Tensor reshaped;
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ const uint16_t tag = profiler.setNextTag( "enc.layer.1" );
+ reshaped = reshapePanels( cur );
+ profiler.setNextTag( tag );
+ Qcur = mulMatTiledEx( layer.attnQuery.w, reshaped );
+ }
+ else
+ {
+ profiler.setNextTag( "enc.layer.1" );
+ Qcur = mulMat( layer.attnQuery.w, cur );
+ }
+
+ if( 0 == index )
+ Tracing::tensor( "enc-Qcur", Qcur );
+ addRepeat( Qcur, layer.attnQuery.b );
+
+ // note: no bias for Key
+ Tensor Kcur;
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ // Already reshaped by the previous `if`
+ profiler.setNextTag( "enc.layer.2" );
+ Kcur = mulMatTiledEx( layer.attnKey, reshaped );
+ }
+ else
+ {
+ profiler.setNextTag( "enc.layer.2" );
+ Kcur = mulMat( layer.attnKey, cur );
+ }
+
+ if( 0 == index )
+ Tracing::tensor( "enc-Kcur", Kcur );
+
+ Tensor Vcur;
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ // Already reshaped by the previous `if`
+ profiler.setNextTag( "enc.layer.3" );
+ Vcur = mulMatTiledEx( layer.attnValue.w, reshaped );
+ }
+ else
+ {
+ profiler.setNextTag( "enc.layer.3" );
+ Vcur = mulMat( layer.attnValue.w, cur );
+ }
+
+ if( 0 == index )
+ Tracing::tensor( "enc-Vcur", Vcur );
+ addRepeat( Vcur, layer.attnValue.b );
+
+ // ------
+ Tensor Q = permute( copy( Qcur, eDataType::FP16, { n_state / n_head, n_head, n_ctx } ), 0, 2, 1, 3 );
+ Tensor K = permute( copy( Kcur, eDataType::FP16, { n_state / n_head, n_head, n_ctx } ), 0, 2, 1, 3 );
+ Tensor V = copy( permute( Vcur.reshape3d( n_state / n_head, n_head, n_ctx ), 1, 2, 0, 3 ), eDataType::FP16, { n_ctx, n_state / n_head, n_head } );
+ Tensor KQV = flashAttention( Q, K, V, false );
+ if( 0 == index )
+ Tracing::tensor( "enc-KQV", KQV );
+ Tensor KQV_merged = permute( KQV, 0, 2, 1, 3 );
+ copyInPlace( cur, KQV_merged, eDataType::FP32, { n_state, n_ctx } );
+
+ // projection
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ const uint16_t tag = profiler.setNextTag( "enc.layer.4" );
+ cur = reshapePanels( cur );
+ profiler.setNextTag( tag );
+ cur = mulMatTiledEx( layer.attnLn1.w, cur );
+ }
+ else
+ {
+ profiler.setNextTag( "enc.layer.4" );
+ cur = mulMat( layer.attnLn1.w, cur );
+ }
+ addRepeat( cur, layer.attnLn1.b );
+
+ // add the input
+ addInPlace( cur, source );
+
+ // feed-forward network
+ Tensor inpFF = cur;
+
+ cur = norm( inpFF );
+ fmaRepeat( cur, layer.mlpLn );
+
+ // fully connected
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ const uint16_t tag = profiler.setNextTag( "enc.layer.5" );
+ cur = reshapePanels( cur );
+ profiler.setNextTag( tag );
+ cur = mulMatTiledEx( layer.mlp0.w, cur );
+ }
+ else
+ {
+ profiler.setNextTag( "enc.layer.5" );
+ cur = mulMat( layer.mlp0.w, cur );
+ }
+ addRepeatGelu( cur, layer.mlp0.b );
+
+ // projection
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ const uint16_t tag = profiler.setNextTag( "enc.layer.6" );
+ cur = reshapePanels( cur );
+ profiler.setNextTag( tag );
+ cur = mulMatTiledEx( layer.mlp1.w, cur );
+ }
+ else
+ {
+ profiler.setNextTag( "enc.layer.6" );
+ cur = mulMat( layer.mlp1.w, cur );
+ }
+
+ addRepeat( cur, layer.mlp1.b );
+
+ // output from this layer
+ addInPlace( cur, inpFF );
+ return cur;
+}
+
+void WhisperContext::createKeyValueBuffers( const sEncodeParams& encParams )
+{
+ {
+ const uint32_t n_audio_ctx = encParams.n_audio_ctx;
+ const uint32_t n_mem = encParams.n_text_layer * encParams.n_audio_ctx;
+ const uint32_t n_elements = encParams.n_text_state * n_mem;
+ kvCross.resize( n_elements );
+ }
+
+#if BUILD_HYBRID_VERSION
+ if( !hybridContext )
+#endif
+ {
+ const uint32_t n_mem = encParams.n_text_layer * encParams.n_text_ctx;
+ const uint32_t n_elements = encParams.n_text_state * n_mem;
+ kv.resize( n_elements );
+ }
+}
+
+Tensor WhisperContext::encode( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams )
+{
+ auto prof = profiler.block( eProfilerBlock::Encode );
+ CaptureRaii renderdocCapture;
+ profiler.profileShaders = profileEncodeShaders;
+
+ createKeyValueBuffers( encParams );
+ // Upload the source
+ check( melInput.create( spectrogram, encParams ) );
+ Tracing::tensor( "enc.input", melInput );
+
+ arenas.enc.clear();
+ ArenaRaii arenaRaii{ *this, arenas.enc };
+
+ // Initial few steps
+ Tensor cur = convolutionAndGelu( melInput, encParams.n_ctx );
+
+ // Process all these layers
+ {
+ const size_t layersCount = encParams.layersCount;
+ for( size_t i = 0; i < layersCount; i++ )
+ {
+ Tracing::tensor( { "enc.layer[ %i ].in", i }, cur );
+ cur = encodeLayer( cur, i, encParams.n_state, encParams.n_head, encParams.n_ctx );
+ }
+ }
+ Tracing::tensor( "enc.layers", cur );
+
+ // A few last steps
+ {
+ cur = norm( cur );
+ // cur = ln_f_g*cur + ln_f_b
+ fmaRepeat( cur, gpuModel.enc.lnPost );
+ }
+
+ // pre-compute cross-attention buffers
+ {
+ Tensor reshaped;
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ if( cur.ne[ 1 ] != 1 )
+ {
+ profiler.setNextTag( "enc.cross" );
+ reshaped = reshapePanels( cur );
+ }
+ else
+ reshaped = cur;
+ }
+
+ const size_t layersCount = encParams.n_text_layer;
+ const uint32_t stride = encParams.n_state * encParams.n_ctx;
+ const float finalScaling = (float)pow( float( encParams.n_state ) / encParams.n_head, -0.25 );
+ for( size_t i = 0; i < layersCount; i++ )
+ {
+ const LayerDecoder& layer = gpuModel.dec.layers[ i ];
+ Tensor Kcross, Vcross;
+ if( gpuInfo.useReshapedMatMul() )
+ Kcross = mulMatEx( layer.crossAttnKey, reshaped, "enc.cross.1" );
+ else
+ {
+ profiler.setNextTag( "enc.cross.1" );
+ Kcross = mulMat( layer.crossAttnKey, cur );
+ }
+ scale( Kcross, finalScaling );
+
+ if( gpuInfo.useReshapedMatMul() )
+ Vcross = mulMatEx( layer.crossAttnValue.w, reshaped, "enc.cross.2" );
+ else
+ {
+ profiler.setNextTag( "enc.cross.2" );
+ Vcross = mulMat( layer.crossAttnValue.w, cur );
+ }
+ addRepeat( Vcross, layer.crossAttnValue.b );
+
+ Tensor k = kvCross.keys.view( stride, stride * (uint32_t)i );
+ copyImpl( Kcross, k, Kcross.getType() == eDataType::FP32 );
+
+ Tensor v = kvCross.values.view( stride, stride * (uint32_t)i );
+ copyImpl( Vcross, v, Vcross.getType() == eDataType::FP32 );
+ }
+ }
+
+#if BUILD_HYBRID_VERSION
+ if( hybridContext )
+ {
+ // When running hybrid model, download cross-attention buffers from VRAM to system RAM
+ check( hybridContext->downloadKeyValues( kvCross ) );
+ }
+#endif
+ return cur;
+}
+
+struct WhisperContext::sLayerDecParams
+{
+ uint32_t n_state, n_head, N;
+ uint32_t n_ctx, n_past, M;
+};
+
+Tensor WhisperContext::decodeLayer( const Tensor& inpL, size_t il, const sLayerDecParams& ldp )
+{
+ auto prof = profiler.block( eProfilerBlock::DecodeLayer );
+ const auto& layer = gpuModel.dec.layers[ il ];
+ std::optional<ArenaRaii> arenaRaii{ std::in_place, *this, arenas.decLayer };
+ if( 0 == il ) Tracing::tensor( "dec-inpL", inpL );
+
+ // norm
+ Tensor cur = norm( inpL );
+ fmaRepeat( cur, layer.attnLn0 );
+ if( 0 == il ) Tracing::tensor( "dec-norm", cur );
+
+ // self-attention
+ {
+ profiler.setNextTag( "dec.layer.1" );
+ Tensor Qcur = mulMat( layer.attnQuery.w, cur );
+ if( 0 == il ) Tracing::tensor( "dec-Qcur-0", Qcur );
+ const float scaling = (float)pow( float( (int)ldp.n_state ) / (int)ldp.n_head, -0.25 );
+ addRepeatScale( Qcur, layer.attnQuery.b, scaling );
+ if( 0 == il ) Tracing::tensor( "dec-Qcur-1", Qcur );
+
+ // note: no bias for Key
+ profiler.setNextTag( "dec.layer.2" );
+ Tensor Kcur = mulMat( layer.attnKey, cur );
+ scale( Kcur, scaling );
+ if( 0 == il ) Tracing::tensor( "dec-Kcur", Kcur );
+
+ profiler.setNextTag( "dec.layer.3" );
+ Tensor Vcur = mulMat( layer.attnValue.w, cur );
+ addRepeat( Vcur, layer.attnValue.b );
+ if( 0 == il ) Tracing::tensor( "dec-Vcur", Vcur );
+
+ // store key and value to memory
+ {
+ const uint32_t len = ldp.N * ldp.n_state;
+ const uint32_t off = ldp.n_state * ( (uint32_t)il * ldp.n_ctx + ldp.n_past );
+ Tensor k = kv.keys.view( len, off );
+ Tensor v = kv.values.view( len, off );
+ copyImpl( Kcur, k, true );
+ copyImpl( Vcur, v, true );
+ }
+
+ // ------
+ Tensor Q = permute( copy( Qcur, eDataType::FP32, { ldp.n_state / ldp.n_head, ldp.n_head, ldp.N } ), 0, 2, 1, 3 );
+ Tensor K = permute( kv.keys.view( ( ldp.n_past + ldp.N ) * ldp.n_state, (uint32_t)il * ldp.n_ctx * ldp.n_state )
+ .reshape3d( ldp.n_state / ldp.n_head, ldp.n_head, ldp.n_past + ldp.N ),
+ 0, 2, 1, 3 );
+ profiler.setNextTag( "dec.layer.4" );
+ Tensor KQ = mulMat( K, Q );
+ if( 0 == il ) Tracing::tensor( "dec-KQ-0", KQ );
+ diagMaskInf( KQ, ldp.n_past );
+ if( 0 == il ) Tracing::tensor( "dec-KQ-1", KQ );
+ softMax( KQ );
+ if( 0 == il ) Tracing::tensor( "dec-KQ-2", KQ );
+
+ Tensor V_trans = permute(
+ kv.values
+ .view( ( ldp.n_past + ldp.N ) * ldp.n_state, (uint32_t)il * ldp.n_ctx * ldp.n_state )
+ .reshape3d( ldp.n_state / ldp.n_head, ldp.n_head, ldp.n_past + ldp.N ),
+ 1, 2, 0, 3 );
+
+ profiler.setNextTag( "dec.layer.5" );
+ Tensor KQV = mulMat( V_trans, KQ );
+ if( 0 == il ) Tracing::tensor( "dec-KQV", KQV );
+
+ Tensor KQV_merged = permute( KQV, 0, 2, 1, 3 );
+ copyInPlace( cur, KQV_merged, eDataType::FP32, { ldp.n_state, ldp.N } );
+ }
+
+ {
+ profiler.setNextTag( "dec.layer.6" );
+ cur = mulMat( layer.attnLn1.w, cur );
+ addRepeat( cur, layer.attnLn1.b );
+ }
+
+ // add the input
+ Tensor inpCA = add( cur, inpL );
+
+ // norm
+ {
+ cur = norm( inpCA );
+ fmaRepeat( cur, layer.crossAttnLn0 );
+ }
+
+ // cross-attention
+ {
+ profiler.setNextTag( "dec.layer.7" );
+ Tensor Qcur = mulMat( layer.crossAttnQuery.w, cur );
+ addRepeatScale( Qcur, layer.crossAttnQuery.b, (float)pow( float( (int)ldp.n_state ) / (int)ldp.n_head, -0.25 ) );
+
+ // Kcross is already scaled
+ const uint32_t len = ldp.M * ldp.n_state;
+ const uint32_t off = (uint32_t)il * len;
+ Tensor Kcross = kvCross.keys.view( len, off ).reshape3d( ldp.n_state / ldp.n_head, ldp.n_head, ldp.M );
+ Tensor Vcross = kvCross.values.view( len, off ).reshape3d( ldp.n_state / ldp.n_head, ldp.n_head, ldp.M );
+
+ // ------
+ Tensor Q = permute( copy( Qcur, eDataType::FP32, { ldp.n_state / ldp.n_head, ldp.n_head, ldp.N } ), 0, 2, 1, 3 );
+ Tensor K = permute( Kcross, 0, 2, 1, 3 );
+ profiler.setNextTag( "dec.layer.8" );
+ Tensor KQ = mulMat( K, Q );
+ softMax( KQ );
+ Tensor V_trans = permute( Vcross, 1, 2, 0, 3 );
+ profiler.setNextTag( "dec.layer.9" );
+ Tensor KQV = mulMat( V_trans, KQ );
+ if( 0 == il ) Tracing::tensor( "dec-KQV", KQV );
+ Tensor KQV_merged = permute( KQV, 0, 2, 1, 3 );
+
+ copyInPlace( cur, KQV_merged, eDataType::FP32, { ldp.n_state, ldp.N } );
+ }
+
+ // projection
+ {
+ profiler.setNextTag( "dec.layer.10" );
+ cur = mulMat( layer.crossAttnLn1.w, cur );
+ addRepeat( cur, layer.crossAttnLn1.b );
+ }
+ // add the input
+ addInPlace( cur, inpCA );
+ Tensor inpFF = cur;
+
+ // feed-forward network
+ {
+ // norm
+ cur = norm( inpFF );
+ fmaRepeat( cur, layer.mlpLn );
+
+ if( gpuInfo.useReshapedMatMul() )
+ cur = mulMatEx( layer.mlp0.w, cur, "dec.layer.11" );
+ else
+ {
+ profiler.setNextTag( "dec.layer.11" );
+ cur = mulMat( layer.mlp0.w, cur );
+ }
+
+ addRepeatGelu( cur, layer.mlp0.b );
+
+ // projection
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ if( cur.ne[ 1 ] != 1 )
+ {
+ const uint16_t tag = profiler.setNextTag( "dec.layer.12" );
+ cur = reshapePanels( cur );
+
+ // The mulMatTiledEx() line creates a layer output tensor. We have a special pool for such tensors so they survive the destruction of the arena.
+ arenaRaii.emplace( *this, decPool );
+ profiler.setNextTag( tag );
+ cur = mulMatTiledEx( layer.mlp1.w, cur );
+ }
+ else
+ {
+ // The mulMatByRowTiledEx() line creates a layer output tensor. We have a special pool for such tensors so they survive the destruction of the arena.
+ arenaRaii.emplace( *this, decPool );
+ profiler.setNextTag( "dec.layer.12" );
+ cur = mulMatByRowTiledEx( layer.mlp1.w, cur );
+ }
+ }
+ else
+ {
+ // The mulMat() line creates a layer output tensor. We have a special pool for such tensors so they survive the destruction of the arena.
+ arenaRaii.emplace( *this, decPool );
+ profiler.setNextTag( "dec.layer.12" );
+ cur = mulMat( layer.mlp1.w, cur );
+ }
+ addRepeat( cur, layer.mlp1.b );
+ }
+
+ // output from this layer
+ addInPlace( cur, inpFF );
+ return cur;
+}
+
+void WhisperContext::decode( const int* tokens, const int n_tokens, const sDecodeParams& decParams, std::vector<float>& probs, int threads )
+{
+ auto cppp = profiler.cpuBlock( Whisper::eCpuBlock::DecodeStep );
+
+#if BUILD_HYBRID_VERSION
+ if( hybridContext )
+ {
+ HybridContext::sDecParams sdp;
+ sdp.n_threads = threads;
+ sdp.M = decParams.M;
+ check( hybridContext->decode( tokens, n_tokens, decParams.n_past, sdp, probs ) );
+ return;
+ }
+#endif
+
+ auto prof = profiler.block( eProfilerBlock::DecodeStep );
+ CaptureRaii renderdocCapture;
+ profiler.profileShaders = profileDecodeShaders;
+ ArenaRaii arenaRaii{ *this, arenas.dec };
+
+ assert( n_tokens > 0 );
+ const uint32_t N = (uint32_t)n_tokens;
+ decoderInput.resize( N );
+
+ Tensor embd = decoderInput.embedding( tokens );
+ Tensor cur = addRows( gpuModel.dec.tokenEmbedding, gpuModel.dec.positionalEmbedding, embd, decParams.n_past );
+ Tracing::tensor( "dec-rows", cur );
+
+ {
+ sLayerDecParams ldp;
+ ldp.n_state = decParams.n_state;
+ ldp.n_head = decParams.n_head;
+ ldp.N = N;
+ ldp.n_ctx = decParams.n_ctx;
+ ldp.n_past = decParams.n_past;
+ ldp.M = decParams.M;
+#if 1
+ for( size_t i = 0; i < decParams.n_text_layer; i++ )
+ cur = decodeLayer( cur, i, ldp );
+#else
+ dbgDecodeTest = decodeLayer( cur, 0, ldp );
+ return;
+#endif
+ }
+
+ // norm
+ cur = norm( cur );
+ fmaRepeat( cur, gpuModel.dec.ln );
+
+ profiler.setNextTag( "dec.logits" );
+ cur = mulMat( gpuModel.dec.tokenEmbedding, cur );
+
+ // logits -> probs
+ softMax( cur );
+
+ decoderOutput.copyFromVram( cur );
+ assert( decoderOutput.size() == N * decParams.n_vocab );
+
+ decoderOutput.copyToVector( probs );
+ Tracing::vector( "probs", probs );
+}
+
+__m128i WhisperContext::Arenas::getMemoryUse() const
+{
+ __m128i res = enc.getMemoryUse();
+ res = _mm_add_epi64( res, encLayer.getMemoryUse() );
+ res = _mm_add_epi64( res, dec.getMemoryUse() );
+ res = _mm_add_epi64( res, decLayer.getMemoryUse() );
+ return res;
+}
+
+__m128i WhisperContext::DecoderLayerPool::getMemoryUse() const
+{
+ size_t cb = result.getCapacity() * 4;
+ __m128i res = _mm_setzero_si128();
+ return _mm_insert_epi64( res, (int64_t)cb, 1 );
+}
+
+__m128i WhisperContext::getMemoryUse() const
+{
+ __m128i res = MlContext::getMemoryUse();
+ res = _mm_add_epi64( res, arenas.getMemoryUse() );
+ res = _mm_add_epi64( res, decPool.getMemoryUse() );
+ res = _mm_add_epi64( res, melInput.getMemoryUse() );
+ res = _mm_add_epi64( res, kv.getMemoryUse() );
+ res = _mm_add_epi64( res, kvCross.getMemoryUse() );
+ res = _mm_add_epi64( res, decoderInput.getMemoryUse() );
+ res = _mm_add_epi64( res, decoderOutput.getMemoryUse() );
+ return res;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/WhisperContext.h b/Whisper/Whisper/WhisperContext.h
new file mode 100644
index 0000000..227e6b6
--- /dev/null
+++ b/Whisper/Whisper/WhisperContext.h
@@ -0,0 +1,126 @@
+#pragma once
+#include "../ML/MlContext.h"
+#include "MelInputTensor.h"
+#include "KeyValueBuffers.h"
+#include "sEncodeParams.h"
+#include "DecoderInputBuffers.h"
+#include "DecoderResultBuffer.h"
+#include "../ML/TensorsArena.h"
+#include "iSpectrogram.h"
+#include "../Hybrid/HybridContext.h"
+#include <memory>
+#include "WhisperModel.h"
+#include <tuple>
+#include <optional>
+
+namespace DirectCompute
+{
+ struct TensorPair;
+ struct ModelBuffers;
+
+ class WhisperContext : public MlContext
+ {
+ struct Arenas
+ {
+ TensorsArena enc;
+ TensorsArena encLayer;
+ TensorsArena dec;
+ TensorsArena decLayer;
+
+ Arenas();
+ __m128i getMemoryUse() const;
+ };
+
+ iTensorArena* currentArena = nullptr;
+ Arenas arenas;
+
+ // Specialized tensor arena for decoder layer outputs, with just a single tensor
+ class DecoderLayerPool : public iTensorArena
+ {
+ PooledTensor result;
+ public:
+ Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final;
+ void reset() override final { }
+ __m128i getMemoryUse() const;
+ };
+
+ DecoderLayerPool decPool;
+
+ class ArenaRaii;
+
+ MelInputTensor melInput;
+ KeyValueBuffers kv, kvCross;
+ DecoderInputBuffers decoderInput;
+ DecoderResultBuffer decoderOutput;
+ const ModelBuffers& gpuModel;
+#if BUILD_HYBRID_VERSION
+ std::unique_ptr<HybridContext> hybridContext;
+#endif
+ struct sWhisperMel
+ {
+ uint32_t n_len;
+ uint32_t n_mel;
+ const std::vector<float>& data;
+ };
+
+ void createKeyValueBuffers( const sEncodeParams& encParams );
+ // Encoder methods
+ Tensor convolutionAndGelu( const Tensor& mel, uint32_t n_ctx );
+ Tensor encodeLayer( const Tensor& source, size_t index, uint32_t n_state, uint32_t n_head, uint32_t n_ctx );
+
+ struct sLayerDecParams;
+
+ // Decoder methods
+ Tensor decodeLayer( const Tensor& source, size_t index, const sLayerDecParams& ldp );
+
+ // cur = add( mul( repeat( that.w, cur ), cur ), repeat( that.b, cur ) );
+ void fmaRepeat( Tensor& cur, const TensorPair& that );
+
+ Tensor createTensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final;
+
+ public:
+#if BUILD_BOTH_VERSIONS
+ WhisperContext();
+ ~WhisperContext();
+#else
+ ~WhisperContext() = default;
+#endif
+ WhisperContext( const Whisper::WhisperModel& wm, Whisper::ProfileCollection& pc );
+ WhisperContext( const WhisperContext& ) = delete;
+
+ Tensor encode( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams );
+
+ void decode( const int* tokens, const int n_tokens, const sDecodeParams& decParams, std::vector<float>& probs, int threads );
+
+ static WhisperContext& current();
+
+ // Create a RAII object which measures both CPU and GPU time for the complete runFull() method
+ decltype( auto ) completeProfiler()
+ {
+ return std::make_tuple(
+ profiler.cpuBlock( Whisper::eCpuBlock::Run ),
+ profiler.block( eProfilerBlock::Run ) );
+ }
+
+ // Create a RAII object which measures CPU and optionally GPU time for the loop which calls decode() method
+ decltype( auto ) decodeProfiler()
+ {
+#if BUILD_HYBRID_VERSION
+ if( hybridContext )
+ return std::make_tuple(
+ profiler.cpuBlock( Whisper::eCpuBlock::Decode ),
+ std::optional<GpuProfiler::BlockRaii>{} );
+ else
+ return std::make_tuple(
+ profiler.cpuBlock( Whisper::eCpuBlock::Decode ),
+ std::optional<GpuProfiler::BlockRaii>{ std::in_place, profiler.block( eProfilerBlock::Decode ) } );
+#else
+ return std::make_tuple(
+ profiler.cpuBlock( Whisper::eCpuBlock::Decode ),
+ profiler.block( eProfilerBlock::Decode ) );
+#endif
+ }
+
+ __m128i getMemoryUse() const;
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/WhisperModel.cpp b/Whisper/Whisper/WhisperModel.cpp
new file mode 100644
index 0000000..28e4540
--- /dev/null
+++ b/Whisper/Whisper/WhisperModel.cpp
@@ -0,0 +1,511 @@
+#include "stdafx.h"
+#include "WhisperModel.h"
+#include "loaderUtils.h"
+#include "../D3D/createBuffer.h"
+#include <atlcoll.h>
+#include <atlstr.h>
+#include "../Utils/GpuProfilerSimple.h"
+#include "../Utils/CpuProfiler.h"
+#include "../CPU/HybridLoader.h"
+#include "../ML/Reshaper.h"
+using namespace Whisper;
+using namespace DirectCompute;
+
+namespace
+{
+ struct ParamsAndMelHeader
+ {
+ sModelParams mp;
+ uint32_t n_mel = 0, n_fft = 0;
+ };
+
+ enum struct ePostProcessing : uint8_t
+ {
+ None = 0,
+ MakePanels = 1
+ };
+
+ struct PendingTensor
+ {
+ DirectCompute::Tensor* dest = nullptr;
+ ePostProcessing postProcessing = ePostProcessing::None;
+
+ PendingTensor() = default;
+ PendingTensor( const PendingTensor& ) = default;
+ PendingTensor( DirectCompute::Tensor& tensor, ePostProcessing pp = ePostProcessing::None ) :
+ dest( &tensor ), postProcessing( pp ) { }
+
+#if RESHAPED_MATRIX_MULTIPLY
+ // If you wonder why not reshape them after all tensors are loaded, doing that on the fly is faster because CPU and GPU work in parallel
+ // In the current version, CPU reads data for a next tensor, while in the meantime GPU reshapes a previously loaded tensor.
+ HRESULT postProcess( Reshaper& rs, eDataType dt )
+ {
+ switch( postProcessing )
+ {
+ case ePostProcessing::None:
+ return S_OK;
+ case ePostProcessing::MakePanels:
+ if( gpuInfo.useReshapedMatMul() )
+ {
+ // GpuInfo structure says we should use that new method
+ return rs.makePanels( *dest, dt );
+ }
+ else
+ {
+ // The feature ain't enabled on the current user's GPU
+ return S_OK;
+ }
+ default:
+ return E_UNEXPECTED;
+ }
+ }
+#endif
+ };
+
+ void populateEncodeTensorsMap( CAtlMap<CStringA, PendingTensor>& map, int layersEnc, DirectCompute::ModelBuffers& tensors )
+ {
+ tensors.enc.layers.resize( layersEnc );
+
+ CStringA tempString;
+ // Encoder tensors
+ auto& enc = tensors.enc;
+
+ map[ "encoder.positional_embedding" ] = enc.positionalEmbedding;
+ map[ "encoder.conv1.weight" ] = enc.conv1.w;
+ map[ "encoder.conv1.bias" ] = enc.conv1.b;
+
+ map[ "encoder.conv2.weight" ] = enc.conv2.w;
+ map[ "encoder.conv2.bias" ] = enc.conv2.b;
+
+ map[ "encoder.ln_post.weight" ] = enc.lnPost.w;
+ map[ "encoder.ln_post.bias" ] = enc.lnPost.b;
+
+ auto add = [ & ]( const char* name, int i, DirectCompute::Tensor& t, ePostProcessing pp = ePostProcessing::None )
+ {
+ tempString.Format( "encoder.blocks.%i.%s", i, name );
+ map[ tempString ] = PendingTensor{ t, pp };
+ };
+ auto add2 = [ & ]( const char* name, int i, DirectCompute::TensorPair& t, ePostProcessing ppWeight = ePostProcessing::None, ePostProcessing ppBias = ePostProcessing::None )
+ {
+ tempString.Format( "encoder.blocks.%i.%s.weight", i, name );
+ map[ tempString ] = PendingTensor{ t.w, ppWeight };
+ tempString.Format( "encoder.blocks.%i.%s.bias", i, name );
+ map[ tempString ] = PendingTensor{ t.b, ppBias };
+ };
+
+ for( int i = 0; i < layersEnc; i++ )
+ {
+ auto& gpu = enc.layers[ i ];
+ add2( "mlp_ln", i, gpu.mlpLn );
+ add2( "mlp.0", i, gpu.mlp0, ePostProcessing::MakePanels );
+ add2( "mlp.2", i, gpu.mlp1, ePostProcessing::MakePanels );
+ add2( "attn_ln", i, gpu.attnLn0 );
+ add2( "attn.query", i, gpu.attnQuery, ePostProcessing::MakePanels );
+ add( "attn.key.weight", i, gpu.attnKey, ePostProcessing::MakePanels );
+ add2( "attn.value", i, gpu.attnValue, ePostProcessing::MakePanels );
+ add2( "attn.out", i, gpu.attnLn1, ePostProcessing::MakePanels );
+ }
+ }
+
+ void populateDecodeTensorsMap( CAtlMap<CStringA, PendingTensor>& map, int layersDec, DirectCompute::ModelBuffers& tensors, bool hybrid )
+ {
+ tensors.dec.layers.resize( layersDec );
+ CStringA tempString;
+ // Decoder tensors
+
+ auto& dec = tensors.dec;
+ if( !hybrid )
+ {
+ map[ "decoder.positional_embedding" ] = dec.positionalEmbedding;
+ map[ "decoder.token_embedding.weight" ] = dec.tokenEmbedding;
+ map[ "decoder.ln.weight" ] = dec.ln.w;
+ map[ "decoder.ln.bias" ] = dec.ln.b;
+ }
+
+ auto add = [ & ]( const char* name, int i, DirectCompute::Tensor& t, ePostProcessing pp = ePostProcessing::None )
+ {
+ tempString.Format( "decoder.blocks.%i.%s", i, name );
+ map[ tempString ] = PendingTensor{ t, pp };
+ };
+ auto add2 = [ & ]( const char* name, int i, DirectCompute::TensorPair& t, ePostProcessing ppWeight = ePostProcessing::None, ePostProcessing ppBias = ePostProcessing::None )
+ {
+ tempString.Format( "decoder.blocks.%i.%s.weight", i, name );
+ map[ tempString ] = PendingTensor{ t.w, ppWeight };
+ tempString.Format( "decoder.blocks.%i.%s.bias", i, name );
+ map[ tempString ] = PendingTensor{ t.b, ppBias };
+ };
+
+ for( int i = 0; i < layersDec; i++ )
+ {
+ auto& gpu = dec.layers[ i ];
+ add( "cross_attn.key.weight", i, gpu.crossAttnKey, ePostProcessing::MakePanels );
+ add2( "cross_attn.value", i, gpu.crossAttnValue, ePostProcessing::MakePanels );
+ if( hybrid )
+ continue;
+
+ add2( "mlp_ln", i, gpu.mlpLn );
+ add2( "mlp.0", i, gpu.mlp0, ePostProcessing::MakePanels );
+ add2( "mlp.2", i, gpu.mlp1, ePostProcessing::MakePanels );
+ add2( "attn_ln", i, gpu.attnLn0 );
+ add2( "attn.query", i, gpu.attnQuery );
+ add( "attn.key.weight", i, gpu.attnKey );
+ add2( "attn.value", i, gpu.attnValue );
+ add2( "attn.out", i, gpu.attnLn1 );
+ add2( "cross_attn_ln", i, gpu.crossAttnLn0 );
+ add2( "cross_attn.query", i, gpu.crossAttnQuery );
+ add2( "cross_attn.out", i, gpu.crossAttnLn1 );
+ }
+ }
+
+ void populateTensorsMap( CAtlMap<CStringA, PendingTensor>& map, int layersEnc, int layersDec, DirectCompute::ModelBuffers& tensors, bool hybrid )
+ {
+ populateEncodeTensorsMap( map, layersEnc, tensors );
+ populateDecodeTensorsMap( map, layersDec, tensors, hybrid );
+ }
+
+ struct sTensorHeader
+ {
+ int n_dims, length, ftype;
+ };
+
+ // compare signed int32 lanes for a <= b
+ inline __m128i cmple( __m128i a, __m128i b )
+ {
+ __m128i i = _mm_min_epi32( a, b );
+ return _mm_cmpeq_epi32( a, i );
+ }
+
+ inline bool allPositive( const std::array<int, 4>& ne )
+ {
+ const __m128i v = _mm_loadu_si128( ( const __m128i* )ne.data() );
+ const __m128i le = cmple( v, _mm_setzero_si128() );
+ return (bool)_mm_testz_si128( le, le );
+ }
+
+ inline const char* cstr( const CStringA& s ) { return s; }
+}
+
+class WhisperModel::CallbacksImpl : public CpuCompute::iLoaderProgressSink
+{
+ sLoadModelCallbacks lmcb;
+ int64_t fileSize;
+
+ HRESULT gotBytes( int64_t cb ) override final
+ {
+ if( nullptr != lmcb.cancel )
+ {
+ HRESULT hr = lmcb.cancel( lmcb.pv );
+ CHECK( hr );
+ if( S_OK != hr )
+ return HRESULT_FROM_WIN32( ERROR_CANCELLED );
+ }
+
+ if( nullptr != lmcb.progress )
+ {
+ postponedBytes -= cb;
+ assert( postponedBytes >= 0 );
+ int64_t pos = fileSize - postponedBytes;
+ const double progressVal = (double)pos / (double)fileSize;
+ HRESULT hr = lmcb.progress( progressVal, lmcb.pv );
+ CHECK( hr );
+ }
+ return S_OK;
+ }
+public:
+ int64_t postponedBytes;
+
+ CallbacksImpl()
+ {
+ lmcb.progress = nullptr;
+ lmcb.cancel = nullptr;
+ lmcb.pv = nullptr;
+ fileSize = 0;
+ postponedBytes = 0;
+ }
+
+ HRESULT initialize( ComLight::iReadStream* stm, const sLoadModelCallbacks* rsi )
+ {
+ if( nullptr == rsi )
+ return S_OK;
+ lmcb = *rsi;
+ if( nullptr != lmcb.progress )
+ CHECK( stm->getLength( fileSize ) );
+ return S_OK;
+ }
+
+ HRESULT call( ComLight::iReadStream* stm )
+ {
+ if( nullptr != lmcb.cancel )
+ {
+ HRESULT hr = lmcb.cancel( lmcb.pv );
+ CHECK( hr );
+ if( S_OK != hr )
+ return HRESULT_FROM_WIN32( ERROR_CANCELLED );
+ }
+
+ if( nullptr != lmcb.progress )
+ {
+ int64_t pos;
+ CHECK( stm->getPosition( pos ) );
+ pos -= postponedBytes;
+ const double progressVal = (double)pos / (double)fileSize;
+ HRESULT hr = lmcb.progress( progressVal, lmcb.pv );
+ CHECK( hr );
+ }
+ return S_OK;
+ }
+};
+
+HRESULT WhisperModel::loadGpu( ComLight::iReadStream* stm, CallbacksImpl& callbacks )
+{
+ CAtlMap<CStringA, PendingTensor> map;
+ populateTensorsMap( map, parameters.n_audio_layer, parameters.n_text_layer, tensors, false );
+
+#if RESHAPED_MATRIX_MULTIPLY
+ DirectCompute::Reshaper reshape;
+#endif
+
+ std::vector<uint8_t> bytesVector;
+ size_t countLoaded = 0;
+ CStringA name;
+ int64_t cb = 0;
+ while( true )
+ {
+ CHECK( callbacks.call( stm ) );
+
+ sTensorHeader header;
+ HRESULT hr = readStruct( stm, header );
+ if( hr == E_EOF )
+ break;
+ if( FAILED( hr ) )
+ return hr;
+ if( header.n_dims < 1 || header.n_dims>3 )
+ return E_INVALIDARG;
+
+ std::array<int, 4> ne = { 1, 1, 1, 1 };
+ CHECK( readBytes( stm, ne.data(), header.n_dims * 4 ) );
+ if( !allPositive( ne ) )
+ return E_INVALIDARG;
+
+ char* nameBuffer = name.GetBufferSetLength( header.length );
+ hr = readBytes( stm, nameBuffer, header.length );
+ name.ReleaseBuffer();
+ if( FAILED( hr ) )
+ return hr;
+
+ auto p = map.Lookup( name );
+ if( nullptr == p )
+ {
+ logError( u8"%s: unknown tensor '%s' in model file", __func__, cstr( name ) );
+ return E_INVALIDARG;
+ }
+
+ DirectCompute::eDataType dt;
+ size_t cbElement;
+ if( header.ftype == 0 )
+ {
+ dt = DirectCompute::eDataType::FP32;
+ cbElement = 4;
+ }
+ else
+ {
+ dt = DirectCompute::eDataType::FP16;
+ cbElement = 2;
+ }
+
+ const size_t totalElts = (size_t)(uint32_t)ne[ 0 ] * (uint32_t)ne[ 1 ] * (uint32_t)ne[ 2 ];
+ if( totalElts * cbElement > UINT_MAX )
+ return DISP_E_OVERFLOW;
+
+ try
+ {
+ bytesVector.resize( cbElement * totalElts );
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+ CHECK( readBytes( stm, bytesVector.data(), bytesVector.size() ) );
+ cb += bytesVector.size();
+ CHECK( p->m_value.dest->createImmutable( dt, ne, bytesVector.data() ) );
+#if RESHAPED_MATRIX_MULTIPLY
+ CHECK( p->m_value.postProcess( reshape, dt ) );
+#endif
+ countLoaded++;
+ }
+
+ if( countLoaded != map.GetCount() )
+ {
+ logError( u8"Not all tensors loaded from model file - expected %zu, got %zu", map.GetCount(), countLoaded );
+ return E_INVALIDARG;
+ }
+
+ constexpr double mulMb = 1.0 / ( 1 << 20 );
+ logDebug( u8"Loaded %zu GPU tensors, %g MB VRAM", countLoaded, mulMb * cb );
+ return S_OK;
+}
+
+#if BUILD_HYBRID_VERSION
+HRESULT WhisperModel::loadHybrid( ComLight::iReadStream* stm, CallbacksImpl& callbacks )
+{
+ CAtlMap<CStringA, PendingTensor> map;
+ populateTensorsMap( map, parameters.n_audio_layer, parameters.n_text_layer, tensors, true );
+
+#if RESHAPED_MATRIX_MULTIPLY
+ DirectCompute::Reshaper reshape;
+#endif
+
+ CpuCompute::HybridLoader loader( hybridTensors, parameters.n_text_layer );
+
+ std::vector<uint8_t> bytesVector;
+ size_t countLoaded = 0;
+ CStringA name;
+ int64_t cb = 0;
+ while( true )
+ {
+ CHECK( callbacks.call( stm ) );
+
+ sTensorHeader header;
+ HRESULT hr = readStruct( stm, header );
+ if( hr == E_EOF )
+ break;
+ if( FAILED( hr ) )
+ return hr;
+ if( header.n_dims < 1 || header.n_dims > 3 )
+ return E_INVALIDARG;
+
+ std::array<int, 4> ne = { 1, 1, 1, 1 };
+ CHECK( readBytes( stm, ne.data(), header.n_dims * 4 ) );
+ if( !allPositive( ne ) )
+ return E_INVALIDARG;
+
+ char* nameBuffer = name.GetBufferSetLength( header.length );
+ hr = readBytes( stm, nameBuffer, header.length );
+ name.ReleaseBuffer();
+ if( FAILED( hr ) )
+ return hr;
+
+ auto p = map.Lookup( name );
+ if( nullptr == p )
+ {
+ HRESULT hr = loader.setupTensor( name, header.n_dims, header.ftype, ne, stm, callbacks.postponedBytes );
+ if( hr == S_OK )
+ continue;
+ logError( u8"%s: unknown tensor '%s' in model file", __func__, cstr( name ) );
+ return E_INVALIDARG;
+ }
+
+ DirectCompute::eDataType dt;
+ size_t cbElement;
+ if( header.ftype == 0 )
+ {
+ dt = DirectCompute::eDataType::FP32;
+ cbElement = 4;
+ }
+ else
+ {
+ dt = DirectCompute::eDataType::FP16;
+ cbElement = 2;
+ }
+
+ const size_t totalElts = (size_t)(uint32_t)ne[ 0 ] * (uint32_t)ne[ 1 ] * (uint32_t)ne[ 2 ];
+ if( totalElts * cbElement > UINT_MAX )
+ return DISP_E_OVERFLOW;
+
+ try
+ {
+ bytesVector.resize( cbElement * totalElts );
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+ CHECK( readBytes( stm, bytesVector.data(), bytesVector.size() ) );
+ CHECK( p->m_value.dest->createImmutable( dt, ne, bytesVector.data() ) );
+#if RESHAPED_MATRIX_MULTIPLY
+ CHECK( p->m_value.postProcess( reshape, dt ) );
+#endif
+ countLoaded++;
+ cb += bytesVector.size();
+ }
+
+ if( countLoaded != map.GetCount() )
+ {
+ logError( u8"Not all tensors loaded from model file - expected %zu, got %zu", map.GetCount(), countLoaded );
+ return E_INVALIDARG;
+ }
+
+ constexpr double mulMb = 1.0 / ( 1 << 20 );
+ logDebug( u8"Loaded %zu GPU tensors, %g MB VRAM", countLoaded, mulMb * cb );
+
+ CHECK( loader.completeLoad( stm, callbacks ) );
+ return S_OK;
+}
+#endif
+
+HRESULT WhisperModel::load( ComLight::iReadStream* stm, bool hybrid, const sLoadModelCallbacks* callbacks )
+{
+ CpuProfiler cpuPerf;
+ CallbacksImpl cb;
+ CHECK( cb.initialize( stm, callbacks ) );
+ // verify magic
+ {
+ uint32_t magic;
+ CHECK( readStruct( stm, magic ) );
+ if( magic != 0x67676d6c )
+ {
+ logError( u8"Invalid model file, bad magic" );
+ return E_INVALIDARG;
+ }
+ }
+
+ // hparams and MEL filters
+ {
+ ParamsAndMelHeader pmh;
+ CHECK( readStruct( stm, pmh ) );
+ parameters = pmh.mp;
+ assert( parameters.n_text_state == parameters.n_audio_state );
+
+ filters.n_mel = pmh.n_mel;
+ filters.n_fft = pmh.n_fft;
+ const size_t len = (size_t)filters.n_mel * filters.n_fft;
+ filters.data.resize( len );
+ CHECK( readBytes( stm, filters.data.data(), len * 4 ) );
+
+ const int64_t cb = len * 4;
+ constexpr double mulKb = 1.0 / ( 1 << 10 );
+ logDebug( u8"Loaded MEL filters, %.1f kb RAM", mulKb * cb );
+ }
+ CHECK( cb.call( stm ) );
+
+ // Vocabulary
+ CHECK( vocab.load( stm, parameters.n_vocab ) );
+ CHECK( cb.call( stm ) );
+
+ DirectCompute::GpuProfilerSimple gpuProfiler;
+ CHECK( gpuProfiler.create() );
+
+ if( hybrid )
+ {
+#if BUILD_HYBRID_VERSION
+ CHECK( loadHybrid( stm, cb ) )
+#else
+ return E_NOTIMPL;
+#endif
+ }
+ else
+ CHECK( loadGpu( stm, cb ) );
+
+ CHECK( gpuProfiler.time( loadTimeGpu ) );
+ loadTimeCpu = cpuPerf.elapsed();
+ return S_OK;
+}
+
+__m128i Whisper::WhisperModel::getMemoryUse() const
+{
+ size_t cb = vocab.getMemoryUse();
+ cb += vectorMemoryUse( filters.data );
+ __m128i v = _mm_cvtsi64_si128( (int64_t)cb );
+ v = _mm_add_epi64( v, tensors.getMemoryUse() );
+ return v;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/WhisperModel.h b/Whisper/Whisper/WhisperModel.h
new file mode 100644
index 0000000..c9b72aa
--- /dev/null
+++ b/Whisper/Whisper/WhisperModel.h
@@ -0,0 +1,54 @@
+#pragma once
+#include "Vocabulary.h"
+#include "ModelBuffers.h"
+#include "../../ComLightLib/streams.h"
+#include "../CPU/DecoderTensors.h"
+#include "../API/sLoadModelCallbacks.h"
+#include "sModelParams.h"
+
+namespace Whisper
+{
+ struct Filters
+ {
+ uint32_t n_mel;
+ uint32_t n_fft;
+ std::vector<float> data;
+ };
+
+ // The complete model, as loaded from a GGML binary file.
+ // The entire model is immutable, and can be safely used from multiple threads in parallel.
+ // The tensors are uploaded to VRAM and don’t stay in system memory, everything else is in the system RAM.
+ struct WhisperModel
+ {
+ sModelParams parameters;
+ Vocabulary vocab;
+ Filters filters;
+ DirectCompute::ModelBuffers tensors;
+
+#if BUILD_HYBRID_VERSION
+ CpuCompute::DecoderTensors hybridTensors;
+#endif
+
+ HRESULT load( ComLight::iReadStream* stm, bool hybrid, const sLoadModelCallbacks* callbacks );
+
+ // A vector of 2 uint64_t values, both numbers are 100 nanosecond ticks:
+ // 0. The time it took to load the model, measured on CPU
+ // 1. The time it took to upload all these tensors to VRAM, measured on GPU
+ __m128i getLoadTimes() const
+ {
+ static_assert( offsetof( WhisperModel, loadTimeCpu ) + 8 == offsetof( WhisperModel, loadTimeGpu ) );
+ return _mm_loadu_si128( ( const __m128i* )( &loadTimeCpu ) );
+ }
+
+ __m128i getMemoryUse() const;
+
+ private:
+ uint64_t loadTimeCpu = 0;
+ uint64_t loadTimeGpu = 0;
+
+ class CallbacksImpl;
+
+ HRESULT loadGpu( ComLight::iReadStream* stm, CallbacksImpl& callbacks );
+ HRESULT loadHybrid( ComLight::iReadStream* stm, CallbacksImpl& callbacks );
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/audioConstants.h b/Whisper/Whisper/audioConstants.h
new file mode 100644
index 0000000..232ab22
--- /dev/null
+++ b/Whisper/Whisper/audioConstants.h
@@ -0,0 +1,14 @@
+#pragma once
+#include <stdint.h>
+
+namespace Whisper
+{
+ // WHISPER_SAMPLE_RATE, 16 kHz
+ constexpr uint32_t SAMPLE_RATE = 16000;
+ // WHISPER_N_FFT, 25 milliseconds
+ constexpr uint32_t FFT_SIZE = 400;
+ // WHISPER_HOP_LENGTH, 10 milliseconds
+ constexpr uint32_t FFT_STEP = 160;
+ // WHISPER_N_MEL
+ constexpr uint32_t N_MEL = 80;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/iSpectrogram.h b/Whisper/Whisper/iSpectrogram.h
new file mode 100644
index 0000000..2e3199d
--- /dev/null
+++ b/Whisper/Whisper/iSpectrogram.h
@@ -0,0 +1,38 @@
+#pragma once
+#include "audioConstants.h"
+
+namespace Whisper
+{
+ __interface iSpectrogram
+ {
+ // Make a buffer with length * N_MEL floats, starting at the specified offset
+ // An implementation of this interface may visualize the spectrogram, making pieces on demand
+ HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride );
+
+ // Apparently, the length unit is 160 input samples = 10 milliseconds of audio
+ size_t getLength() const;
+ };
+
+ // RAII class to deal with iSpectrogram's makeBuffer method.
+ // Throws exceptions when things fail.
+ class MelBufferRaii
+ {
+ const float* pointer;
+ size_t stride;
+ public:
+
+ HRESULT make( iSpectrogram& mel, size_t off, size_t len )
+ {
+ return mel.makeBuffer( off, len, &pointer, stride );
+ }
+
+ const float* operator[]( size_t idx ) const
+ {
+ assert( idx < N_MEL );
+ return pointer + idx * stride;
+ }
+
+ const BYTE* bytePtr() const { return (const BYTE*)pointer; }
+ LONG strideBytes() const { return (LONG)stride * 4; }
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/languageCodez.inl b/Whisper/Whisper/languageCodez.inl
new file mode 100644
index 0000000..c302769
--- /dev/null
+++ b/Whisper/Whisper/languageCodez.inl
@@ -0,0 +1,100 @@
+// This file is generated by a tool, from the `languageCodez.tsv` file in this repository
+Lang{ 0x6661, 68, "afrikaans" },
+Lang{ 0x7173, 58, "albanian" },
+Lang{ 0x6D61, 75, "amharic" },
+Lang{ 0x7261, 13, "arabic" },
+Lang{ 0x7968, 53, "armenian" },
+Lang{ 0x7361, 91, "assamese" },
+Lang{ 0x7A61, 45, "azerbaijani" },
+Lang{ 0x6162, 96, "bashkir" },
+Lang{ 0x7565, 51, "basque" },
+Lang{ 0x6562, 71, "belarusian" },
+Lang{ 0x6E62, 43, "bengali" },
+Lang{ 0x7362, 56, "bosnian" },
+Lang{ 0x7262, 50, "breton" },
+Lang{ 0x6762, 33, "bulgarian" },
+Lang{ 0x6163, 11, "catalan" },
+Lang{ 0x687A, 1, "chinese" },
+Lang{ 0x7268, 32, "croatian" },
+Lang{ 0x7363, 24, "czech" },
+Lang{ 0x6164, 26, "danish" },
+Lang{ 0x6C6E, 12, "dutch" },
+Lang{ 0x6E65, 0, "english" },
+Lang{ 0x7465, 48, "estonian" },
+Lang{ 0x6F66, 79, "faroese" },
+Lang{ 0x6966, 18, "finnish" },
+Lang{ 0x7266, 6, "french" },
+Lang{ 0x6C67, 60, "galician" },
+Lang{ 0x616B, 70, "georgian" },
+Lang{ 0x6564, 2, "german" },
+Lang{ 0x6C65, 22, "greek" },
+Lang{ 0x7567, 74, "gujarati" },
+Lang{ 0x7468, 80, "haitian creole" },
+Lang{ 0x6168, 95, "hausa" },
+Lang{ 0x776168, 93, "hawaiian" },
+Lang{ 0x7769, 20, "hebrew" },
+Lang{ 0x6968, 17, "hindi" },
+Lang{ 0x7568, 27, "hungarian" },
+Lang{ 0x7369, 52, "icelandic" },
+Lang{ 0x6469, 16, "indonesian" },
+Lang{ 0x7469, 15, "italian" },
+Lang{ 0x616A, 7, "japanese" },
+Lang{ 0x776A, 97, "javanese" },
+Lang{ 0x6E6B, 47, "kannada" },
+Lang{ 0x6B6B, 57, "kazakh" },
+Lang{ 0x6D6B, 64, "khmer" },
+Lang{ 0x6F6B, 5, "korean" },
+Lang{ 0x6F6C, 77, "lao" },
+Lang{ 0x616C, 35, "latin" },
+Lang{ 0x766C, 42, "latvian" },
+Lang{ 0x6E6C, 94, "lingala" },
+Lang{ 0x746C, 34, "lithuanian" },
+Lang{ 0x626C, 86, "luxembourgish" },
+Lang{ 0x6B6D, 49, "macedonian" },
+Lang{ 0x676D, 90, "malagasy" },
+Lang{ 0x736D, 23, "malay" },
+Lang{ 0x6C6D, 37, "malayalam" },
+Lang{ 0x746D, 84, "maltese" },
+Lang{ 0x696D, 36, "maori" },
+Lang{ 0x726D, 61, "marathi" },
+Lang{ 0x6E6D, 55, "mongolian" },
+Lang{ 0x796D, 87, "myanmar" },
+Lang{ 0x656E, 54, "nepali" },
+Lang{ 0x6F6E, 29, "norwegian" },
+Lang{ 0x6E6E, 83, "nynorsk" },
+Lang{ 0x636F, 69, "occitan" },
+Lang{ 0x7370, 81, "pashto" },
+Lang{ 0x6166, 41, "persian" },
+Lang{ 0x6C70, 10, "polish" },
+Lang{ 0x7470, 8, "portuguese" },
+Lang{ 0x6170, 62, "punjabi" },
+Lang{ 0x6F72, 25, "romanian" },
+Lang{ 0x7572, 4, "russian" },
+Lang{ 0x6173, 85, "sanskrit" },
+Lang{ 0x7273, 44, "serbian" },
+Lang{ 0x6E73, 65, "shona" },
+Lang{ 0x6473, 73, "sindhi" },
+Lang{ 0x6973, 63, "sinhala" },
+Lang{ 0x6B73, 39, "slovak" },
+Lang{ 0x6C73, 46, "slovenian" },
+Lang{ 0x6F73, 67, "somali" },
+Lang{ 0x7365, 3, "spanish" },
+Lang{ 0x7573, 98, "sundanese" },
+Lang{ 0x7773, 59, "swahili" },
+Lang{ 0x7673, 14, "swedish" },
+Lang{ 0x6C74, 89, "tagalog" },
+Lang{ 0x6774, 72, "tajik" },
+Lang{ 0x6174, 28, "tamil" },
+Lang{ 0x7474, 92, "tatar" },
+Lang{ 0x6574, 40, "telugu" },
+Lang{ 0x6874, 30, "thai" },
+Lang{ 0x6F62, 88, "tibetan" },
+Lang{ 0x7274, 9, "turkish" },
+Lang{ 0x6B74, 82, "turkmen" },
+Lang{ 0x6B75, 21, "ukrainian" },
+Lang{ 0x7275, 31, "urdu" },
+Lang{ 0x7A75, 78, "uzbek" },
+Lang{ 0x6976, 19, "vietnamese" },
+Lang{ 0x7963, 38, "welsh" },
+Lang{ 0x6979, 76, "yiddish" },
+Lang{ 0x6F79, 66, "yoruba" },
diff --git a/Whisper/Whisper/languageCodez.tsv b/Whisper/Whisper/languageCodez.tsv
new file mode 100644
index 0000000..4048618
--- /dev/null
+++ b/Whisper/Whisper/languageCodez.tsv
@@ -0,0 +1,99 @@
+en 0 english
+zh 1 chinese
+de 2 german
+es 3 spanish
+ru 4 russian
+ko 5 korean
+fr 6 french
+ja 7 japanese
+pt 8 portuguese
+tr 9 turkish
+pl 10 polish
+ca 11 catalan
+nl 12 dutch
+ar 13 arabic
+sv 14 swedish
+it 15 italian
+id 16 indonesian
+hi 17 hindi
+fi 18 finnish
+vi 19 vietnamese
+iw 20 hebrew
+uk 21 ukrainian
+el 22 greek
+ms 23 malay
+cs 24 czech
+ro 25 romanian
+da 26 danish
+hu 27 hungarian
+ta 28 tamil
+no 29 norwegian
+th 30 thai
+ur 31 urdu
+hr 32 croatian
+bg 33 bulgarian
+lt 34 lithuanian
+la 35 latin
+mi 36 maori
+ml 37 malayalam
+cy 38 welsh
+sk 39 slovak
+te 40 telugu
+fa 41 persian
+lv 42 latvian
+bn 43 bengali
+sr 44 serbian
+az 45 azerbaijani
+sl 46 slovenian
+kn 47 kannada
+et 48 estonian
+mk 49 macedonian
+br 50 breton
+eu 51 basque
+is 52 icelandic
+hy 53 armenian
+ne 54 nepali
+mn 55 mongolian
+bs 56 bosnian
+kk 57 kazakh
+sq 58 albanian
+sw 59 swahili
+gl 60 galician
+mr 61 marathi
+pa 62 punjabi
+si 63 sinhala
+km 64 khmer
+sn 65 shona
+yo 66 yoruba
+so 67 somali
+af 68 afrikaans
+oc 69 occitan
+ka 70 georgian
+be 71 belarusian
+tg 72 tajik
+sd 73 sindhi
+gu 74 gujarati
+am 75 amharic
+yi 76 yiddish
+lo 77 lao
+uz 78 uzbek
+fo 79 faroese
+ht 80 haitian creole
+ps 81 pashto
+tk 82 turkmen
+nn 83 nynorsk
+mt 84 maltese
+sa 85 sanskrit
+lb 86 luxembourgish
+my 87 myanmar
+bo 88 tibetan
+tl 89 tagalog
+mg 90 malagasy
+as 91 assamese
+tt 92 tatar
+haw 93 hawaiian
+ln 94 lingala
+ha 95 hausa
+ba 96 bashkir
+jw 97 javanese
+su 98 sundanese \ No newline at end of file
diff --git a/Whisper/Whisper/loaderUtils.h b/Whisper/Whisper/loaderUtils.h
new file mode 100644
index 0000000..650556d
--- /dev/null
+++ b/Whisper/Whisper/loaderUtils.h
@@ -0,0 +1,24 @@
+#pragma once
+#include "../../ComLightLib/streams.h"
+
+namespace Whisper
+{
+ inline HRESULT readBytes( ComLight::iReadStream* stm, void* rdi, size_t cb )
+ {
+ if( cb > INT_MAX )
+ return DISP_E_OVERFLOW;
+ if( cb == 0 )
+ return S_FALSE;
+ int n;
+ CHECK( stm->read( rdi, (int)cb, n ) );
+ if( n != (int)cb )
+ return E_EOF;
+ return S_OK;
+ }
+
+ template<typename T>
+ inline HRESULT readStruct( ComLight::iReadStream* stm, T& dest )
+ {
+ return readBytes( stm, &dest, sizeof( T ) );
+ }
+} \ No newline at end of file
diff --git a/Whisper/Whisper/melSpectrogram.cpp b/Whisper/Whisper/melSpectrogram.cpp
new file mode 100644
index 0000000..f297557
--- /dev/null
+++ b/Whisper/Whisper/melSpectrogram.cpp
@@ -0,0 +1,298 @@
+#include "stdafx.h"
+#include <cmath>
+#include "melSpectrogram.h"
+
+namespace Whisper
+{
+ HanningWindow::HanningWindow()
+ {
+ for( int i = 0; i < FFT_SIZE; i++ )
+ {
+ // TODO [low]: use XMVectorCos instead
+ hann[ i ] = (float)( 0.5 * ( 1.0 - std::cos( ( 2.0 * M_PI * i ) / ( FFT_SIZE ) ) ) );
+ }
+ }
+ const HanningWindow s_hanning;
+}
+
+namespace
+{
+ using namespace Whisper;
+
+ uint32_t tempVectorSizeRecursion( uint32_t len )
+ {
+ // out.resize( in.size() * 2 );
+ const uint32_t res = len * 2;
+ if( len == 1 )
+ return res;
+ if( len % 2 == 1 )
+ return res; // dft
+
+ const uint32_t even = ( len + 1 ) / 2;
+ const uint32_t odd = len / 2;
+ const uint32_t evenFft = tempVectorSizeRecursion( even );
+ const uint32_t oddFft = tempVectorSizeRecursion( odd );
+ return res + even + odd + evenFft + oddFft;
+ }
+
+ // 6000
+ // const uint32_t tempBufferSize = FFT_SIZE + tempVectorSizeRecursion( FFT_SIZE );
+ constexpr uint32_t tempBufferSize = 6000;
+
+ // naive Discrete Fourier Transform
+ // input is real-valued
+ // output is complex-valued
+ inline void dft( const float* rsi, size_t len, float* rdi )
+ {
+ for( size_t k = 0; k < len; k++ )
+ {
+ float re = 0;
+ float im = 0;
+
+ for( int n = 0; n < len; n++ )
+ {
+ float angle = (float)( 2 * M_PI * (int)k * n / len );
+ re += (float)( rsi[ n ] * std::cosf( angle ) );
+ im -= (float)( rsi[ n ] * std::sinf( angle ) );
+ }
+
+ rdi[ k * 2 + 0 ] = re;
+ rdi[ k * 2 + 1 ] = im;
+ }
+ }
+
+ inline void splitEvenOdd( const float* rsi, size_t len, float* rdiEven, float* rdiOdd )
+ {
+ const float* const rsiEndAligned = rsi + ( len & ( ~(size_t)7 ) );
+ const size_t rem = len % 8;
+
+ for( ; rsi < rsiEndAligned; rsi += 8, rdiEven += 4, rdiOdd += 4 )
+ {
+ const __m128 v1 = _mm_loadu_ps( rsi );
+ const __m128 v2 = _mm_loadu_ps( rsi + 4 );
+ const __m128 e = _mm_shuffle_ps( v1, v2, _MM_SHUFFLE( 2, 0, 2, 0 ) );
+ const __m128 o = _mm_shuffle_ps( v1, v2, _MM_SHUFFLE( 3, 1, 3, 1 ) );
+ _mm_storeu_ps( rdiEven, e );
+ _mm_storeu_ps( rdiOdd, o );
+ }
+
+#pragma loop( no_vector )
+ for( size_t i = 0; i < rem; i++, rsi++ )
+ {
+ if( i % 2 == 0 )
+ {
+ *rdiEven = *rsi;
+ rdiEven++;
+ }
+ else
+ {
+ *rdiOdd = *rsi;
+ rdiOdd++;
+ }
+ }
+ }
+ inline __m128 set2( float f )
+ {
+ __m128 v = _mm_set_ss( f );
+ return _mm_moveldup_ps( v );
+ }
+ inline __m128 load2( const float* rsi )
+ {
+ return _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) );
+ }
+ inline void store2( float* rdi, __m128 vec )
+ {
+ _mm_store_sd( (double*)rdi, _mm_castps_pd( vec ) );
+ }
+ // [ x, y ] => [ x, y, x, y ]
+ inline __m128 dup2( __m128 x )
+ {
+ __m128d v = _mm_castps_pd( x );
+ v = _mm_movedup_pd( v );
+ return _mm_castpd_ps( v );
+ }
+ inline __m128 load2dup( const float* rsi )
+ {
+ return _mm_castpd_ps( _mm_loaddup_pd( (const double*)rsi ) );
+ }
+ inline void store2high( float* rdi, __m128 vec )
+ {
+ _mm_storeh_pd( (double*)rdi, _mm_castps_pd( vec ) );
+ }
+}
+
+using namespace Whisper;
+
+SpectrogramContext::SpectrogramContext( const Filters& flt ) :
+ filters( flt )
+{
+ assert( tempBufferSize == FFT_SIZE + tempVectorSizeRecursion( FFT_SIZE ) );
+ tempBuffer = std::make_unique<float[]>( tempBufferSize );
+}
+
+// Cooley-Tukey FFT
+// poor man's implementation - use something better
+// input is real-valued
+// output is complex-valued
+float* SpectrogramContext::fftRecursion( float* temp, const float* const rsi, const size_t len )
+{
+ float* const out = temp;
+ temp += len * 2;
+ if( len == 1 )
+ {
+ out[ 0 ] = rsi[ 0 ];
+ out[ 1 ] = 0;
+ return temp;
+ }
+
+ if( len % 2 == 1 )
+ {
+ dft( rsi, len, out );
+ return temp;
+ }
+
+ const size_t lenEven = ( len + 1 ) / 2;
+ const size_t lenOdd = len / 2;
+ float* const even = temp;
+ temp += lenEven;
+
+ float* const odd = temp;
+ temp += lenOdd;
+ splitEvenOdd( rsi, len, even, odd );
+
+ const float* const evenFft = temp;
+ temp = fftRecursion( temp, even, lenEven );
+
+ const float* const oddFft = temp;
+ temp = fftRecursion( temp, odd, lenOdd );
+
+ const size_t N = len;
+ const __m128 maskNegateHigh = _mm_setr_ps( 0, 0, -0.0f, -0.0f );
+ for( size_t k = 0; k < N / 2; k++ )
+ {
+ const float theta = (float)( 2 * M_PI * (double)(int)k / N );
+
+ /*
+ const float re = std::cosf( theta );
+ const float im = -std::sinf( theta );
+
+ float re_odd = oddFft[ 2 * k + 0 ];
+ float im_odd = oddFft[ 2 * k + 1 ];
+
+ out[ 2 * k + 0 ] = evenFft[ 2 * k + 0 ] + re * re_odd - im * im_odd;
+ out[ 2 * k + 1 ] = evenFft[ 2 * k + 1 ] + re * im_odd + im * re_odd;
+
+ out[ 2 * ( k + N / 2 ) + 0 ] = evenFft[ 2 * k + 0 ] - re * re_odd + im * im_odd;
+ out[ 2 * ( k + N / 2 ) + 1 ] = evenFft[ 2 * k + 1 ] - re * im_odd - im * re_odd;
+ */
+
+ const __m128 re = _mm_set_ss( std::cosf( theta ) );
+ const __m128 im = _mm_set_ss( std::sinf( theta ) );
+ __m128 reIm = _mm_shuffle_ps( re, im, _MM_SHUFFLE( 0, 0, 0, 0 ) );
+ // [ re, re, im, im ]
+ reIm = _mm_xor_ps( reIm, maskNegateHigh );
+
+ // [ re_odd, im_odd ]
+ __m128 odd = load2( oddFft + 2 * k );
+ // [ re_odd, im_odd, im_odd, re_odd ]
+ odd = _mm_shuffle_ps( odd, odd, _MM_SHUFFLE( 0, 1, 1, 0 ) );
+
+ // re_odd * re, im_odd * re, im_odd * im, re_odd * im ]
+ const __m128 products4 = _mm_mul_ps( reIm, odd );
+
+ // re_odd * re, im_odd * re, re_odd * re, im_odd * re
+ __m128 prod1 = dup2( products4 );
+ // im_odd * im, re_odd * im, im_odd * im, re_odd * im
+ __m128 prod2 = _mm_movehl_ps( products4, products4 );
+
+ // re_odd * re, im_odd * re, -re_odd * re, -im_odd * re
+ prod1 = _mm_xor_ps( prod1, maskNegateHigh );
+ // im_odd * im, re_odd * im, -im_odd * im, -re_odd * im
+ prod2 = _mm_xor_ps( prod2, maskNegateHigh );
+
+ const __m128 even = load2dup( evenFft + 2 * k );
+ __m128 res;
+ res = _mm_add_ps( even, prod1 );
+ res = _mm_addsub_ps( res, prod2 );
+ store2( out + 2 * k, res );
+ store2high( out + 2 * ( k + N / 2 ), res );
+ }
+
+ return temp;
+}
+
+void SpectrogramContext::fft( std::array<float, N_MEL>& rdi, const float* pcm, size_t length )
+{
+ assert( length > 0 );
+ length = std::min( length, (size_t)FFT_SIZE );
+
+ float* const temp = tempBuffer.get();
+ // Apply Hanning window
+ for( size_t i = 0; i < length; i++ )
+ temp[ i ] = pcm[ i ] * s_hanning[ i ];
+ if( length < FFT_SIZE )
+ memset( temp + length, 0, ( FFT_SIZE - length ) * 4 );
+
+ float* const fftOut = temp + FFT_SIZE;
+ float* bufferEnd = fftRecursion( fftOut, temp, FFT_SIZE );
+ assert( bufferEnd == tempBuffer.get() + tempBufferSize );
+
+ // for( size_t j = 0; j < FFT_SIZE; j++ )
+ // fft_out[ j ] = ( fft_out[ 2 * j + 0 ] * fft_out[ 2 * j + 0 ] + fft_out[ 2 * j + 1 ] * fft_out[ 2 * j + 1 ] );
+ for( size_t j = 0; j < 4; j++ )
+ {
+ __m128 tmp = load2( fftOut + 2 * j );
+ tmp = _mm_mul_ps( tmp, tmp );
+ tmp = _mm_add_ss( tmp, _mm_movehdup_ps( tmp ) );
+ _mm_store_ss( fftOut + j, tmp );
+ }
+ for( size_t j = 4; j < FFT_SIZE; j += 4 )
+ {
+ __m128 low = _mm_loadu_ps( fftOut + 2 * j );
+ __m128 high = _mm_loadu_ps( fftOut + 2 * j + 4 );
+ low = _mm_mul_ps( low, low );
+ high = _mm_mul_ps( high, high );
+ __m128 res = _mm_hadd_ps( low, high );
+ _mm_storeu_ps( fftOut + j, res );
+ }
+
+ // for( size_t j = 1; j < FFT_SIZE / 2; j++ )
+ // fftOut[ j ] += fftOut[ FFT_SIZE - j ];
+ for( size_t j = 1; j < 4; j++ )
+ fftOut[ j ] += fftOut[ FFT_SIZE - j ];
+ for( size_t j = 4; j < FFT_SIZE / 2; j += 4 )
+ {
+ __m128 curr = _mm_loadu_ps( fftOut + j );
+ // Too bad _mm_loadr_ps requires alignment
+ __m128 high = _mm_loadu_ps( fftOut + ( FFT_SIZE - 3 ) - j );
+ high = _mm_shuffle_ps( high, high, _MM_SHUFFLE( 0, 1, 2, 3 ) );
+ curr = _mm_add_ps( curr, high );
+ _mm_storeu_ps( fftOut + j, curr );
+ }
+
+ constexpr size_t n_fft = 1 + ( FFT_SIZE / 2 );
+
+ // mel spectrogram
+ for( size_t j = 0; j < N_MEL; j++ )
+ {
+ double sum = 0.0;
+ for( size_t k = 0; k < n_fft; k++ )
+ sum += fftOut[ k ] * filters.data[ j * n_fft + k ];
+ if( sum < 1e-10 )
+ sum = 1e-10;
+ sum = log10( sum );
+ rdi[ j ] = (float)sum;
+ }
+
+ /*
+ const float* ptr = rdi.data();
+ const float* const ptrEnd = ptr + rdi.size();
+ static_assert( 0 == N_MEL % 4 );
+ __m128 ax = _mm_loadu_ps( ptr );
+ for( ptr += 4; ptr < ptrEnd; ptr += 4 )
+ ax = _mm_max_ps( ax, _mm_loadu_ps( ptr ) );
+ ax = _mm_max_ps( ax, _mm_movehl_ps( ax, ax ) );
+ ax = _mm_max_ss( ax, _mm_movehdup_ps( ax ) );
+ return _mm_cvtss_f32( ax );
+ */
+} \ No newline at end of file
diff --git a/Whisper/Whisper/melSpectrogram.h b/Whisper/Whisper/melSpectrogram.h
new file mode 100644
index 0000000..7e66f06
--- /dev/null
+++ b/Whisper/Whisper/melSpectrogram.h
@@ -0,0 +1,34 @@
+#pragma once
+#include "audioConstants.h"
+#include "WhisperModel.h"
+#include <memory>
+
+namespace Whisper
+{
+ class HanningWindow
+ {
+ std::array<float, FFT_SIZE> hann;
+ public:
+ HanningWindow();
+
+ float operator[]( size_t i ) const
+ {
+ return hann[ i ];
+ }
+ };
+
+ extern const HanningWindow s_hanning;
+
+ class SpectrogramContext
+ {
+ const Filters& filters;
+ static float* fftRecursion( float* temp, const float* const rsi, const size_t len );
+ std::unique_ptr<float[]> tempBuffer;
+
+ public:
+ SpectrogramContext( const Filters& flt );
+
+ // First step of the MEL algorithm, and recursively compute the FFT
+ void fft( std::array<float, N_MEL>& rdi, const float* pcm, size_t length );
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/sEncodeParams.h b/Whisper/Whisper/sEncodeParams.h
new file mode 100644
index 0000000..428ef0d
--- /dev/null
+++ b/Whisper/Whisper/sEncodeParams.h
@@ -0,0 +1,20 @@
+#pragma once
+#include <stdint.h>
+
+namespace DirectCompute
+{
+ struct sEncodeParams
+ {
+ uint32_t n_ctx, n_mels, mel_offset;
+ uint32_t layersCount, n_state, n_head;
+ uint32_t n_audio_ctx, n_text_state, n_text_layer, n_text_ctx;
+ };
+
+ struct sDecodeParams
+ {
+ uint32_t n_state, n_head;
+ uint32_t n_ctx, n_past, M;
+ uint32_t n_text_layer;
+ uint32_t n_vocab;
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/sModelParams.h b/Whisper/Whisper/sModelParams.h
new file mode 100644
index 0000000..0419542
--- /dev/null
+++ b/Whisper/Whisper/sModelParams.h
@@ -0,0 +1,19 @@
+#pragma once
+namespace Whisper
+{
+ // default hparams (Whisper tiny)
+ struct sModelParams
+ {
+ int n_vocab = 51864;
+ int n_audio_ctx = 1500;
+ int n_audio_state = 384;
+ int n_audio_head = 6;
+ int n_audio_layer = 4;
+ int n_text_ctx = 448;
+ int n_text_state = 384;
+ int n_text_head = 6;
+ int n_text_layer = 4;
+ int n_mels = 80;
+ int f16 = 1;
+ };
+} \ No newline at end of file
diff --git a/Whisper/Whisper/sTokenData.h b/Whisper/Whisper/sTokenData.h
new file mode 100644
index 0000000..3a50a92
--- /dev/null
+++ b/Whisper/Whisper/sTokenData.h
@@ -0,0 +1,23 @@
+#pragma once
+#include <stdint.h>
+
+namespace Whisper
+{
+ using whisper_token = int;
+
+ struct sTokenData
+ {
+ whisper_token id; // token id
+ whisper_token tid; // forced timestamp token id
+ float p; // probability of the token
+ float pt; // probability of the timestamp token
+ float ptsum; // sum of probabilities of all timestamp tokens
+ float vlen; // voice length of the token
+ // token-level timestamp data
+ // do not use if you haven't computed token-level timestamps
+ int64_t t0; // start time of the token
+ int64_t t1; // end time of the token
+ };
+
+
+} \ No newline at end of file
diff --git a/Whisper/Whisper/voiceActivityDetection.cpp b/Whisper/Whisper/voiceActivityDetection.cpp
new file mode 100644
index 0000000..31763ec
--- /dev/null
+++ b/Whisper/Whisper/voiceActivityDetection.cpp
@@ -0,0 +1,199 @@
+#include "stdafx.h"
+#include "voiceActivityDetection.h"
+#include <DirectXMath.h>
+using namespace Whisper;
+
+// Initially ported (poorly) from there https://github.com/panmasuo/voice-activity-detection MIT license
+// The code is based on that article:
+// https://www.researchgate.net/publication/255667085_A_simple_but_efficient_real-time_voice_activity_detection_algorithm
+
+inline VAD::Feature VAD::defaultPrimaryThresholds()
+{
+ Feature f;
+ f.energy = 40;
+ f.F = 185;
+ f.SFM = 5;
+ return f;
+}
+
+VAD::VAD() :
+ primThresh( defaultPrimaryThresholds() )
+{
+ fft_signal = std::make_unique<cplx[]>( FFT_POINTS );
+}
+
+inline void VAD::fft( cplx* buf, cplx* out, size_t n, size_t step )
+{
+ if( step < n )
+ {
+ fft( out, buf, n, step * 2 );
+ fft( out + step, buf + step, n, step * 2 );
+
+ for( size_t i = 0; i < n; i += 2 * step )
+ {
+ // cplx t = cexp(-I * M_PI * i / n) * out[i + step];
+ // using namespace std::complex_literals;
+ const float mul = (float)M_PI * (float)(int)i / (float)(int)n;
+ // cplx t0 = std::exp( -1.0if * mul ) * out[ i + step ];
+ float sine, cosine;
+ DirectX::XMScalarSinCos( &sine, &cosine, -mul );
+ const cplx exponent{ cosine, sine };
+ cplx t = exponent * out[ i + step ];
+
+ buf[ i / 2 ] = out[ i ] + t;
+ buf[ ( i + n ) / 2 ] = out[ i ] - t;
+ }
+ }
+}
+
+void VAD::fft() const
+{
+ cplx out[ FFT_POINTS ];
+ memcpy( &out[ 0 ], fft_signal.get(), FFT_POINTS * sizeof( cplx ) );
+
+ fft( fft_signal.get(), out, FFT_POINTS, 1 );
+}
+
+constexpr float mulInt16FromFloat = 32768.0;
+
+float VAD::computeEnergy( const float* rsi )
+{
+ // calculate_energy
+ double sum = 0;
+ for( size_t i = 0; i < FFT_POINTS; i++ )
+ {
+ float f = rsi[ i ];
+ f *= mulInt16FromFloat;
+ f *= f;
+ sum += f;
+ }
+ return std::sqrtf( (float)( sum * ( 1.0 / FFT_POINTS ) ) );
+}
+
+float VAD::computeDominant( const cplx* spectrum )
+{
+ // calculate_dominant, reworked heavily
+ float maxMagSquared = 0;
+ int maxFreq = 0;
+
+ for( int i = 0; i < FFT_POINTS / 2; i++ )
+ {
+ const float real = (float)spectrum[ i ].real();
+ const float imag = (float)spectrum[ i ].imag();
+ float sq = real * real + imag * imag;
+ if( sq <= maxMagSquared )
+ continue;
+ maxMagSquared = sq;
+ maxFreq = i;
+ }
+ return (float)maxFreq * FFT_STEP;
+}
+
+float VAD::computreSpectralFlatnessMeasure( const cplx* spectrum )
+{
+ // calculate_sfm
+ double sum_ari = 0;
+ double sum_geo = 0;
+ for( size_t i = 0; i < FFT_POINTS; i++ )
+ {
+ // sig = cabsf( spectrum[ i ] );
+ float sig = std::abs( spectrum[ i ] );
+ sum_ari += sig;
+ sum_geo += std::log( sig );
+ }
+ sum_ari = sum_ari / FFT_POINTS;
+ sum_geo = std::exp( sum_geo / FFT_POINTS );
+ return -10.0f * std::log10f( (float)( sum_geo / sum_ari ) );
+}
+
+void VAD::clear()
+{
+ memset( &state, 0, sizeof( State ) );
+ state.currThresh = primThresh;
+}
+
+size_t VAD::detect( const float* rsi, size_t length )
+{
+ // The cryptic numbers in the comments are from section 3 "Proposed VAD Algorithm" of the article, on page 2550, on the right
+ const size_t frames = length / FFT_POINTS;
+ if( frames <= 0 )
+ {
+ clear();
+ return 0;
+ }
+
+ // Load detection state from the field
+ Feature currThresh = state.currThresh;
+ Feature minFeature = state.minFeature;
+ Feature curr = state.curr;
+
+ size_t lastSpeech = state.lastSpeech;
+ float silenceRun = state.silenceRun;
+ size_t i = state.i;
+
+ // Run the loop just on the [ state.i .. frames ] slice of the input PCM
+ rsi += i * FFT_POINTS;
+ for( ; i < frames; i++, rsi += FFT_POINTS )
+ {
+ // 3-2 calculate FFT
+ for( size_t j = 0; j < FFT_POINTS; j++ )
+ {
+ const float re = rsi[ j ] * mulInt16FromFloat;
+ fft_signal[ j ] = { re, 0.0f };
+ }
+ fft();
+
+ // 3-1 + 3-2 calculate features
+ curr.energy = computeEnergy( rsi );
+ curr.F = computeDominant( fft_signal.get() );
+ curr.SFM = computreSpectralFlatnessMeasure( fft_signal.get() );
+
+ // 3-3 calculate minimum value for first 30 frames
+ if( i == 0 )
+ minFeature = curr;
+ else if( i < 30 )
+ {
+ minFeature.energy = std::min( minFeature.energy, curr.energy );
+ minFeature.F = std::min( minFeature.F, curr.F );
+ minFeature.SFM = std::min( minFeature.SFM, curr.SFM );
+ }
+
+ // 3-4 set thresholds
+ currThresh.energy = primThresh.energy * std::log10f( minFeature.energy );
+
+ // 3-5 calculate decision
+ uint8_t counter = 0;
+ if( ( curr.energy - minFeature.energy ) >= currThresh.energy )
+ counter = 1;
+ if( ( curr.F - minFeature.F ) >= currThresh.F )
+ counter++;
+ if( ( curr.SFM - minFeature.SFM ) >= currThresh.SFM )
+ counter++;
+
+ if( counter > 1 )
+ {
+ // 3-6 If counter > 1 mark the current frame as speech
+ lastSpeech = ( i + 1 ) * FFT_POINTS;
+ silenceRun = 0.0f;
+ }
+ else
+ {
+ silenceRun += 1.0f;
+ // 3-7 If current frame is marked as silence, update the energy minimum value
+ minFeature.energy = ( ( silenceRun * minFeature.energy ) + curr.energy ) / ( silenceRun + 1 );
+ }
+
+ // 3-8
+ currThresh.energy = primThresh.energy * std::log10f( minFeature.energy );
+ }
+
+ // Store the updated detection state back into that field
+ state.currThresh = currThresh;
+ state.minFeature = minFeature;
+ state.curr = curr;
+ state.lastSpeech = (uint32_t)lastSpeech;
+ state.silenceRun = silenceRun;
+ state.i = (uint32_t)i;
+
+ return lastSpeech;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/voiceActivityDetection.h b/Whisper/Whisper/voiceActivityDetection.h
new file mode 100644
index 0000000..80952d2
--- /dev/null
+++ b/Whisper/Whisper/voiceActivityDetection.h
@@ -0,0 +1,54 @@
+#pragma once
+#include <complex>
+#include <memory>
+#include "audioConstants.h"
+
+namespace Whisper
+{
+ class VAD
+ {
+ using cplx = std::complex<float>;
+ std::unique_ptr<cplx[]> fft_signal;
+
+ struct Feature
+ {
+ float energy;
+ float F;
+ float SFM;
+ };
+ const Feature primThresh;
+ static Feature defaultPrimaryThresholds();
+
+ struct State
+ {
+ Feature currThresh;
+ Feature minFeature;
+ Feature curr;
+
+ uint32_t lastSpeech;
+ float silenceRun;
+ uint32_t i;
+ };
+ State state;
+
+ static inline void fft( cplx* buf, cplx* out, size_t n, size_t step );
+ void fft() const;
+
+ static float computeEnergy( const float* rsi );
+ static float computeDominant( const cplx* spectrum );
+ static float computreSpectralFlatnessMeasure( const cplx* spectrum );
+
+ public:
+
+ VAD();
+
+ // When no speech is detected, returns 0
+ // When speech is detected, returns sample position for the end of the speech
+ size_t detect( const float* rsi, size_t length );
+
+ void clear();
+
+ static constexpr uint32_t FFT_POINTS = 256;
+ static constexpr float FFT_STEP = (float)SAMPLE_RATE / (float)FFT_POINTS;
+ };
+} \ No newline at end of file
diff --git a/Whisper/misc.natvis b/Whisper/misc.natvis
new file mode 100644
index 0000000..23b22b3
--- /dev/null
+++ b/Whisper/misc.natvis
@@ -0,0 +1,50 @@
+<?xml version="1.0" encoding="utf-8"?>
+<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010">
+ <!-- This file is only for Visual Studio debugger, it doesn't effect the binary in any way.
+ More info: https://learn.microsoft.com/en-us/visualstudio/debugger/create-custom-views-of-native-objects -->
+
+ <Type Name="__m128i">
+ <DisplayString>[ { m128i_i32[ 0 ] }, { m128i_i32[ 1 ] }, { m128i_i32[ 2 ] }, { m128i_i32[ 3 ] } ]</DisplayString>
+ </Type>
+ <Type Name="__m128">
+ <DisplayString>[ { m128_f32[ 0 ] }, { m128_f32[ 1 ] }, { m128_f32[ 2 ] }, { m128_f32[ 3 ] } ]</DisplayString>
+ </Type>
+ <Type Name="__m256">
+ <DisplayString>[ { m256_f32[ 0 ] }, { m256_f32[ 1 ] }, { m256_f32[ 2 ] }, { m256_f32[ 3 ] }, { m256_f32[ 4 ] }, { m256_f32[ 5 ] }, { m256_f32[ 6 ] }, { m256_f32[ 7 ] } ]</DisplayString>
+ </Type>
+ <Type Name="__m256i">
+ <DisplayString>[ { m256i_i32[ 0 ] }, { m256i_i32[ 1 ] }, { m256i_i32[ 2 ] }, { m256i_i32[ 3 ] }, { m256i_i32[ 4 ] }, { m256i_i32[ 5 ] }, { m256i_i32[ 6 ] }, { m256i_i32[ 7 ] } ]</DisplayString>
+ </Type>
+
+ <Type Name="std::array&lt;*,4&gt;">
+ <DisplayString>[ { _Elems[ 0 ] }, { _Elems[ 1 ] }, { _Elems[ 2 ] }, { _Elems[ 3 ] } ]</DisplayString>
+ </Type>
+ <Type Name="std::array&lt;*,3&gt;">
+ <DisplayString>[ { _Elems[ 0 ] }, { _Elems[ 1 ] }, { _Elems[ 2 ] } ]</DisplayString>
+ </Type>
+ <Type Name="std::array&lt;*,2&gt;">
+ <DisplayString>[ { _Elems[ 0 ] }, { _Elems[ 1 ] } ]</DisplayString>
+ </Type>
+ <Type Name="DirectCompute::Tensor">
+ <DisplayString>Size { ne }, strides { nb }</DisplayString>
+ </Type>
+ <Type Name="CpuCompute::Tensor">
+ <DisplayString Condition="m_type==DirectCompute::eDataType::FP16">FP16 { ne }, strides { nb }</DisplayString>
+ <DisplayString Condition="m_type==DirectCompute::eDataType::FP32">FP32 { ne }, strides { nb }</DisplayString>
+ <DisplayString Condition="m_type==DirectCompute::eDataType::U32">U32 { ne }, strides { nb }</DisplayString>
+ <Expand>
+ <ArrayItems Condition="m_type==DirectCompute::eDataType::FP16">
+ <Size>ne._Elems[ 0 ] * ne._Elems[ 1 ] * ne._Elems[ 2 ] * ne._Elems[ 3 ]</Size>
+ <ValuePointer>(uint16_t*)m_data</ValuePointer>
+ </ArrayItems>
+ <ArrayItems Condition="m_type==DirectCompute::eDataType::FP32">
+ <Size>ne._Elems[ 0 ] * ne._Elems[ 1 ] * ne._Elems[ 2 ] * ne._Elems[ 3 ]</Size>
+ <ValuePointer>(float*)m_data</ValuePointer>
+ </ArrayItems>
+ <ArrayItems Condition="m_type==DirectCompute::eDataType::U32">
+ <Size>ne._Elems[ 0 ] * ne._Elems[ 1 ] * ne._Elems[ 2 ] * ne._Elems[ 3 ]</Size>
+ <ValuePointer>(uint32_t*)m_data</ValuePointer>
+ </ArrayItems>
+ </Expand>
+ </Type>
+</AutoVisualizer> \ No newline at end of file
diff --git a/Whisper/modelFactory.cpp b/Whisper/modelFactory.cpp
new file mode 100644
index 0000000..a708551
--- /dev/null
+++ b/Whisper/modelFactory.cpp
@@ -0,0 +1,19 @@
+#include "stdafx.h"
+#include "modelFactory.h"
+#include "API/iContext.cl.h"
+
+HRESULT COMLIGHTCALL Whisper::loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp )
+{
+ switch( impl )
+ {
+ case eModelImplementation::GPU:
+ return loadGpuModel( path, false, callbacks, pp );
+ case eModelImplementation::Hybrid:
+ return loadGpuModel( path, true, callbacks, pp );
+ case eModelImplementation::Reference:
+ return loadReferenceCpuModel( path, pp );
+ }
+
+ logError( u8"Unknown model implementation 0x%X", (int)impl );
+ return E_INVALIDARG;
+} \ No newline at end of file
diff --git a/Whisper/modelFactory.h b/Whisper/modelFactory.h
new file mode 100644
index 0000000..ebe77b1
--- /dev/null
+++ b/Whisper/modelFactory.h
@@ -0,0 +1,11 @@
+#pragma once
+#include "API/sLoadModelCallbacks.h"
+
+namespace Whisper
+{
+ struct iModel;
+
+ HRESULT __stdcall loadGpuModel( const wchar_t* path, bool hybrid, const sLoadModelCallbacks* callbacks, iModel** pp );
+
+ HRESULT __stdcall loadReferenceCpuModel( const wchar_t* path, iModel** pp );
+} \ No newline at end of file
diff --git a/Whisper/resource.h b/Whisper/resource.h
new file mode 100644
index 0000000..7ca31da
--- /dev/null
+++ b/Whisper/resource.h
@@ -0,0 +1,14 @@
+//{{NO_DEPENDENCIES}}
+// Microsoft Visual C++ generated include file.
+// Used by Resource.rc
+
+// Next default values for new objects
+//
+#ifdef APSTUDIO_INVOKED
+#ifndef APSTUDIO_READONLY_SYMBOLS
+#define _APS_NEXT_RESOURCE_VALUE 101
+#define _APS_NEXT_COMMAND_VALUE 40001
+#define _APS_NEXT_CONTROL_VALUE 1001
+#define _APS_NEXT_SYMED_VALUE 101
+#endif
+#endif
diff --git a/Whisper/source.compat/Readme.txt b/Whisper/source.compat/Readme.txt
new file mode 100644
index 0000000..affee91
--- /dev/null
+++ b/Whisper/source.compat/Readme.txt
@@ -0,0 +1 @@
+The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_BOTH_VERSIONS macro in stdafx.h \ No newline at end of file
diff --git a/Whisper/source.compat/convertThings.cpp b/Whisper/source.compat/convertThings.cpp
new file mode 100644
index 0000000..0e6e8c2
--- /dev/null
+++ b/Whisper/source.compat/convertThings.cpp
@@ -0,0 +1,234 @@
+#include "stdafx.h"
+#if BUILD_BOTH_VERSIONS
+#include "../API/iContext.cl.h"
+#include "convertThings.h"
+using namespace Whisper;
+
+sFullParams makeNewParams( const whisper_full_params& wfp )
+{
+ assert( nullptr == wfp.encoder_begin_callback );
+ assert( nullptr == wfp.new_segment_callback );
+
+ sFullParams res;
+ memset( &res, 0, sizeof( res ) );
+
+ res.strategy = (eSamplingStrategy)wfp.strategy;
+ res.cpuThreads = wfp.n_threads;
+ res.n_max_text_ctx = wfp.n_max_text_ctx;
+ res.offset_ms = wfp.offset_ms;
+ res.duration_ms = wfp.duration_ms;
+
+ // flags
+ uint32_t flags = 0;
+ if( wfp.translate ) flags |= (uint32_t)eFullParamsFlags::Translate;
+ if( wfp.no_context ) flags |= (uint32_t)eFullParamsFlags::NoContext;
+ if( wfp.single_segment ) flags |= (uint32_t)eFullParamsFlags::SingleSegment;
+ if( wfp.print_special ) flags |= (uint32_t)eFullParamsFlags::PrintSpecial;
+ if( wfp.print_progress ) flags |= (uint32_t)eFullParamsFlags::PrintProgress;
+ if( wfp.print_realtime ) flags |= (uint32_t)eFullParamsFlags::PrintRealtime;
+ if( wfp.print_timestamps ) flags |= (uint32_t)eFullParamsFlags::PrintTimestamps;
+ if( wfp.token_timestamps ) flags |= (uint32_t)eFullParamsFlags::TokenTimestamps;
+ if( wfp.speed_up ) flags |= (uint32_t)eFullParamsFlags::SpeedupAudio;
+ res.flags = (eFullParamsFlags)flags;
+
+ res.language = findLanguageKeyA( wfp.language );
+ res.thold_pt = wfp.thold_pt;
+ res.thold_ptsum = wfp.thold_ptsum;
+ res.max_len = wfp.max_len;
+ res.greedy.n_past = wfp.greedy.n_past;
+ res.beam_search.n_past = wfp.beam_search.n_past;
+ res.beam_search.beam_width = wfp.beam_search.beam_width;
+ res.beam_search.n_best = wfp.beam_search.n_best;
+ res.audio_ctx = wfp.audio_ctx;
+ res.prompt_tokens = wfp.prompt_tokens;
+ res.prompt_n_tokens = wfp.prompt_n_tokens;
+
+ return res;
+}
+
+namespace
+{
+ class NewParamsTemp
+ {
+ char language[ 5 ];
+ iContext* newContext;
+ pfnNewSegment newSegment;
+ pfnEncoderBegin encoderBegin;
+
+ static bool encBegin( struct whisper_context* ctx, void* user_data );
+ static void newSeg( struct whisper_context* ctx, int n_new, void* user_data );
+
+ public:
+
+ void initialize( whisper_full_params& res, const Whisper::sFullParams& rsi, Whisper::iContext* context )
+ {
+ *(uint32_t*)( &language[ 0 ] ) = rsi.language;
+ language[ 4 ] = '\0';
+ res.language = language;
+
+ newContext = context;
+
+ if( nullptr != rsi.encoder_begin_callback )
+ {
+ encoderBegin = rsi.encoder_begin_callback;
+ res.encoder_begin_callback = &encBegin;
+ res.encoder_begin_callback_user_data = rsi.encoder_begin_callback_user_data;
+ }
+ else
+ {
+ encoderBegin = nullptr;
+ res.encoder_begin_callback = nullptr;
+ res.encoder_begin_callback_user_data = nullptr;
+ }
+
+ if( nullptr != rsi.new_segment_callback )
+ {
+ newSegment = rsi.new_segment_callback;
+ res.new_segment_callback = &newSeg;
+ res.new_segment_callback_user_data = rsi.new_segment_callback_user_data;
+ }
+ else
+ {
+ newSegment = nullptr;
+ res.new_segment_callback = nullptr;
+ res.new_segment_callback_user_data = nullptr;
+ }
+ }
+ };
+
+ static thread_local NewParamsTemp npTemp;
+
+ bool NewParamsTemp::encBegin( struct whisper_context* ctx, void* user_data )
+ {
+ const NewParamsTemp& tmp = npTemp;
+ HRESULT hr = tmp.encoderBegin( tmp.newContext, user_data );
+ if( SUCCEEDED( hr ) )
+ return S_OK == hr;
+ throw hr;
+ }
+
+ void NewParamsTemp::newSeg( struct whisper_context* ctx, int n_new, void* user_data )
+ {
+ assert( n_new >= 0 );
+ const NewParamsTemp& tmp = npTemp;
+ HRESULT hr = tmp.newSegment( tmp.newContext, (uint32_t)n_new, user_data );
+ if( SUCCEEDED( hr ) )
+ return;
+ throw hr;
+ }
+}
+
+whisper_full_params makeOldParams( const Whisper::sFullParams& rsi, Whisper::iContext* context )
+{
+ whisper_full_params res;
+ memset( &res, 0, sizeof( res ) );
+
+ res.strategy = (whisper_sampling_strategy)rsi.strategy;
+ res.n_threads = rsi.cpuThreads;
+ res.n_max_text_ctx = rsi.n_max_text_ctx;
+ res.offset_ms = rsi.offset_ms;
+ res.duration_ms = rsi.duration_ms;
+
+ // flags
+ const uint32_t flags = (uint32_t)rsi.flags;
+ auto hasFlag = [ = ]( eFullParamsFlags bit ) { return 0 != ( flags & (uint32_t)bit ); };
+
+ res.translate = hasFlag( eFullParamsFlags::Translate );
+ res.no_context = hasFlag( eFullParamsFlags::NoContext );
+ res.single_segment = hasFlag( eFullParamsFlags::SingleSegment );
+ res.print_special = hasFlag( eFullParamsFlags::PrintSpecial );
+ res.print_progress = hasFlag( eFullParamsFlags::PrintProgress );
+ res.print_realtime = hasFlag( eFullParamsFlags::PrintRealtime );
+ res.print_timestamps = hasFlag( eFullParamsFlags::PrintTimestamps );
+ res.token_timestamps = hasFlag( eFullParamsFlags::TokenTimestamps );
+ res.speed_up = hasFlag( eFullParamsFlags::SpeedupAudio );
+
+ res.thold_pt = rsi.thold_pt;
+ res.thold_ptsum = rsi.thold_ptsum;
+ res.max_len = rsi.max_len;
+ res.greedy.n_past = rsi.greedy.n_past;
+ res.beam_search.n_past = rsi.beam_search.n_past;
+ res.beam_search.beam_width = rsi.beam_search.beam_width;
+ res.beam_search.n_best = rsi.beam_search.n_best;
+ res.audio_ctx = rsi.audio_ctx;
+ res.prompt_tokens = rsi.prompt_tokens;
+ res.prompt_n_tokens = rsi.prompt_n_tokens;
+
+ NewParamsTemp& tmp = npTemp;
+ tmp.initialize( res, rsi, context );
+ return res;
+}
+
+#include "../Whisper/TranscribeResult.h"
+#include <mfapi.h>
+
+namespace
+{
+ inline sTimeSpan time( int64_t wt )
+ {
+ int64_t ticks = MFllMulDiv( wt, 10'000'000, 100, 0 );
+ return sTimeSpan{ (uint64_t)ticks };
+ }
+
+ void makeNewResults( whisper_context* ctx, Whisper::eResultFlags flags, TranscribeResult& res )
+ {
+ const bool makeTokens = 0 != ( flags & eResultFlags::Tokens );
+ res.segments.clear();
+ res.tokens.clear();
+
+ const int countSegments = whisper_full_n_segments( ctx );
+ res.segments.resize( countSegments );
+ const int tokenEot = whisper_token_eot( ctx );
+ for( int i = 0; i < countSegments; i++ )
+ {
+ sSegment& seg = res.segments[ i ];
+ seg.text = whisper_full_get_segment_text( ctx, i );
+ seg.time.begin = time( whisper_full_get_segment_t0( ctx, i ) );
+ seg.time.end = time( whisper_full_get_segment_t1( ctx, i ) );
+
+ seg.firstToken = (uint32_t)res.tokens.size();
+ seg.countTokens = 0;
+ if( !makeTokens )
+ continue;
+
+ const int countTokens = whisper_full_n_tokens( ctx, i );
+ seg.countTokens = countTokens;
+ res.tokens.resize( res.tokens.size() + countTokens );
+ for( int t = 0; t < countTokens; t++ )
+ {
+ sToken& tok = res.tokens[ seg.firstToken + t ];
+ tok.text = whisper_full_get_token_text( ctx, i, t );
+
+ const whisper_token_data src = whisper_full_get_token_data( ctx, i, t );
+ tok.time.begin = time( src.t0 );
+ tok.time.end = time( src.t1 );
+ tok.probability = src.p;
+ tok.probabilityTimestamp = src.pt;
+ tok.ptsum = src.ptsum;
+ tok.vlen = src.vlen;
+ tok.id = src.id;
+ uint32_t flags = 0;
+ if( src.id >= tokenEot )
+ flags |= eTokenFlags::Special;
+ tok.flags = (eTokenFlags)flags;
+ }
+ }
+ }
+}
+
+HRESULT makeNewResults( whisper_context* ctx, Whisper::eResultFlags flags, Whisper::iTranscribeResult** pp )
+{
+ static TranscribeResultStatic trs;
+ if( flags & eResultFlags::NewObject )
+ {
+ return E_NOTIMPL;
+ }
+ else
+ {
+ makeNewResults( ctx, flags, trs );
+ *pp = &trs;
+ ( *pp )->AddRef();
+ return S_OK;
+ }
+}
+#endif \ No newline at end of file
diff --git a/Whisper/source.compat/convertThings.h b/Whisper/source.compat/convertThings.h
new file mode 100644
index 0000000..5750734
--- /dev/null
+++ b/Whisper/source.compat/convertThings.h
@@ -0,0 +1,10 @@
+#pragma once
+#include "../source/whisper.h"
+#include "../API/sFullParams.h"
+#include "../API/iTranscribeResult.cl.h"
+
+Whisper::sFullParams makeNewParams( const whisper_full_params& rsi );
+
+whisper_full_params makeOldParams( const Whisper::sFullParams& rsi, Whisper::iContext* context );
+
+HRESULT makeNewResults( whisper_context* ctx, Whisper::eResultFlags flags, Whisper::iTranscribeResult** pp ); \ No newline at end of file
diff --git a/Whisper/source.compat/ggmlMsvc.c b/Whisper/source.compat/ggmlMsvc.c
new file mode 100644
index 0000000..5a9f340
--- /dev/null
+++ b/Whisper/source.compat/ggmlMsvc.c
@@ -0,0 +1,37 @@
+#include <stdint.h>
+#include <immintrin.h>
+#include <stdio.h>
+#include <assert.h>
+#include "../source/ggml.h"
+
+__forceinline float _cvtsh_ss( uint16_t f16 )
+{
+ __m128i i = _mm_cvtsi32_si128( f16 );
+ __m128 f = _mm_cvtph_ps( i );
+ return _mm_cvtss_f32( f );
+}
+
+__forceinline uint16_t _cvtss_sh( float f, int rounding )
+{
+ assert( 0 == rounding );
+ __m128 v = _mm_set_ss( f );
+ __m128i i = _mm_cvtps_ph( v, 0 );
+ return (uint16_t)(uint32_t)_mm_cvtsi128_si32( i );
+}
+
+FILE* fopen_msvc( const char* filename, const char* mode )
+{
+ FILE* stream;
+ errno_t err = fopen_s( &stream, filename, mode );
+ if( err == 0 )
+ return stream;
+ return NULL;
+}
+
+#define fopen fopen_msvc
+
+#include "../ML/testUtilsC.h"
+
+#define __F16C__
+#define __FMA__
+#include "../source/ggml.c" \ No newline at end of file
diff --git a/Whisper/source/LICENSE b/Whisper/source/LICENSE
new file mode 100644
index 0000000..fb7ff0c
--- /dev/null
+++ b/Whisper/source/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Georgi Gerganov
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Whisper/source/Readme.txt b/Whisper/source/Readme.txt
new file mode 100644
index 0000000..affee91
--- /dev/null
+++ b/Whisper/source/Readme.txt
@@ -0,0 +1 @@
+The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_BOTH_VERSIONS macro in stdafx.h \ No newline at end of file
diff --git a/Whisper/source/ggml.c b/Whisper/source/ggml.c
new file mode 100644
index 0000000..c318b20
--- /dev/null
+++ b/Whisper/source/ggml.c
@@ -0,0 +1,8336 @@
+#include "ggml.h"
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <malloc.h> // using malloc.h with MSC/MINGW
+#elif !defined(__FreeBSD__)
+#include <alloca.h>
+#endif
+
+#include <assert.h>
+#include <time.h>
+#include <math.h>
+#include <stdlib.h>
+#include <string.h>
+#include <stdint.h>
+#include <stdio.h>
+
+// if C99 - static_assert is noop
+// ref: https://stackoverflow.com/a/53923785/4039976
+#ifndef static_assert
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+
+#if defined _MSC_VER || defined(__MINGW32__)
+
+#if !defined(__MINGW32__)
+#include <Windows.h>
+#else
+// ref: https://github.com/ggerganov/whisper.cpp/issues/168
+#include <windows.h>
+#include <errno.h>
+#endif
+
+typedef volatile LONG atomic_int;
+typedef atomic_int atomic_bool;
+
+static void atomic_store(atomic_int* ptr, LONG val) {
+ InterlockedExchange(ptr, val);
+}
+static LONG atomic_load(atomic_int* ptr) {
+ return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
+ return InterlockedExchangeAdd(ptr, inc);
+}
+static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
+ return atomic_fetch_add(ptr, -(dec));
+}
+
+typedef HANDLE pthread_t;
+
+typedef DWORD thread_ret_t;
+static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
+ HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
+ if (handle == NULL)
+ {
+ return EAGAIN;
+ }
+
+ *out = handle;
+ return 0;
+}
+
+static int pthread_join(pthread_t thread, void* unused) {
+ return (int) WaitForSingleObject(thread, INFINITE);
+}
+
+static int sched_yield (void) {
+ Sleep (0);
+ return 0;
+}
+#else
+#include <pthread.h>
+#include <stdatomic.h>
+
+typedef void* thread_ret_t;
+#endif
+
+#ifdef __HAIKU__
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#endif
+
+#define GGML_DEBUG 0
+#define GGML_GELU_FP16
+
+#if UINTPTR_MAX == 0xFFFFFFFF
+ #define GGML_MEM_ALIGN 4
+#else
+ #define GGML_MEM_ALIGN 16
+#endif
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+
+#define UNUSED(x) (void)(x)
+#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
+
+#define GGML_ASSERT(x) \
+ do { \
+ if (!(x)) { \
+ logError( u8"GGML_ASSERT: %s:%d: %s", __FILE__, __LINE__, #x); \
+ abort(); \
+ } \
+ } while (0)
+
+#ifdef GGML_USE_ACCELERATE
+#include <Accelerate/Accelerate.h>
+#elif GGML_USE_OPENBLAS
+#include <cblas.h>
+#endif
+
+// floating point type used to accumulate sums
+typedef double ggml_float;
+
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+#ifdef __ARM_NEON
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+float ggml_fp16_to_fp32(ggml_fp16_t x) {
+ return x;
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float x) {
+ return x;
+}
+
+#define GGML_FP16_TO_FP32(x) (x)
+#define GGML_FP32_TO_FP16(x) (x)
+
+#else
+
+#ifdef __wasm_simd128__
+#include <wasm_simd128.h>
+#else
+#ifdef __POWER9_VECTOR__
+#include <altivec.h>
+#undef bool
+#define bool _Bool
+#else
+#include <immintrin.h>
+#endif
+#endif
+
+#ifdef __F16C__
+float ggml_fp16_to_fp32(ggml_fp16_t h) {
+ return _cvtsh_ss(h);
+}
+ggml_fp16_t ggml_fp32_to_fp16(float f) {
+ return _cvtss_sh(f, 0);
+}
+
+#define GGML_FP16_TO_FP32(x) _cvtsh_ss(x)
+#define GGML_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+
+#else
+
+// FP16 <-> FP32
+// ref: https://github.com/Maratyszcza/FP16
+
+static inline float fp32_from_bits(uint32_t w) {
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } fp32;
+ fp32.as_bits = w;
+ return fp32.as_value;
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+ union {
+ float as_value;
+ uint32_t as_bits;
+ } fp32;
+ fp32.as_value = f;
+ return fp32.as_bits;
+}
+
+float ggml_fp16_to_fp32(ggml_fp16_t h) {
+ const uint32_t w = (uint32_t) h << 16;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ const uint32_t two_w = w + w;
+
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+ const float exp_scale = 0x1.0p-112f;
+#else
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+#endif
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+ const uint32_t magic_mask = UINT32_C(126) << 23;
+ const float magic_bias = 0.5f;
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+ const uint32_t result = sign |
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+ return fp32_from_bits(result);
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float f) {
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+ const float scale_to_inf = 0x1.0p+112f;
+ const float scale_to_zero = 0x1.0p-110f;
+#else
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+#endif
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+ const uint32_t w = fp32_to_bits(f);
+ const uint32_t shl1_w = w + w;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+ if (bias < UINT32_C(0x71000000)) {
+ bias = UINT32_C(0x71000000);
+ }
+
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+ const uint32_t bits = fp32_to_bits(base);
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+ const uint32_t nonsign = exp_bits + mantissa_bits;
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+}
+
+#define GGML_FP16_TO_FP32(x) ggml_fp16_to_fp32(x)
+#define GGML_FP32_TO_FP16(x) ggml_fp32_to_fp16(x)
+
+#endif // __F16C__
+
+#endif // __ARM_NEON
+
+//
+// global data
+//
+
+// precomputed gelu table for f16 (128 KB)
+static ggml_fp16_t table_gelu_f16[1 << 16];
+
+// precomputed exp table for f16 (128 KB)
+static ggml_fp16_t table_exp_f16[1 << 16];
+
+//
+// timing
+//
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+static int64_t timer_freq;
+void ggml_time_init(void) {
+ LARGE_INTEGER frequency;
+ QueryPerformanceFrequency(&frequency);
+ timer_freq = frequency.QuadPart;
+}
+int64_t ggml_time_ms(void) {
+ LARGE_INTEGER t;
+ QueryPerformanceCounter(&t);
+ return (t.QuadPart * 1000) / timer_freq;
+}
+int64_t ggml_time_us(void) {
+ LARGE_INTEGER t;
+ QueryPerformanceCounter(&t);
+ return (t.QuadPart * 1000000) / timer_freq;
+}
+#else
+void ggml_time_init(void) {}
+int64_t ggml_time_ms(void) {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
+}
+
+int64_t ggml_time_us(void) {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
+}
+#endif
+
+int64_t ggml_cycles(void) {
+ return clock();
+}
+
+int64_t ggml_cycles_per_ms(void) {
+ return CLOCKS_PER_SEC/1000;
+}
+
+#ifdef GGML_PERF
+#define ggml_perf_time_ms() ggml_time_ms()
+#define ggml_perf_time_us() ggml_time_us()
+#define ggml_perf_cycles() ggml_cycles()
+#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
+#else
+#define ggml_perf_time_ms() 0
+#define ggml_perf_time_us() 0
+#define ggml_perf_cycles() 0
+#define ggml_perf_cycles_per_ms() 0
+#endif
+
+//
+// cache line
+//
+
+#if defined(__cpp_lib_hardware_interference_size)
+#define CACHE_LINE_SIZE hardware_destructive_interference_size
+#else
+#define CACHE_LINE_SIZE 64
+#endif
+
+static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+
+//
+// simd mappings
+//
+
+// we define a common set of C macros which map to specific intrinsics based on the current architecture
+// we then implement the fundamental computation operations below using only these macros
+// adding support for new architectures requires to define the corresponding SIMD macros
+//
+// GGML_F32_STEP / GGML_F16_STEP
+// number of elements to process in a single step
+//
+// GGML_F32_EPR / GGML_F16_EPR
+// number of elements to fit in a single register
+//
+
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
+
+#define GGML_SIMD
+
+// F32 NEON
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 float32x4_t
+#define GGML_F32x4_ZERO vdupq_n_f32(0.0f)
+#define GGML_F32x4_SET1(x) vdupq_n_f32(x)
+#define GGML_F32x4_LOAD vld1q_f32
+#define GGML_F32x4_STORE vst1q_f32
+#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
+#define GGML_F32x4_ADD vaddq_f32
+#define GGML_F32x4_MUL vmulq_f32
+#if defined(__ARM_FEATURE_QRDMX)
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
+#else
+ #define GGML_F32x4_REDUCE_ONE(x) \
+ (vgetq_lane_f32(x, 0) + \
+ vgetq_lane_f32(x, 1) + \
+ vgetq_lane_f32(x, 2) + \
+ vgetq_lane_f32(x, 3))
+#endif
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+ x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+ x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+ x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \
+ } \
+ res = GGML_F32x4_REDUCE_ONE(x[0]); \
+}
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+ #define GGML_F16_STEP 32
+ #define GGML_F16_EPR 8
+
+ #define GGML_F16x8 float16x8_t
+ #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
+ #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
+ #define GGML_F16x8_LOAD vld1q_f16
+ #define GGML_F16x8_STORE vst1q_f16
+ #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
+ #define GGML_F16x8_ADD vaddq_f16
+ #define GGML_F16x8_MUL vmulq_f16
+ #define GGML_F16x8_REDUCE(res, x) \
+ { \
+ for (int i = 0; i < GGML_F16_ARR/2; ++i) { \
+ x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \
+ } \
+ for (int i = 0; i < GGML_F16_ARR/4; ++i) { \
+ x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \
+ } \
+ for (int i = 0; i < GGML_F16_ARR/8; ++i) { \
+ x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \
+ } \
+ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
+ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
+ res = vaddvq_f32(vaddq_f32(t0, t1)); \
+ }
+
+ #define GGML_F16_VEC GGML_F16x8
+ #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
+ #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
+ #define GGML_F16_VEC_LOAD GGML_F16x8_LOAD
+ #define GGML_F16_VEC_STORE GGML_F16x8_STORE
+ #define GGML_F16_VEC_FMA GGML_F16x8_FMA
+ #define GGML_F16_VEC_ADD GGML_F16x8_ADD
+ #define GGML_F16_VEC_MUL GGML_F16x8_MUL
+ #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
+#else
+ // if FP16 vector arithmetic is not supported, we use FP32 instead
+ // and take advantage of the vcvt_ functions to convert to/from FP16
+
+ #define GGML_F16_STEP 16
+ #define GGML_F16_EPR 4
+
+ #define GGML_F32Cx4 float32x4_t
+ #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
+ #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
+ #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x))
+ #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
+ #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
+ #define GGML_F32Cx4_ADD vaddq_f32
+ #define GGML_F32Cx4_MUL vmulq_f32
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
+
+ #define GGML_F16_VEC GGML_F32Cx4
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
+ #define GGML_F16_VEC_LOAD GGML_F32Cx4_LOAD
+ #define GGML_F16_VEC_STORE GGML_F32Cx4_STORE
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
+#endif
+
+#elif defined(__AVX__)
+
+#define GGML_SIMD
+
+// F32 AVX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR 8
+
+#define GGML_F32x8 __m256
+#define GGML_F32x8_ZERO _mm256_setzero_ps()
+#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32x8_LOAD _mm256_loadu_ps
+#define GGML_F32x8_STORE _mm256_storeu_ps
+#if defined(__FMA__)
+ #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
+#else
+ #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
+#endif
+#define GGML_F32x8_ADD _mm256_add_ps
+#define GGML_F32x8_MUL _mm256_mul_ps
+#define GGML_F32x8_REDUCE(res, x) \
+{ \
+ for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+ x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+ x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+ x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \
+ } \
+ const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
+ _mm256_extractf128_ps(x[0], 1)); \
+ const __m128 t1 = _mm_hadd_ps(t0, t0); \
+ res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC GGML_F32x8
+#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 AVX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR 8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32
+
+#define GGML_F32Cx8 __m256
+#define GGML_F32Cx8_ZERO _mm256_setzero_ps()
+#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
+#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#define GGML_F32Cx8_FMA GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD _mm256_add_ps
+#define GGML_F32Cx8_MUL _mm256_mul_ps
+#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC GGML_F32Cx8
+#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD GGML_F32Cx8_LOAD
+#define GGML_F16_VEC_STORE GGML_F32Cx8_STORE
+#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
+
+#elif defined(__POWER9_VECTOR__)
+
+// TODO: uncomment this when it works
+//#define GGML_SIMD
+
+// F32 POWER9
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR 8
+
+// TODO: not tested !!
+#define GGML_F32x4 __vector float
+#define GGML_F32x4_ZERO (__vector float){0.0f, 0.0f, 0.0f, 0.0f}
+#define GGML_F32x4_SET1(x) (__vector float){x, x, x, x}
+#define GGML_F32x4_LOAD vec_vsx_ld
+#define GGML_F32x4_STORE vec_vsx_st
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD vec_add
+#define GGML_F32x4_MUL vec_mul
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+ x[2*i] = vec_add(x[2*i], x[2*i+1]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+ x[4*i] = vec_add(x[4*i], x[4*i+2]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+ x[8*i] = vec_add(x[8*i], x[8*i+4]); \
+ } \
+ res = vec_extract(x[0], 0) + \
+ vec_extract(x[0], 1) + \
+ vec_extract(x[0], 2) + \
+ vec_extract(x[0], 3); \
+}
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 POWER9
+// TODO: implement here
+// ...
+
+#elif defined(__wasm_simd128__)
+
+#define GGML_SIMD
+
+// F32 WASM
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 v128_t
+#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f)
+#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x)
+#define GGML_F32x4_LOAD wasm_v128_load
+#define GGML_F32x4_STORE wasm_v128_store
+#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
+#define GGML_F32x4_ADD wasm_f32x4_add
+#define GGML_F32x4_MUL wasm_f32x4_mul
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
+ x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
+ x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
+ } \
+ for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
+ x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
+ } \
+ res = wasm_f32x4_extract_lane(x[0], 0) + \
+ wasm_f32x4_extract_lane(x[0], 1) + \
+ wasm_f32x4_extract_lane(x[0], 2) + \
+ wasm_f32x4_extract_lane(x[0], 3); \
+}
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 WASM
+
+#define GGML_F16_STEP 16
+#define GGML_F16_EPR 4
+
+inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
+ float tmp[4];
+
+ tmp[0] = GGML_FP16_TO_FP32(p[0]);
+ tmp[1] = GGML_FP16_TO_FP32(p[1]);
+ tmp[2] = GGML_FP16_TO_FP32(p[2]);
+ tmp[3] = GGML_FP16_TO_FP32(p[3]);
+
+ return wasm_v128_load(tmp);
+}
+
+inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
+ float tmp[4];
+
+ wasm_v128_store(tmp, x);
+
+ p[0] = GGML_FP32_TO_FP16(tmp[0]);
+ p[1] = GGML_FP32_TO_FP16(tmp[1]);
+ p[2] = GGML_FP32_TO_FP16(tmp[2]);
+ p[3] = GGML_FP32_TO_FP16(tmp[3]);
+}
+
+#define GGML_F16x4 v128_t
+#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f)
+#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x)
+#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x)
+#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
+#define GGML_F16x4_FMA GGML_F32x4_FMA
+#define GGML_F16x4_ADD wasm_f32x4_add
+#define GGML_F16x4_MUL wasm_f32x4_mul
+#define GGML_F16x4_REDUCE(res, x) \
+{ \
+ for (int i = 0; i < GGML_F16_ARR/2; ++i) { \
+ x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
+ } \
+ for (int i = 0; i < GGML_F16_ARR/4; ++i) { \
+ x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
+ } \
+ for (int i = 0; i < GGML_F16_ARR/8; ++i) { \
+ x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
+ } \
+ res = wasm_f32x4_extract_lane(x[0], 0) + \
+ wasm_f32x4_extract_lane(x[0], 1) + \
+ wasm_f32x4_extract_lane(x[0], 2) + \
+ wasm_f32x4_extract_lane(x[0], 3); \
+}
+
+#define GGML_F16_VEC GGML_F16x4
+#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1 GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD GGML_F16x4_LOAD
+#define GGML_F16_VEC_STORE GGML_F16x4_STORE
+#define GGML_F16_VEC_FMA GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
+
+#endif
+
+// GGML_F32_ARR / GGML_F16_ARR
+// number of registers to use per step
+#ifdef GGML_SIMD
+#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
+#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
+#endif
+
+//
+// fundamental operations
+//
+
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
+
+inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+ ggml_float sumf = 0.0;
+
+#ifdef GGML_SIMD
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+ GGML_F32_VEC ax[GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ // reduce sum0..sum3 to sum0
+ GGML_F32_VEC_REDUCE(sumf, sum);
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ sumf += x[i]*y[i];
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ sumf += x[i]*y[i];
+ }
+#endif
+
+ *s = sumf;
+}
+
+inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+ ggml_float sumf = 0.0;
+
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
+
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR);
+
+ sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ // reduce sum0..sum3 to sum0
+ GGML_F16_VEC_REDUCE(sumf, sum);
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+ }
+#elif defined(__POWER9_VECTOR__)
+ // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without
+ // being able to test it. hoping someone with access to a POWER9 machine can help out here.
+ const int n32 = (n & ~31);
+
+ vector float sum0 = vec_splats (0.0f);
+
+ for (int i = 0; i < n32; i += 32) {
+ // Use vec_xl, not vec_ld, because x is sometimes unaligned.
+ vector unsigned short x0 = vec_xl(i * 2 + 0, x);
+ vector unsigned short x1 = vec_xl(i * 2 + 16, x);
+ vector unsigned short x2 = vec_xl(i * 2 + 32, x);
+ vector unsigned short x3 = vec_xl(i * 2 + 48, x);
+
+ vector unsigned short y0 = vec_xl(i * 2 + 0, y);
+ vector unsigned short y1 = vec_xl(i * 2 + 16, y);
+ vector unsigned short y2 = vec_xl(i * 2 + 32, y);
+ vector unsigned short y3 = vec_xl(i * 2 + 48, y);
+
+ vector float fx0l = vec_extract_fp32_from_shortl(x0);
+ vector float fx0h = vec_extract_fp32_from_shorth(x0);
+ vector float fx1l = vec_extract_fp32_from_shortl(x1);
+ vector float fx1h = vec_extract_fp32_from_shorth(x1);
+ vector float fx2l = vec_extract_fp32_from_shortl(x2);
+ vector float fx2h = vec_extract_fp32_from_shorth(x2);
+ vector float fx3l = vec_extract_fp32_from_shortl(x3);
+ vector float fx3h = vec_extract_fp32_from_shorth(x3);
+
+ vector float fy0l = vec_extract_fp32_from_shortl(y0);
+ vector float fy0h = vec_extract_fp32_from_shorth(y0);
+ vector float fy1l = vec_extract_fp32_from_shortl(y1);
+ vector float fy1h = vec_extract_fp32_from_shorth(y1);
+ vector float fy2l = vec_extract_fp32_from_shortl(y2);
+ vector float fy2h = vec_extract_fp32_from_shorth(y2);
+ vector float fy3l = vec_extract_fp32_from_shortl(y3);
+ vector float fy3h = vec_extract_fp32_from_shorth(y3);
+
+ sum0 = vec_add(sum0, vec_mul(fx0l, fy0l));
+ sum0 = vec_add(sum0, vec_mul(fx0h, fy0h));
+ sum0 = vec_add(sum0, vec_mul(fx1l, fy1l));
+ sum0 = vec_add(sum0, vec_mul(fx1h, fy1h));
+ sum0 = vec_add(sum0, vec_mul(fx2l, fy2l));
+ sum0 = vec_add(sum0, vec_mul(fx2h, fy2h));
+ sum0 = vec_add(sum0, vec_mul(fx3l, fy3l));
+ sum0 = vec_add(sum0, vec_mul(fx3h, fy3h));
+ }
+
+ sumf = vec_extract(sum0, 0) + vec_extract(sum0, 1)
+ + vec_extract(sum0, 2) + vec_extract(sum0, 3);
+
+ for (int i = n32; i < n; ++i) {
+ sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+ }
+#else
+ for (int i = 0; i < n; ++i) {
+ sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+ }
+#endif
+
+ *s = sumf;
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+ GGML_F32_VEC ax[GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] += x[i]*v;
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] += x[i]*v;
+ }
+#endif
+}
+
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR);
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ GGML_ASSERT(false);
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#elif defined(__POWER9_VECTOR__)
+ // TODO: this is temporary because I cannot fit it in the GGML_SIMD pattern like all other architectures without
+ // being able to test it. hoping someone with access to a POWER9 machine can help out here.
+ const int n32 = (n & ~31);
+ for (int i = 0; i < n32; i += 32) {
+ // Use vec_xl, not vec_ld, because x is sometimes unaligned!
+ vector unsigned short x0 = vec_xl(i * 2 + 0, x);
+ vector unsigned short x1 = vec_xl(i * 2 + 16, x);
+ vector unsigned short x2 = vec_xl(i * 2 + 32, x);
+ vector unsigned short x3 = vec_xl(i * 2 + 48, x);
+
+ vector unsigned short y0 = vec_xl(i * 2 + 0, y);
+ vector unsigned short y1 = vec_xl(i * 2 + 16, y);
+ vector unsigned short y2 = vec_xl(i * 2 + 32, y);
+ vector unsigned short y3 = vec_xl(i * 2 + 48, y);
+
+ vector float v4 = vec_splats(v);
+
+ vector float fx0l = vec_extract_fp32_from_shortl(x0);
+ vector float fx0h = vec_extract_fp32_from_shorth(x0);
+ vector float fx1l = vec_extract_fp32_from_shortl(x1);
+ vector float fx1h = vec_extract_fp32_from_shorth(x1);
+ vector float fx2l = vec_extract_fp32_from_shortl(x2);
+ vector float fx2h = vec_extract_fp32_from_shorth(x2);
+ vector float fx3l = vec_extract_fp32_from_shortl(x3);
+ vector float fx3h = vec_extract_fp32_from_shorth(x3);
+
+ vector float fy0l = vec_extract_fp32_from_shortl(y0);
+ vector float fy0h = vec_extract_fp32_from_shorth(y0);
+ vector float fy1l = vec_extract_fp32_from_shortl(y1);
+ vector float fy1h = vec_extract_fp32_from_shorth(y1);
+ vector float fy2l = vec_extract_fp32_from_shortl(y2);
+ vector float fy2h = vec_extract_fp32_from_shorth(y2);
+ vector float fy3l = vec_extract_fp32_from_shortl(y3);
+ vector float fy3h = vec_extract_fp32_from_shorth(y3);
+
+ fy0l = vec_madd(fx0l, v4, fy0l);
+ fy0h = vec_madd(fx0h, v4, fy0h);
+ fy1l = vec_madd(fx1l, v4, fy1l);
+ fy1h = vec_madd(fx1h, v4, fy1h);
+ fy2l = vec_madd(fx2l, v4, fy2l);
+ fy2h = vec_madd(fx2h, v4, fy2h);
+ fy3l = vec_madd(fx3l, v4, fy3l);
+ fy3h = vec_madd(fx3h, v4, fy3h);
+
+ y0 = vec_pack_to_short_fp32(fy0h, fy0l);
+ y1 = vec_pack_to_short_fp32(fy1h, fy1l);
+ y2 = vec_pack_to_short_fp32(fy2h, fy2l);
+ y3 = vec_pack_to_short_fp32(fy3h, fy3l);
+
+ vec_xst(y0, i * 2 + 0, y);
+ vec_xst(y1, i * 2 + 16, y);
+ vec_xst(y2, i * 2 + 32, y);
+ vec_xst(y3, i * 2 + 48, y);
+ }
+
+ for (int i = n32; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#else
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#endif
+}
+
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
+
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] *= v;
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] *= v;
+ }
+#endif
+}
+
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
+inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
+inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+
+static const ggml_float GELU_COEF_A = 0.044715;
+static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+
+inline static float ggml_gelu_f32(float x) {
+ return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
+}
+
+inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+ const uint16_t * i16 = (const uint16_t *) x;
+ for (int i = 0; i < n; ++i) {
+ y[i] = table_gelu_f16[i16[i]];
+ }
+}
+
+#ifdef GGML_GELU_FP16
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+ uint16_t t;
+ for (int i = 0; i < n; ++i) {
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+ memcpy(&t, &fp16, sizeof(uint16_t));
+ y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]);
+ }
+}
+#else
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+ for (int i = 0; i < n; ++i) {
+ y[i] = ggml_gelu_f32(x[i]);
+ }
+}
+#endif
+
+inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; }
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
+
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) logDebug( __VA_ARGS__ )
+
+//
+// data types
+//
+
+static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
+ sizeof(int8_t ),
+ sizeof(int16_t),
+ sizeof(int32_t),
+ sizeof(ggml_fp16_t),
+ sizeof(float ),
+};
+
+static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
+ "NONE",
+
+ "DUP",
+ "ADD",
+ "SUB",
+ "MUL",
+ "DIV",
+ "SQR",
+ "SQRT",
+ "SUM",
+ "MEAN",
+ "REPEAT",
+ "ABS",
+ "SGN",
+ "NEG",
+ "STEP",
+ "RELU",
+ "GELU",
+ "NORM",
+
+ "MUL_MAT",
+
+ "SCALE",
+ "CPY",
+ "RESHAPE",
+ "VIEW",
+ "PERMUTE",
+ "TRANSPOSE",
+ "GET_ROWS",
+ "DIAG_MASK_INF",
+ "SOFT_MAX",
+ "ROPE",
+ "CONV_1D_1S",
+ "CONV_1D_2S",
+
+ "FLASH_ATTN",
+ "FLASH_FF",
+};
+
+static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+ "none",
+
+ "x",
+ "x+y",
+ "x-y",
+ "x*y",
+ "x/y",
+ "x^2",
+ "√x",
+ "Σx",
+ "Σx/n",
+ "repeat(x)",
+ "abs(x)",
+ "sgn(x)",
+ "-x",
+ "step(x)",
+ "relu(x)",
+ "gelu(x)",
+ "norm(x)",
+
+ "X*Y",
+
+ "x*v",
+ "x-\\>y",
+ "reshape(x)",
+ "view(x)",
+ "permute(x)",
+ "transpose(x)",
+ "get_rows(x)",
+ "diag_mask_inf(x)",
+ "soft_max(x)",
+ "rope(x)",
+ "conv_1d_1s(x)",
+ "conv_1d_2s(x)",
+
+ "flash_attn(x)",
+ "flash_ff(x)",
+};
+
+//
+// ggml object
+//
+
+struct ggml_object {
+ size_t offset;
+ size_t size;
+
+ struct ggml_object * next;
+
+ char padding[8];
+};
+
+static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
+static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
+
+//
+// ggml context
+//
+
+struct ggml_context {
+ size_t mem_size;
+ void * mem_buffer;
+ bool mem_buffer_owned;
+
+ int n_objects;
+
+ struct ggml_object * objects_begin;
+ struct ggml_object * objects_end;
+};
+
+struct ggml_context_container {
+ bool used;
+
+ struct ggml_context context;
+};
+
+//
+// compute types
+//
+
+enum ggml_task_type {
+ GGML_TASK_INIT = 0,
+ GGML_TASK_COMPUTE,
+ GGML_TASK_FINALIZE,
+};
+
+struct ggml_compute_params {
+ enum ggml_task_type type;
+
+ int ith, nth;
+
+ // work buffer for all threads
+ size_t wsize;
+ void * wdata;
+};
+
+//
+// ggml state
+//
+
+struct ggml_state {
+ struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
+};
+
+// global state
+static struct ggml_state g_state;
+static atomic_int g_state_barrier = 0;
+
+// barrier via spin lock
+inline static void ggml_critical_section_start() {
+ int processing = atomic_fetch_add(&g_state_barrier, 1);
+
+ while (processing > 0) {
+ // wait for other threads to finish
+ atomic_fetch_sub(&g_state_barrier, 1);
+ sched_yield(); // TODO: reconsider this
+ processing = atomic_fetch_add(&g_state_barrier, 1);
+ }
+}
+
+// TODO: make this somehow automatically executed
+// some sort of "sentry" mechanism
+inline static void ggml_critical_section_end() {
+ atomic_fetch_sub(&g_state_barrier, 1);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_print_object(const struct ggml_object * obj) {
+ GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
+ obj->offset, obj->size, (const void *) obj->next);
+}
+
+void ggml_print_objects(const struct ggml_context * ctx) {
+ struct ggml_object * obj = ctx->objects_begin;
+
+ GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx);
+
+ while (obj != NULL) {
+ ggml_print_object(obj);
+ obj = obj->next;
+ }
+
+ GGML_PRINT("%s: --- end ---\n", __func__);
+}
+
+int ggml_nelements(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+int ggml_nrows(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+size_t ggml_nbytes(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type];
+}
+
+size_t ggml_type_size(enum ggml_type type) {
+ return GGML_TYPE_SIZE[type];
+}
+
+size_t ggml_element_size(const struct ggml_tensor * tensor) {
+ return GGML_TYPE_SIZE[tensor->type];
+}
+
+bool ggml_is_scalar(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_vector(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_matrix(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t0->ne[0] == t1->ne[0]) &&
+ (t0->ne[2] == t1->ne[2]) &&
+ (t0->ne[3] == t1->ne[3]);
+}
+
+bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+ tensor->nb[1] == tensor->nb[0]*tensor->ne[0] &&
+ tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+ tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t0->ne[0] == t1->ne[0] ) &&
+ (t0->ne[1] == t1->ne[1] ) &&
+ (t0->ne[2] == t1->ne[2] ) &&
+ (t0->ne[3] == t1->ne[3] );
+}
+
+// check if t1 can be represented as a repeatition of t0
+bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t1->ne[0]%t0->ne[0] == 0) &&
+ (t1->ne[1]%t0->ne[1] == 0) &&
+ (t1->ne[2]%t0->ne[2] == 0) &&
+ (t1->ne[3]%t0->ne[3] == 0);
+}
+
+int ggml_up32(int n) {
+ return (n + 31) & ~31;
+}
+
+int ggml_up64(int n) {
+ return (n + 63) & ~63;
+}
+
+// assert that pointer is aligned to GGML_MEM_ALIGN
+#define ggml_assert_aligned(ptr) \
+ assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_context * ggml_init(struct ggml_init_params params) {
+ // make this function thread safe
+ ggml_critical_section_start();
+
+ static bool is_first_call = true;
+
+ if (is_first_call) {
+ // initialize GELU and EXP tables
+ {
+ const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+ ggml_fp16_t ii;
+ for (int i = 0; i < (1 << 16); ++i) {
+ uint16_t ui = i;
+ memcpy(&ii, &ui, sizeof(ii));
+ const float f = GGML_FP16_TO_FP32(ii);
+ table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+ table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f));
+ }
+
+ const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+ GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+ }
+
+ // initialize g_state
+ {
+ const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+ g_state = (struct ggml_state) {
+ /*.contexts =*/ { 0 },
+ };
+
+ for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
+ g_state.contexts[i].used = false;
+ }
+
+ const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+ GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+ }
+
+ is_first_call = false;
+ }
+
+ // find non-used context in g_state
+ struct ggml_context * ctx = NULL;
+
+ for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+ if (!g_state.contexts[i].used) {
+ g_state.contexts[i].used = true;
+ ctx = &g_state.contexts[i].context;
+
+ GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
+ break;
+ }
+ }
+
+ if (ctx == NULL) {
+ GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
+
+ ggml_critical_section_end();
+
+ return NULL;
+ }
+
+ *ctx = (struct ggml_context) {
+ .mem_size = params.mem_size,
+ .mem_buffer = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
+ .mem_buffer_owned = params.mem_buffer ? false : true,
+ .n_objects = 0,
+ .objects_begin = NULL,
+ .objects_end = NULL,
+ };
+
+ ggml_assert_aligned(ctx->mem_buffer);
+
+ GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
+
+ ggml_critical_section_end();
+
+ return ctx;
+}
+
+void ggml_free(struct ggml_context * ctx) {
+ // make this function thread safe
+ ggml_critical_section_start();
+
+ bool found = false;
+
+ for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+ if (&g_state.contexts[i].context == ctx) {
+ g_state.contexts[i].used = false;
+
+ GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
+ __func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
+
+ if (ctx->mem_buffer_owned) {
+ free(ctx->mem_buffer);
+ }
+
+ found = true;
+ break;
+ }
+ }
+
+ if (!found) {
+ GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+ }
+
+ ggml_critical_section_end();
+}
+
+size_t ggml_used_mem(const struct ggml_context * ctx) {
+ return ctx->objects_end->offset + ctx->objects_end->size;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_tensor * ggml_new_tensor_impl(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int* ne,
+ void* data) {
+ // always insert objects at the end of the context's memory pool
+ struct ggml_object * obj_cur = ctx->objects_end;
+
+ const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset;
+ const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
+ const size_t cur_end = cur_offset + cur_size;
+
+ size_t size_needed = 0;
+
+ if (data == NULL) {
+ size_needed += GGML_TYPE_SIZE[type];
+ for (int i = 0; i < n_dims; i++) {
+ size_needed *= ne[i];
+ }
+ // align to GGML_MEM_ALIGN
+ size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
+
+ }
+ size_needed += sizeof(struct ggml_tensor);
+
+ if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+ GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
+ assert(false);
+ return NULL;
+ }
+
+ char * const mem_buffer = ctx->mem_buffer;
+
+ struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
+
+ *obj_new = (struct ggml_object) {
+ .offset = cur_end + GGML_OBJECT_SIZE,
+ .size = size_needed,
+ .next = NULL,
+ };
+
+ if (obj_cur != NULL) {
+ obj_cur->next = obj_new;
+ } else {
+ // this is the first object in this context
+ ctx->objects_begin = obj_new;
+ }
+
+ ctx->objects_end = obj_new;
+
+ //GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end);
+
+ struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset);
+
+ ggml_assert_aligned(result);
+
+ *result = (struct ggml_tensor) {
+ /*.type =*/ type,
+ /*.n_dims =*/ n_dims,
+ /*.ne =*/ { 1, 1, 1, 1 },
+ /*.nb =*/ { 0, 0, 0, 0 },
+ /*.op =*/ GGML_OP_NONE,
+ /*.is_param =*/ false,
+ /*.grad =*/ NULL,
+ /*.src0 =*/ NULL,
+ /*.src1 =*/ NULL,
+ /*.opt =*/ { NULL },
+ /*.n_tasks =*/ 0,
+ /*.perf_runs =*/ 0,
+ /*.perf_cycles =*/ 0,
+ /*.perf_time_us =*/ 0,
+ /*.data =*/ data == NULL ? (void *)(result + 1) : data,
+ /*.pad =*/ { 0 },
+ };
+
+ ggml_assert_aligned(result->data);
+
+ for (int i = 0; i < n_dims; i++) {
+ result->ne[i] = ne[i];
+ }
+
+ result->nb[0] = GGML_TYPE_SIZE[type];
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
+ }
+
+ ctx->n_objects++;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_new_tensor(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int* ne) {
+ return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
+}
+
+struct ggml_tensor * ggml_new_tensor_1d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0) {
+ return ggml_new_tensor(ctx, type, 1, &ne0);
+}
+
+struct ggml_tensor * ggml_new_tensor_2d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1) {
+ const int ne[2] = { ne0, ne1 };
+ return ggml_new_tensor(ctx, type, 2, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_3d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1,
+ int ne2) {
+ const int ne[3] = { ne0, ne1, ne2 };
+ return ggml_new_tensor(ctx, type, 3, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_4d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1,
+ int ne2,
+ int ne3) {
+ const int ne[4] = { ne0, ne1, ne2, ne3 };
+ return ggml_new_tensor(ctx, type, 4, ne);
+}
+
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+
+ ggml_set_i32(result, value);
+
+ return result;
+}
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+
+ ggml_set_f32(result, value);
+
+ return result;
+}
+
+struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
+ return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL);
+}
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
+ memset(tensor->data, 0, ggml_nbytes(tensor));
+ return tensor;
+}
+
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
+ const int n = ggml_nrows(tensor);
+ const int nc = tensor->ne[0];
+ const size_t n1 = tensor->nb[1];
+
+ char * const data = tensor->data;
+
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ assert(tensor->nb[0] == sizeof(int8_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I16:
+ {
+ assert(tensor->nb[0] == sizeof(int16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I32:
+ {
+ assert(tensor->nb[0] == sizeof(int32_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_F16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_F32:
+ {
+ assert(tensor->nb[0] == sizeof(float));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+
+ return tensor;
+}
+
+struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
+ const int n = ggml_nrows(tensor);
+ const int nc = tensor->ne[0];
+ const size_t n1 = tensor->nb[1];
+
+ char * const data = tensor->data;
+
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ assert(tensor->nb[0] == sizeof(int8_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I16:
+ {
+ assert(tensor->nb[0] == sizeof(int16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I32:
+ {
+ assert(tensor->nb[0] == sizeof(int32_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_F16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_F32:
+ {
+ assert(tensor->nb[0] == sizeof(float));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+
+ return tensor;
+}
+
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+ return ((int8_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_I16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+ return ((int16_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_I32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+ return ((int32_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(float));
+ return ((float *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+ ((int8_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+ ((int16_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+ ((int32_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+ ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(float));
+ ((float *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+ return ((int8_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_I16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+ return ((int16_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_I32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+ return ((int32_t *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(float));
+ return ((float *)(tensor->data))[i];
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+ ((int8_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+ ((int16_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+ ((int32_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+ ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(float));
+ ((float *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+void * ggml_get_data(const struct ggml_tensor * tensor) {
+ return tensor->data;
+}
+
+float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
+ assert(tensor->type == GGML_TYPE_F32);
+ return (float *)(tensor->data);
+}
+
+struct ggml_tensor * ggml_view_tensor(
+ struct ggml_context * ctx,
+ const struct ggml_tensor * src) {
+ return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// ggml_dup
+
+struct ggml_tensor * ggml_dup_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_DUP;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_dup(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_dup_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_dup_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_dup_impl(ctx, a, true);
+}
+
+// ggml_add
+
+struct ggml_tensor * ggml_add_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_ADD;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_add_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_add_impl(ctx, a, b, true);
+}
+
+// ggml_sub
+
+struct ggml_tensor * ggml_sub_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SUB;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sub(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_sub_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_sub_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_sub_impl(ctx, a, b, true);
+}
+
+// ggml_mul
+
+struct ggml_tensor * ggml_mul_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ if (inplace) {
+ assert(is_node == false);
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_MUL;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_mul(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_mul_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_mul_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_mul_impl(ctx, a, b, true);
+}
+
+// ggml_div
+
+struct ggml_tensor * ggml_div_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ if (inplace) {
+ assert(is_node == false);
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_DIV;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_div(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_div_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_div_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_div_impl(ctx, a, b, true);
+}
+
+// ggml_sqr
+
+struct ggml_tensor * ggml_sqr_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SQR;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sqr(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqr_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqr_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqr_impl(ctx, a, true);
+}
+
+// ggml_sqrt
+
+struct ggml_tensor * ggml_sqrt_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SQRT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sqrt(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqrt_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqrt_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqrt_impl(ctx, a, true);
+}
+
+// ggml_sum
+
+struct ggml_tensor * ggml_sum(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+ result->op = GGML_OP_SUM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_mean
+
+struct ggml_tensor * ggml_mean(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement
+ is_node = true;
+ }
+
+ int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne);
+
+ result->op = GGML_OP_MEAN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_repeat
+
+struct ggml_tensor * ggml_repeat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_can_repeat(a, b));
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ if (ggml_are_same_shape(a, b) && !is_node) {
+ return a;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
+
+ result->op = GGML_OP_REPEAT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_abs
+
+struct ggml_tensor * ggml_abs_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_ABS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_abs(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_abs_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_abs_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_abs_impl(ctx, a, true);
+}
+
+
+// ggml_sgn
+
+struct ggml_tensor * ggml_sgn_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SGN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sgn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sgn_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sgn_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sgn_impl(ctx, a, true);
+}
+
+// ggml_neg
+
+struct ggml_tensor * ggml_neg_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_NEG;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_neg(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_neg_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_neg_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_neg_impl(ctx, a, true);
+}
+
+// ggml_step
+
+struct ggml_tensor * ggml_step_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_STEP;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_step(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_step_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_step_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_step_impl(ctx, a, true);
+}
+
+// ggml_relu
+
+struct ggml_tensor * ggml_relu_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_RELU;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_relu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_relu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_relu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_relu_impl(ctx, a, true);
+}
+
+// ggml_gelu
+
+struct ggml_tensor * ggml_gelu_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_GELU;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_gelu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_gelu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_gelu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_gelu_impl(ctx, a, true);
+}
+
+// ggml_norm
+
+struct ggml_tensor * ggml_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_NORM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store epsilon here?
+
+ return result;
+}
+
+struct ggml_tensor * ggml_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_norm_impl(ctx, a, true);
+}
+
+// ggml_mul_mat
+
+struct ggml_tensor * ggml_mul_mat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_can_mul_mat(a, b));
+
+ // printUniqueTensorSize( "ggml_mul_mat", a->ne, b->ne );
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+
+ result->op = GGML_OP_MUL_MAT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_scale
+
+struct ggml_tensor * ggml_scale_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_is_scalar(b));
+ assert(ggml_is_padded_1d(a));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ result->op = GGML_OP_SCALE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_scale(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_scale_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_scale_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_scale_impl(ctx, a, b, true);
+}
+
+// ggml_cpy
+
+struct ggml_tensor * ggml_cpy_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ assert(ggml_nelements(a) == ggml_nelements(b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // make a view of the destination
+ struct ggml_tensor * result = ggml_view_tensor(ctx, b);
+
+ result->op = GGML_OP_CPY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_cpy(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_cpy_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_cpy_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_cpy_impl(ctx, a, b, true);
+}
+
+// ggml_reshape
+
+struct ggml_tensor * ggml_reshape(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_is_contiguous(a));
+ assert(ggml_is_contiguous(b));
+ assert(ggml_nelements(a) == ggml_nelements(b));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_reshape_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1) {
+ assert(ggml_is_contiguous(a));
+ assert(ggml_nelements(a) == ne0*ne1);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int ne[2] = { ne0, ne1 };
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_reshape_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ int ne2) {
+ assert(ggml_is_contiguous(a));
+ assert(ggml_nelements(a) == ne0*ne1*ne2);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int ne[3] = { ne0, ne1, ne2 };
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_view_1d
+
+struct ggml_tensor * ggml_view_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ size_t offset) {
+ if (a->grad) {
+ assert(false); // gradient propagation is not supported
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
+
+ result->op = GGML_OP_VIEW;
+ result->grad = NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store the offset here?
+
+ return result;
+}
+
+// ggml_view_2d
+
+struct ggml_tensor * ggml_view_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ size_t nb1,
+ size_t offset) {
+ if (a->grad) {
+ assert(false); // gradient propagation is not supported
+ }
+
+ const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
+
+ result->nb[1] = nb1;
+ result->nb[2] = result->nb[1]*ne1;
+ result->nb[3] = result->nb[2];
+
+ result->op = GGML_OP_VIEW;
+ result->grad = NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store the offset here?
+
+ return result;
+}
+
+// ggml_permute
+
+struct ggml_tensor * ggml_permute(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int axis0,
+ int axis1,
+ int axis2,
+ int axis3) {
+ assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
+ assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
+ assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
+ assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
+
+ assert(axis0 != axis1);
+ assert(axis0 != axis2);
+ assert(axis0 != axis3);
+ assert(axis1 != axis2);
+ assert(axis1 != axis3);
+ assert(axis2 != axis3);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ int ne[GGML_MAX_DIMS];
+ int nb[GGML_MAX_DIMS];
+
+ ne[axis0] = a->ne[0];
+ ne[axis1] = a->ne[1];
+ ne[axis2] = a->ne[2];
+ ne[axis3] = a->ne[3];
+
+ nb[axis0] = a->nb[0];
+ nb[axis1] = a->nb[1];
+ nb[axis2] = a->nb[2];
+ nb[axis3] = a->nb[3];
+
+ result->ne[0] = ne[0];
+ result->ne[1] = ne[1];
+ result->ne[2] = ne[2];
+ result->ne[3] = ne[3];
+
+ result->nb[0] = nb[0];
+ result->nb[1] = nb[1];
+ result->nb[2] = nb[2];
+ result->nb[3] = nb[3];
+
+ result->op = GGML_OP_PERMUTE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store the permutation here?
+
+ return result;
+}
+
+// ggml_transpose
+
+struct ggml_tensor * ggml_transpose(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ result->ne[0] = a->ne[1];
+ result->ne[1] = a->ne[0];
+
+ result->nb[0] = a->nb[1];
+ result->nb[1] = a->nb[0];
+
+ result->op = GGML_OP_TRANSPOSE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_get_rows
+
+struct ggml_tensor * ggml_get_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: implement non F32 return
+ //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
+ struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+
+ result->op = GGML_OP_GET_ROWS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_diag_mask_inf
+
+struct ggml_tensor * ggml_diag_mask_inf(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past) {
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+ ((int32_t *) b->data)[0] = n_past;
+
+ result->op = GGML_OP_DIAG_MASK_INF;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_soft_max
+
+struct ggml_tensor * ggml_soft_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ result->op = GGML_OP_SOFT_MAX;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+// ggml_rope
+
+struct ggml_tensor * ggml_rope(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ int n_dims,
+ int mode) {
+ assert(n_past >= 0);
+ bool is_node = false;
+
+ if (a->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+ ((int32_t *) b->data)[0] = n_past;
+ ((int32_t *) b->data)[1] = n_dims;
+ ((int32_t *) b->data)[2] = mode;
+
+ result->op = GGML_OP_ROPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_conv_1d_1s
+
+struct ggml_tensor * ggml_conv_1d_1s(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_is_matrix(b));
+ assert(a->ne[1] == b->ne[1]);
+ assert(a->ne[3] == 1);
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int ne[4] = { b->ne[0], a->ne[2], 1, 1, };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+
+ result->op = GGML_OP_CONV_1D_1S;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_conv_1d_2s
+
+struct ggml_tensor * ggml_conv_1d_2s(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ assert(ggml_is_matrix(b));
+ assert(a->ne[1] == b->ne[1]);
+ assert(a->ne[3] == 1);
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ assert(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+
+ result->op = GGML_OP_CONV_1D_2S;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+
+ return result;
+}
+
+// ggml_flash_attn
+
+struct ggml_tensor * ggml_flash_attn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ bool masked) {
+ assert(ggml_can_mul_mat(k, q));
+ // TODO: check if vT can be multiplied by (k*qT)
+
+ bool is_node = false;
+
+ if (q->grad || k->grad || v->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
+
+ result->op = GGML_OP_FLASH_ATTN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = q;
+ result->src1 = k;
+ result->opt[0] = v;
+ result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
+
+ return result;
+}
+
+// ggml_flash_ff
+
+struct ggml_tensor * ggml_flash_ff(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b0,
+ struct ggml_tensor * b1,
+ struct ggml_tensor * c0,
+ struct ggml_tensor * c1) {
+ assert(ggml_can_mul_mat(b0, a));
+ // TODO: more checks
+
+ bool is_node = false;
+
+ if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
+
+ result->op = GGML_OP_FLASH_FF;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b0;
+ result->opt[0] = b1;
+ result->opt[1] = c0;
+ result->opt[2] = c1;
+
+ return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_param(
+ struct ggml_context * ctx,
+ struct ggml_tensor * tensor) {
+ tensor->is_param = true;
+
+ assert(tensor->grad == NULL);
+ tensor->grad = ggml_dup_tensor(ctx, tensor);
+}
+
+// ggml_compute_forward_dup
+
+static void ggml_compute_forward_dup_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_is_contiguous(dst));
+ assert(ggml_nelements(dst) == ggml_nelements(src0));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb00 = src0->nb[0];
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+ memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+ return;
+ }
+
+ if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+ if (dst->type == GGML_TYPE_F16) {
+ int id = 0;
+ const size_t rs = ne00*nb00;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ char * dst_ptr = (char *) dst->data + id*rs;
+
+ memcpy(dst_ptr, src0_ptr, rs);
+
+ id++;
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F32) {
+ int id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ } else {
+ //printf("%s: this is not optimal - fix me\n", __func__);
+
+ if (dst->type == GGML_TYPE_F32) {
+ int id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ int id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ }
+}
+
+static void ggml_compute_forward_dup_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(params->ith == 0);
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb00 = src0->nb[0];
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+ memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
+ return;
+ }
+
+ if (src0->nb[0] == sizeof(float)) {
+ if (dst->type == GGML_TYPE_F32) {
+ int id = 0;
+ const size_t rs = ne00*nb00;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ char * dst_ptr = (char *) dst->data + id*rs;
+
+ memcpy(dst_ptr, src0_ptr, rs);
+
+ id++;
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ int id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ } else {
+ //printf("%s: this is not optimal - fix me\n", __func__);
+
+ if (dst->type == GGML_TYPE_F32) {
+ int id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ int id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ }
+}
+
+static void ggml_compute_forward_dup(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_dup_f16(params, src0, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_dup_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_add
+
+static void ggml_compute_forward_add_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ const size_t nb00 = src0->nb[0];
+ const size_t nb01 = src0->nb[1];
+
+ const size_t nb10 = src1->nb[0];
+ const size_t nb11 = src1->nb[1];
+
+ const size_t nb0 = dst->nb[0];
+ const size_t nb1 = dst->nb[1];
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ if (nb10 == sizeof(float)) {
+ const int j0 = (n/nth)*ith;
+ const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
+
+ for (int j = j0; j < j1; j++) {
+ ggml_vec_add_f32(nc,
+ (float *) ((char *) dst->data + j*nb1),
+ (float *) ((char *) src0->data + j*nb01),
+ (float *) ((char *) src1->data + j*nb11));
+ }
+ } else {
+ // src1 is not contiguous
+ for (int j = ith; j < n; j += nth) {
+ float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
+ float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+ for (int i = 0; i < nc; i++) {
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+
+ dst_ptr[i] = src0_ptr[i] + *src1_ptr;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_add(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_add_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sub
+
+static void ggml_compute_forward_sub_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+ assert(src1->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sub_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sub(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sub_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mul
+
+static void ggml_compute_forward_mul_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+ assert(src1->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_mul_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_mul(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_mul_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_div
+
+static void ggml_compute_forward_div_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+ assert(src1->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_div_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_div(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_div_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sqr
+
+static void ggml_compute_forward_sqr_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sqr_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sqr(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sqr_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sqrt
+
+static void ggml_compute_forward_sqrt_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sqrt_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sqrt(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sqrt_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sum
+
+static void ggml_compute_forward_sum_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_is_scalar(dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ assert(ggml_is_scalar(dst));
+ assert(src0->nb[0] == sizeof(float));
+
+ *(float *) (dst->data) = 0.0f;
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ ggml_vec_sum_f32(ne00,
+ (float *) (dst->data),
+ (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_sum(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sum_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mean
+
+static void ggml_compute_forward_mean_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ assert(src0->nb[0] == sizeof(float));
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ const int ne2 = dst->ne[2];
+ const int ne3 = dst->ne[3];
+
+ assert(ne0 == 1);
+ assert(ne1 == ne01);
+ assert(ne2 == ne02);
+ assert(ne3 == ne03);
+
+ UNUSED(ne0);
+ UNUSED(ne1);
+ UNUSED(ne2);
+ UNUSED(ne3);
+
+ const size_t nb1 = dst->nb[1];
+ const size_t nb2 = dst->nb[2];
+ const size_t nb3 = dst->nb[3];
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) = 0.0f;
+
+ ggml_vec_sum_f32(ne00,
+ (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
+ (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+
+ *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_mean(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_mean_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_repeat
+
+static void ggml_compute_forward_repeat_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_can_repeat(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // TODO: implement support for rank > 2 tensors
+ assert(src0->ne[2] == 1);
+ assert(src0->ne[3] == 1);
+ assert( dst->ne[2] == 1);
+ assert( dst->ne[3] == 1);
+
+ const int nc = dst->ne[0];
+ const int nr = dst->ne[1];
+ const int nc0 = src0->ne[0];
+ const int nr0 = src0->ne[1];
+ const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
+ const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
+
+ // TODO: support for transposed / permuted tensors
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ // TODO: maybe this is not optimal?
+ for (int i = 0; i < nrr; i++) {
+ for (int j = 0; j < ncr; j++) {
+ for (int k = 0; k < nr0; k++) {
+ ggml_vec_cpy_f32(nc0,
+ (float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])),
+ (float *) ((char *) src0->data + ( k)*(src0->nb[1])));
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_repeat(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_repeat_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_abs
+
+static void ggml_compute_forward_abs_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_abs_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_abs(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_abs_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sgn
+
+static void ggml_compute_forward_sgn_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sgn_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sgn(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sgn_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_neg
+
+static void ggml_compute_forward_neg_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_neg_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_neg(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_neg_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_step
+
+static void ggml_compute_forward_step_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_step_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_step(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_step_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_relu
+
+static void ggml_compute_forward_relu_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_relu_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_relu(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_relu_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_gelu
+
+static void ggml_compute_forward_gelu_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_gelu_f32(nc,
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+ UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_gelu(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_gelu_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_norm
+
+static void ggml_compute_forward_norm_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ const size_t nb1 = dst->nb[1];
+ const size_t nb2 = dst->nb[2];
+ const size_t nb3 = dst->nb[3];
+
+ const ggml_float eps = 1e-5f; // TODO: make this a parameter
+
+ // TODO: optimize
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = ith; i01 < ne01; i01 += nth) {
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ ggml_float mean = 0.0;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ mean += x[i00];
+ }
+
+ mean /= ne00;
+
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ ggml_float sum2 = 0.0;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ ggml_float v = x[i00] - mean;
+ y[i00] = v;
+ sum2 += v*v;
+ }
+
+ const float scale = 1.0/sqrt(sum2/ne00 + eps);
+
+ ggml_vec_scale_f32(ne00, y, scale);
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_norm(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_norm_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mul_mat
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+// helper function to determine if it is better to use BLAS or not
+// for large matrices, BLAS is faster
+static bool ggml_compute_forward_mul_mat_use_blas(
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ UNUSED(src0);
+
+ const int ne10 = src1->ne[0];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+
+ // TODO: find the optimal values for these
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) {
+ //printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
+ return true;
+ }
+
+ return false;
+}
+#endif
+
+static void ggml_compute_forward_mul_mat_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ const int ne12 = src1->ne[2];
+ const int ne13 = src1->ne[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ const int ne2 = dst->ne[2];
+ const int ne3 = dst->ne[3];
+ const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ const int nb12 = src1->nb[2];
+ const int nb13 = src1->nb[3];
+
+ const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ assert(ne02 == ne12);
+ assert(ne03 == ne13);
+ assert(ne2 == ne12);
+ assert(ne3 == ne13);
+
+ // TODO: we don't support permuted src0
+ assert(nb00 == sizeof(float) || nb01 == sizeof(float));
+
+ // dst cannot be transposed or permuted
+ assert(nb0 == sizeof(float));
+ assert(nb0 <= nb1);
+ assert(nb1 <= nb2);
+ assert(nb2 <= nb3);
+
+ assert(ne0 == ne01);
+ assert(ne1 == ne11);
+ assert(ne2 == ne02);
+ assert(ne3 == ne03);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+ //
+ // nb00 < nb01 - src0 is transposed
+ // compute by src0 columns
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+ if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (params->ith != 0) return;
+
+ if (params->type == GGML_TASK_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ const float * x = (float *) (src0->data);
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+ // zT = y * xT
+ {
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+ ne11, ne01, ne10,
+ 1.0f, y, ne10,
+ x, ne10,
+ 0.0f, d, ne01);
+ }
+ }
+ }
+
+ //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
+
+ return;
+ }
+#endif
+
+ if (params->type == GGML_TASK_INIT) {
+ if (nb01 >= nb00) {
+ return;
+ }
+
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ if (nb01 >= nb00) {
+ return;
+ }
+
+ // TODO: fix this memset (wsize is overestimated)
+ //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+
+ float * const wdata = params->wdata;
+
+ // cols per thread
+ const int dc = (ne + nth - 1)/nth;
+
+ // col range for this thread
+ const int ic0 = dc*ith;
+ const int ic1 = MIN(ic0 + dc, ne);
+
+ ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
+
+ for (int k = 1; k < nth; k++) {
+ ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
+ }
+
+ return;
+ }
+
+ if (nb01 >= nb00) {
+ // TODO: do not support transposed src1
+ assert(nb10 == sizeof(float));
+
+ // parallelize by src0 rows using ggml_vec_dot_f32
+
+ // total rows in src0
+ const int nr = ne01*ne02*ne03;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ for (int ic = 0; ic < ne11; ++ic) {
+ // src1 indices
+ const int i13 = i03;
+ const int i12 = i02;
+ const int i11 = ic;
+
+ // dst indices
+ const int i0 = i01;
+ const int i1 = i11;
+ const int i2 = i02;
+ const int i3 = i03;
+
+ ggml_vec_dot_f32(ne00,
+ (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+ (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
+ }
+ }
+ } else {
+ // parallelize by src1 columns using ggml_vec_mad_f32
+ // each thread has its own work data
+ // during FINALIZE we accumulate all work data into dst
+
+ // total columns in src1
+ const int nc = ne10;
+
+ // columns per thread
+ const int dc = (nc + nth - 1)/nth;
+
+ // column range for this thread
+ const int ic0 = dc*ith;
+ const int ic1 = MIN(ic0 + dc, nc);
+
+ // work data for thread
+ const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+ float * const wdata = params->wdata;
+
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ for (int ic = ic0; ic < ic1; ++ic) {
+ // src1 indices
+ const int i10 = ic;
+
+ // src0 indices
+ const int i03 = i13;
+ const int i02 = i12;
+ const int i00 = ic;
+
+ // dst indices
+ const int i1 = i11;
+ const int i2 = i12;
+ const int i3 = i13;
+
+ assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+ ggml_vec_mad_f32(ne01,
+ (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
+ (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
+ *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
+ }
+ }
+ }
+ }
+ }
+
+ //int64_t t1 = ggml_perf_time_us();
+ //static int64_t acc = 0;
+ //acc += t1 - t0;
+ //if (t1 - t0 > 10) {
+ // printf("\n");
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+ //}
+}
+
+static void ggml_compute_forward_mul_mat_f16_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ const int ne12 = src1->ne[2];
+ const int ne13 = src1->ne[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ const int ne2 = dst->ne[2];
+ const int ne3 = dst->ne[3];
+ const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ const int nb12 = src1->nb[2];
+ const int nb13 = src1->nb[3];
+
+ const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_ASSERT(ne02 == ne12);
+ GGML_ASSERT(ne03 == ne13);
+ GGML_ASSERT(ne2 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+
+ // TODO: we don't support permuted src0
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ GGML_ASSERT(ne0 == ne01);
+ GGML_ASSERT(ne1 == ne11);
+ GGML_ASSERT(ne2 == ne02);
+ GGML_ASSERT(ne3 == ne03);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+ //
+ // nb00 < nb01 - src0 is transposed
+ // compute by src0 columns
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+ if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (params->ith != 0) return;
+
+ if (params->type == GGML_TASK_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ float * const wdata = params->wdata;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ {
+ int id = 0;
+ for (int i01 = 0; i01 < ne01; ++i01) {
+ for (int i00 = 0; i00 < ne00; ++i00) {
+ wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
+ }
+ }
+ }
+
+ const float * x = wdata;
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+ // float * z = wdata + ne00*ne01;
+
+ // z = x * yT
+ //{
+ // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+ // ne01, ne11, ne00,
+ // 1.0f, x, ne00,
+ // y, ne00,
+ // 0.0f, z, ne11);
+ //}
+
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+ // transpose z
+ //for (int j = 0; j < ne11; ++j) {
+ // for (int i = 0; i < ne01; ++i) {
+ // d[j*ne01 + i] = z[i*ne11 + j];
+ // }
+ //}
+
+ {
+#if 1
+ // zT = y * xT
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+ ne11, ne01, ne10,
+ 1.0f, y, ne00,
+ x, ne00,
+ 0.0f, d, ne01);
+#else
+ // zT = (xT * y)T
+ cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
+ ne01, ne11, ne10,
+ 1.0f, x, ne00,
+ y, ne00,
+ 0.0f, d, ne01);
+#endif
+ }
+ }
+ }
+
+ //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
+
+ return;
+ }
+#endif
+
+ if (params->type == GGML_TASK_INIT) {
+ if (nb01 >= nb00) {
+ ggml_fp16_t * const wdata = params->wdata;
+
+ int id = 0;
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ for (int i10 = 0; i10 < ne10; ++i10) {
+ wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+ }
+ }
+ }
+ }
+
+ GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
+
+ return;
+ }
+
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ if (nb01 >= nb00) {
+ return;
+ }
+
+ // TODO: fix this memset (wsize is overestimated)
+ //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
+
+ ggml_fp16_t * const wdata = params->wdata;
+
+ // cols per thread
+ const int dc = (ne + nth - 1)/nth;
+
+ // col range for this thread
+ const int ic0 = dc*ith;
+ const int ic1 = MIN(ic0 + dc, ne);
+
+ for (int i = ic0; i < ic1; ++i) {
+ ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]);
+ }
+
+ for (int k = 1; k < nth; k++) {
+ for (int i = ic0; i < ic1; ++i) {
+ ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
+ }
+ }
+
+ return;
+ }
+
+ if (nb01 >= nb00) {
+ // fp16 -> half the size, so divide by 2
+ // TODO: do not support transposed src1
+ assert(nb10/2 == sizeof(ggml_fp16_t));
+
+ // parallelize by src0 rows using ggml_vec_dot_f32
+
+ // total rows in src0
+ const int nr = ne01*ne02*ne03;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ ggml_fp16_t * wdata = params->wdata;
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int i13 = i03;
+ const int i12 = i02;
+
+ const int i0 = i01;
+ const int i2 = i02;
+ const int i3 = i03;
+
+ ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+ ggml_fp16_t * src1_col = wdata + (i13*ne12*ne11 + i12*ne11 + 0)*ne00;
+
+ float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+ for (int ic = 0; ic < ne11; ++ic) {
+ assert(ne00 % 32 == 0);
+
+ ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
+ }
+ }
+ } else {
+ // parallelize by src1 columns using ggml_vec_mad_f32
+ // each thread has its own work data
+ // during FINALIZE we accumulate all work data into dst
+
+ // total columns in src1
+ const int nc = ne10;
+
+ // columns per thread
+ const int dc = (nc + nth - 1)/nth;
+
+ // column range for this thread
+ const int ic0 = dc*ith;
+ const int ic1 = MIN(ic0 + dc, nc);
+
+ // work data for thread
+ const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
+ ggml_fp16_t * const wdata = params->wdata;
+
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ // dst indices
+ const int i1 = i11;
+ const int i2 = i12;
+ const int i3 = i13;
+
+ ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+
+ for (int ic = ic0; ic < ic1; ++ic) {
+ // src1 indices
+ const int i10 = ic;
+
+ // src0 indices
+ const int i03 = i13;
+ const int i02 = i12;
+ const int i00 = ic;
+
+ assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+
+ ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
+ float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+
+ ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
+ }
+ }
+ }
+ }
+ }
+
+ //int64_t t1 = ggml_time_us();
+ //static int64_t acc = 0;
+ //acc += t1 - t0;
+ //if (t1 - t0 > 10) {
+ // printf("\n");
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+ //}
+}
+
+static void ggml_compute_forward_mul_mat(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_scale
+
+static void ggml_compute_forward_scale_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_scalar(src1));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // scale factor
+ const float v = *(float *) src1->data;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v);
+ }
+}
+
+static void ggml_compute_forward_scale(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_scale_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_cpy
+
+static void ggml_compute_forward_cpy(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ ggml_compute_forward_dup(params, src0, dst);
+}
+
+// ggml_compute_forward_reshape
+
+static void ggml_compute_forward_reshape(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ // NOP
+ UNUSED(params);
+ UNUSED(src0);
+ UNUSED(dst);
+}
+
+// ggml_compute_forward_view
+
+static void ggml_compute_forward_view(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0) {
+ // NOP
+ UNUSED(params);
+ UNUSED(src0);
+}
+
+// ggml_compute_forward_permute
+
+static void ggml_compute_forward_permute(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0) {
+ // NOP
+ UNUSED(params);
+ UNUSED(src0);
+}
+
+// ggml_compute_forward_transpose
+
+static void ggml_compute_forward_transpose(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0) {
+ // NOP
+ UNUSED(params);
+ UNUSED(src0);
+}
+
+// ggml_compute_forward_get_rows
+
+static void ggml_compute_forward_get_rows_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nelements(src1);
+
+ assert( dst->ne[0] == nc);
+ assert( dst->ne[1] == nr);
+ assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+ for (int i = 0; i < nr; ++i) {
+ const int r = ((int32_t *) src1->data)[i];
+
+ for (int j = 0; j < nc; ++j) {
+ ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
+ ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
+ }
+ }
+}
+
+static void ggml_compute_forward_get_rows_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nelements(src1);
+
+ assert( dst->ne[0] == nc);
+ assert( dst->ne[1] == nr);
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < nr; ++i) {
+ const int r = ((int32_t *) src1->data)[i];
+
+ ggml_vec_cpy_f32(nc,
+ (float *) ((char *) dst->data + i*dst->nb[1]),
+ (float *) ((char *) src0->data + r*src0->nb[1]));
+ }
+}
+
+static void ggml_compute_forward_get_rows(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+static void ggml_compute_forward_diag_mask_inf_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(src1->type == GGML_TYPE_I32);
+ assert(ggml_nelements(src1) == 1);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n_past = ((int32_t *) src1->data)[0];
+
+ // TODO: handle transposed/permuted matrices
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+ const int nr = src0->ne[1];
+ const int nz = n/nr;
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int k = 0; k < nz; k++) {
+ for (int j = 0; j < nr; j++) {
+ for (int i = n_past; i < nc; i++) {
+ if (i > n_past + j) {
+ *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_diag_mask_inf(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_soft_max
+
+static void ggml_compute_forward_soft_max_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // TODO: handle transposed/permuted matrices
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float *p = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(p[i]));
+ }
+#endif
+
+ float max = -INFINITY;
+ for (int i = 0; i < nc; i++) {
+ max = MAX(max, p[i]);
+ }
+
+ ggml_float sum = 0.0;
+
+ uint16_t ss;
+ for (int i = 0; i < nc; i++) {
+ if (p[i] == -INFINITY) {
+ p[i] = 0.0;
+ } else {
+ //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
+ ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
+ memcpy(&ss, &s, sizeof(ss));
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
+ sum += val;
+ p[i] = val;
+ }
+ }
+
+ assert(sum > 0.0f);
+
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(nc, p, sum);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(p[i]));
+ assert(!isinf(p[i]));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_soft_max(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_soft_max_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_rope
+
+static void ggml_compute_forward_rope_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(src1->type == GGML_TYPE_I32);
+ assert(ggml_nelements(src1) == 3);
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n_past = ((int32_t *) src1->data)[0];
+ const int n_dims = ((int32_t *) src1->data)[1];
+ const int mode = ((int32_t *) src1->data)[2];
+
+ //const int ne0 = src0->ne[0];
+ const int ne1 = src0->ne[1];
+ const int ne2 = src0->ne[2];
+ const int ne3 = src0->ne[3];
+
+ const int nb0 = src0->nb[0];
+ const int nb1 = src0->nb[1];
+ const int nb2 = src0->nb[2];
+ const int nb3 = src0->nb[3];
+
+ //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+ //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+ assert(nb0 == sizeof(float));
+
+ // TODO: optimize
+ for (int i3 = 0; i3 < ne3; i3++) {
+ for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+ const int p = (mode == 0 ? n_past + i2 : i2);
+ for (int i1 = 0; i1 < ne1; i1++) {
+ for (int i0 = 0; i0 < n_dims; i0 += 2) {
+ const double theta = pow(10000.0, ((double)-i0)/n_dims);
+
+ const double cos_theta = cos(p*theta);
+ const double sin_theta = sin(p*theta);
+
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ double x0 = src[0];
+ double x1 = src[1];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_rope(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rope_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_conv_1d_1s
+
+static void ggml_compute_forward_conv_1d_1s_f16_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ //const int ne12 = src1->ne[2];
+ //const int ne13 = src1->ne[3];
+
+ //const int ne0 = dst->ne[0];
+ //const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+ //const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ //const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ //const int nb12 = src1->nb[2];
+ //const int nb13 = src1->nb[3];
+
+ //const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ //const int nb2 = dst->nb[2];
+ //const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00;
+ const int nh = nk/2;
+
+ const int ew0 = ggml_up32(ne01);
+
+ GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (params->type == GGML_TASK_INIT) {
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+ ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ew0 + i01] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
+
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ ggml_fp16_t * dst_data = wdata;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // total rows in dst
+ const int nr = ne02;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ for (int i0 = 0; i0 < ne10; ++i0) {
+ dst_data[i0] = 0;
+ for (int k = -nh; k <= nh; k++) {
+ float v = 0.0f;
+ ggml_vec_dot_f16(ew0, &v,
+ (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
+ (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+ dst_data[i0] += v;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_conv_1d_1s_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ //const int ne12 = src1->ne[2];
+ //const int ne13 = src1->ne[3];
+
+ //const int ne0 = dst->ne[0];
+ //const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+ //const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ //const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ //const int nb12 = src1->nb[2];
+ //const int nb13 = src1->nb[3];
+
+ //const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ //const int nb2 = dst->nb[2];
+ //const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00;
+ const int nh = nk/2;
+
+ const int ew0 = ggml_up32(ne01);
+
+ GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (params->type == GGML_TASK_INIT) {
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0)
+ {
+ float * const wdata = (float *) params->wdata + 0;
+
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+ float * dst_data = wdata + i02*ew0*ne00;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ew0 + i01] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
+
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ float * dst_data = wdata;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[(i10 + nh)*ew0 + i11] = src[i10];
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // total rows in dst
+ const int nr = ne02;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ for (int i0 = 0; i0 < ne10; ++i0) {
+ dst_data[i0] = 0;
+ for (int k = -nh; k <= nh; k++) {
+ float v = 0.0f;
+ ggml_vec_dot_f32(ew0, &v,
+ (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
+ (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+ dst_data[i0] += v;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_conv_1d_1s(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_conv_1d_2s
+
+static void ggml_compute_forward_conv_1d_2s_f16_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ //const int ne12 = src1->ne[2];
+ //const int ne13 = src1->ne[3];
+
+ //const int ne0 = dst->ne[0];
+ //const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+ //const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ //const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ //const int nb12 = src1->nb[2];
+ //const int nb13 = src1->nb[3];
+
+ //const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ //const int nb2 = dst->nb[2];
+ //const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00;
+ const int nh = nk/2;
+
+ const int ew0 = ggml_up32(ne01);
+
+ GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (params->type == GGML_TASK_INIT) {
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+ ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ew0 + i01] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
+
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ ggml_fp16_t * dst_data = wdata;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // total rows in dst
+ const int nr = ne02;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ for (int i0 = 0; i0 < ne10; i0 += 2) {
+ dst_data[i0/2] = 0;
+ for (int k = -nh; k <= nh; k++) {
+ float v = 0.0f;
+ ggml_vec_dot_f16(ew0, &v,
+ (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
+ (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+ dst_data[i0/2] += v;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_conv_1d_2s_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ //const int ne03 = src0->ne[3];
+
+ const int ne10 = src1->ne[0];
+ const int ne11 = src1->ne[1];
+ //const int ne12 = src1->ne[2];
+ //const int ne13 = src1->ne[3];
+
+ //const int ne0 = dst->ne[0];
+ //const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+ //const int ne = ne0*ne1*ne2*ne3;
+
+ const int nb00 = src0->nb[0];
+ const int nb01 = src0->nb[1];
+ const int nb02 = src0->nb[2];
+ //const int nb03 = src0->nb[3];
+
+ const int nb10 = src1->nb[0];
+ const int nb11 = src1->nb[1];
+ //const int nb12 = src1->nb[2];
+ //const int nb13 = src1->nb[3];
+
+ //const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ //const int nb2 = dst->nb[2];
+ //const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00;
+ const int nh = nk/2;
+
+ const int ew0 = ggml_up32(ne01);
+
+ GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (params->type == GGML_TASK_INIT) {
+ // TODO: fix this memset (wsize is overestimated)
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0)
+ {
+ float * const wdata = (float *) params->wdata + 0;
+
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+ float * dst_data = wdata + i02*ew0*ne00;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ew0 + i01] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
+
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ float * dst_data = wdata;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[(i10 + nh)*ew0 + i11] = src[i10];
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // total rows in dst
+ const int nr = ne02;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ for (int i0 = 0; i0 < ne10; i0 += 2) {
+ dst_data[i0/2] = 0;
+ for (int k = -nh; k <= nh; k++) {
+ float v = 0.0f;
+ ggml_vec_dot_f32(ew0, &v,
+ (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0,
+ (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0);
+
+ dst_data[i0/2] += v;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_conv_1d_2s(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_flash_attn
+
+static void ggml_compute_forward_flash_attn_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const bool masked,
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int neq0 = q->ne[0];
+ const int neq1 = q->ne[1];
+ const int neq2 = q->ne[2];
+ const int neq3 = q->ne[3];
+
+ const int nek0 = k->ne[0];
+ const int nek1 = k->ne[1];
+ //const int nek2 = k->ne[2];
+ //const int nek3 = k->ne[3];
+
+ //const int nev0 = v->ne[0];
+ const int nev1 = v->ne[1];
+ //const int nev2 = v->ne[2];
+ //const int nev3 = v->ne[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+
+ const int nbk0 = k->nb[0];
+ const int nbk1 = k->nb[1];
+ const int nbk2 = k->nb[2];
+ const int nbk3 = k->nb[3];
+
+ const int nbq0 = q->nb[0];
+ const int nbq1 = q->nb[1];
+ const int nbq2 = q->nb[2];
+ const int nbq3 = q->nb[3];
+
+ const int nbv0 = v->nb[0];
+ const int nbv1 = v->nb[1];
+ const int nbv2 = v->nb[2];
+ const int nbv3 = v->nb[3];
+
+ const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int D = neq0;
+ const int N = neq1;
+ const int P = nek1 - N;
+ const int M = P + N;
+
+ GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne1 == N);
+ GGML_ASSERT(P >= 0);
+
+ GGML_ASSERT(nbq0 == sizeof(float));
+ GGML_ASSERT(nbk0 == sizeof(float));
+ GGML_ASSERT(nbv0 == sizeof(float));
+
+ GGML_ASSERT(neq0 == D);
+ GGML_ASSERT(nek0 == D);
+ GGML_ASSERT(nev1 == D);
+
+ GGML_ASSERT(neq1 == N);
+ GGML_ASSERT(nek1 == N + P);
+ GGML_ASSERT(nev1 == D);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ if (params->type == GGML_TASK_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // parallelize by q rows using ggml_vec_dot_f32
+
+ // total rows in q
+ const int nr = neq1*neq2*neq3;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ const float scale = 1.0/sqrt((double) D);
+
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // q indices
+ const int iq3 = ir/(neq2*neq1);
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+ float * S = (float *) params->wdata + ith*(M + CACHE_LINE_SIZE_F32);
+
+ for (int ic = 0; ic < nek1; ++ic) {
+ // k indices
+ const int ik3 = iq3;
+ const int ik2 = iq2;
+ const int ik1 = ic;
+
+ // S indices
+ const int i1 = ik1;
+
+ ggml_vec_dot_f32(neq0,
+ S + i1,
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+ }
+
+ // scale
+ ggml_vec_scale_f32(nek1, S, scale);
+
+ if (masked) {
+ for (int i = P; i < M; i++) {
+ if (i > P + iq1) {
+ S[i] = -INFINITY;
+ }
+ }
+ }
+
+ // softmax
+ {
+ float max = -INFINITY;
+ for (int i = 0; i < M; i++) {
+ max = MAX(max, S[i]);
+ }
+
+ ggml_float sum = 0.0;
+
+ uint16_t ss;
+ for (int i = 0; i < M; i++) {
+ if (S[i] == -INFINITY) {
+ S[i] = 0.0;
+ } else {
+ //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
+ ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max);
+ memcpy(&ss, &s, sizeof(ss));
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
+ sum += val;
+ S[i] = val;
+ }
+ }
+
+ assert(sum > 0.0f);
+
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(M, S, sum);
+ }
+
+ for (int ic = 0; ic < nev1; ++ic) {
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ ggml_vec_dot_f32(nek1,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ S);
+ }
+ }
+}
+
+static void ggml_compute_forward_flash_attn_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const bool masked,
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int neq0 = q->ne[0];
+ const int neq1 = q->ne[1];
+ const int neq2 = q->ne[2];
+ const int neq3 = q->ne[3];
+
+ const int nek0 = k->ne[0];
+ const int nek1 = k->ne[1];
+ //const int nek2 = k->ne[2];
+ //const int nek3 = k->ne[3];
+
+ //const int nev0 = v->ne[0];
+ const int nev1 = v->ne[1];
+ //const int nev2 = v->ne[2];
+ //const int nev3 = v->ne[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ //const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+
+ const int nbk0 = k->nb[0];
+ const int nbk1 = k->nb[1];
+ const int nbk2 = k->nb[2];
+ const int nbk3 = k->nb[3];
+
+ const int nbq0 = q->nb[0];
+ const int nbq1 = q->nb[1];
+ const int nbq2 = q->nb[2];
+ const int nbq3 = q->nb[3];
+
+ const int nbv0 = v->nb[0];
+ const int nbv1 = v->nb[1];
+ const int nbv2 = v->nb[2];
+ const int nbv3 = v->nb[3];
+
+ const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int D = neq0;
+ const int N = neq1;
+ const int P = nek1 - N;
+ const int M = P + N;
+
+ GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne1 == N);
+ GGML_ASSERT(P >= 0);
+
+ GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
+
+ GGML_ASSERT(neq0 == D);
+ GGML_ASSERT(nek0 == D);
+ GGML_ASSERT(nev1 == D);
+
+ GGML_ASSERT(neq1 == N);
+ GGML_ASSERT(nek1 == N + P);
+ GGML_ASSERT(nev1 == D);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ if (params->type == GGML_TASK_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // parallelize by q rows using ggml_vec_dot_f32
+
+ // total rows in q
+ const int nr = neq1*neq2*neq3;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ const float scale = 1.0/sqrt((double) D);
+
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // q indices
+ const int iq3 = ir/(neq2*neq1);
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+ float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
+
+ for (int ic = 0; ic < nek1; ++ic) {
+ // k indices
+ const int ik3 = iq3;
+ const int ik2 = iq2;
+ const int ik1 = ic;
+
+ // S indices
+ const int i1 = ik1;
+
+ ggml_vec_dot_f16(neq0,
+ S + i1,
+ (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+ }
+
+ // scale
+ ggml_vec_scale_f32(nek1, S, scale);
+
+ if (masked) {
+ for (int i = P; i < M; i++) {
+ if (i > P + iq1) {
+ S[i] = -INFINITY;
+ }
+ }
+ }
+
+ // softmax
+ {
+ float max = -INFINITY;
+ for (int i = 0; i < M; i++) {
+ max = MAX(max, S[i]);
+ }
+
+ ggml_float sum = 0.0;
+
+ uint16_t ss;
+ for (int i = 0; i < M; i++) {
+ if (S[i] == -INFINITY) {
+ S[i] = 0.0;
+ } else {
+ //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
+ ggml_fp16_t s = GGML_FP32_TO_FP16(S[i] - max);
+ memcpy(&ss, &s, sizeof(ss));
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[ss]);
+ sum += val;
+ S[i] = val;
+ }
+ }
+
+ assert(sum > 0.0f);
+
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(M, S, sum);
+ }
+
+ ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
+
+ for (int i = 0; i < M; i++) {
+ S16[i] = GGML_FP32_TO_FP16(S[i]);
+ }
+
+ for (int ic = 0; ic < nev1; ++ic) {
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ ggml_vec_dot_f16(nek1,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ S16);
+ }
+ }
+}
+
+static void ggml_compute_forward_flash_attn(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const bool masked,
+ struct ggml_tensor * dst) {
+ switch (q->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_flash_ff
+
+static void ggml_compute_forward_flash_ff_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * a, // F16
+ const struct ggml_tensor * b0, // F16 fc_w
+ const struct ggml_tensor * b1, // F32 fc_b
+ const struct ggml_tensor * c0, // F16 proj_w
+ const struct ggml_tensor * c1, // F32 proj_b
+ struct ggml_tensor * dst) {
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ const int nea0 = a->ne[0];
+ const int nea1 = a->ne[1];
+ const int nea2 = a->ne[2];
+ const int nea3 = a->ne[3];
+
+ const int neb00 = b0->ne[0];
+ const int neb01 = b0->ne[1];
+ //const int neb02 = b0->ne[2];
+ //const int neb03 = b0->ne[3];
+
+ const int neb10 = b1->ne[0];
+ const int neb11 = b1->ne[1];
+ //const int neb12 = b1->ne[2];
+ //const int neb13 = b1->ne[3];
+
+ const int nec00 = c0->ne[0];
+ const int nec01 = c0->ne[1];
+ //const int nec02 = c0->ne[2];
+ //const int nec03 = c0->ne[3];
+
+ const int nec10 = c1->ne[0];
+ const int nec11 = c1->ne[1];
+ //const int nec12 = c1->ne[2];
+ //const int nec13 = c1->ne[3];
+
+ const int ne0 = dst->ne[0];
+ const int ne1 = dst->ne[1];
+ const int ne2 = dst->ne[2];
+ //const int ne3 = dst->ne[3];
+
+ const int nba0 = a->nb[0];
+ const int nba1 = a->nb[1];
+ const int nba2 = a->nb[2];
+ const int nba3 = a->nb[3];
+
+ const int nbb00 = b0->nb[0];
+ const int nbb01 = b0->nb[1];
+ const int nbb02 = b0->nb[2];
+ const int nbb03 = b0->nb[3];
+
+ const int nbb10 = b1->nb[0];
+ //const int nbb11 = b1->nb[1];
+ //const int nbb12 = b1->nb[2];
+ //const int nbb13 = b1->nb[3];
+
+ const int nbc00 = c0->nb[0];
+ const int nbc01 = c0->nb[1];
+ const int nbc02 = c0->nb[2];
+ const int nbc03 = c0->nb[3];
+
+ const int nbc10 = c1->nb[0];
+ //const int nbc11 = c1->nb[1];
+ //const int nbc12 = c1->nb[2];
+ //const int nbc13 = c1->nb[3];
+
+ const int nb0 = dst->nb[0];
+ const int nb1 = dst->nb[1];
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int D = nea0;
+ //const int N = nea1;
+ const int M = neb01;
+
+ GGML_ASSERT(ne0 == nea0);
+ GGML_ASSERT(ne1 == nea1);
+ GGML_ASSERT(ne2 == nea2);
+
+ GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nbb10 == sizeof(float));
+ GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nbc10 == sizeof(float));
+
+ GGML_ASSERT(neb00 == D);
+ GGML_ASSERT(neb01 == M);
+ GGML_ASSERT(neb10 == M);
+ GGML_ASSERT(neb11 == 1);
+
+ GGML_ASSERT(nec00 == M);
+ GGML_ASSERT(nec01 == D);
+ GGML_ASSERT(nec10 == D);
+ GGML_ASSERT(nec11 == 1);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ if (params->type == GGML_TASK_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // parallelize by a rows using ggml_vec_dot_f32
+
+ // total rows in a
+ const int nr = nea1*nea2*nea3;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // a indices
+ const int ia3 = ir/(nea2*nea1);
+ const int ia2 = (ir - ia3*nea2*nea1)/nea1;
+ const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
+
+ float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
+
+ for (int ic = 0; ic < neb01; ++ic) {
+ // b0 indices
+ const int ib03 = ia3;
+ const int ib02 = ia2;
+ const int ib01 = ic;
+
+ // S indices
+ const int i1 = ib01;
+
+ ggml_vec_dot_f16(nea0,
+ S + i1,
+ (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
+ (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
+ }
+
+ ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
+ //ggml_vec_gelu_f32(neb01, S, S);
+
+ ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
+
+ for (int i = 0; i < M; i++) {
+ S16[i] = GGML_FP32_TO_FP16(S[i]);
+ }
+
+ ggml_vec_gelu_f16(neb01, S16, S16);
+
+ {
+ // dst indices
+ const int i1 = ia1;
+ const int i2 = ia2;
+ const int i3 = ia3;
+
+ for (int ic = 0; ic < nec01; ++ic) {
+
+ ggml_vec_dot_f16(neb01,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),
+ S16);
+ }
+
+ ggml_vec_add_f32(nec01,
+ (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
+ (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
+ (float *) c1->data);
+ }
+ }
+}
+
+static void ggml_compute_forward_flash_ff(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * a,
+ const struct ggml_tensor * b0,
+ const struct ggml_tensor * b1,
+ const struct ggml_tensor * c0,
+ const struct ggml_tensor * c1,
+ struct ggml_tensor * dst) {
+ switch (b0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(false); // TODO
+ } break;
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+}
+
+/////////////////////////////////
+
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+ assert(params);
+
+ switch (tensor->op) {
+ case GGML_OP_DUP:
+ {
+ ggml_compute_forward_dup(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_ADD:
+ {
+ ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_SUB:
+ {
+ ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_MUL:
+ {
+ ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_DIV:
+ {
+ ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_SQR:
+ {
+ ggml_compute_forward_sqr(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_SQRT:
+ {
+ ggml_compute_forward_sqrt(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_SUM:
+ {
+ ggml_compute_forward_sum(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_MEAN:
+ {
+ ggml_compute_forward_mean(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ ggml_compute_forward_repeat(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_ABS:
+ {
+ ggml_compute_forward_abs(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_SGN:
+ {
+ ggml_compute_forward_sgn(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_NEG:
+ {
+ ggml_compute_forward_neg(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_STEP:
+ {
+ ggml_compute_forward_step(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_RELU:
+ {
+ ggml_compute_forward_relu(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_GELU:
+ {
+ ggml_compute_forward_gelu(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_NORM:
+ {
+ ggml_compute_forward_norm(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_SCALE:
+ {
+ ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_compute_forward_cpy(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_RESHAPE:
+ {
+ ggml_compute_forward_reshape(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_VIEW:
+ {
+ ggml_compute_forward_view(params, tensor->src0);
+ } break;
+ case GGML_OP_PERMUTE:
+ {
+ ggml_compute_forward_permute(params, tensor->src0);
+ } break;
+ case GGML_OP_TRANSPOSE:
+ {
+ ggml_compute_forward_transpose(params, tensor->src0);
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ ggml_compute_forward_soft_max(params, tensor->src0, tensor);
+ } break;
+ case GGML_OP_ROPE:
+ {
+ ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_CONV_1D_1S:
+ {
+ ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_CONV_1D_2S:
+ {
+ ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
+ } break;
+ case GGML_OP_FLASH_ATTN:
+ {
+ int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
+ GGML_ASSERT(t == 0 || t == 1);
+ bool masked = t != 0;
+ ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
+ } break;
+ case GGML_OP_FLASH_FF:
+ {
+ ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
+ } break;
+ case GGML_OP_NONE:
+ {
+ // nop
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
+ struct ggml_tensor * src0 = tensor->src0;
+ struct ggml_tensor * src1 = tensor->src1;
+
+ switch (tensor->op) {
+ case GGML_OP_DUP:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ }
+ } break;
+ case GGML_OP_ADD:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ }
+ if (src1->grad) {
+ src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
+ }
+ } break;
+ case GGML_OP_SUB:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+ }
+ if (src1->grad) {
+ src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
+ }
+ } break;
+ case GGML_OP_MUL:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_mul(ctx, src1, tensor->grad),
+ inplace);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_add_impl(ctx,
+ src1->grad,
+ ggml_mul(ctx, src0, tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_DIV:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_div(ctx, tensor->grad, src1),
+ inplace);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_sub_impl(ctx,
+ src1->grad,
+ ggml_mul(ctx,
+ tensor->grad,
+ ggml_div(ctx, tensor, src1)),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SQR:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ ggml_mul(ctx, src0, tensor->grad),
+ ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SQRT:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_div(ctx,
+ ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
+ tensor),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SUM:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_repeat(ctx, tensor->grad, src0->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_MEAN:
+ {
+ assert(false); // TODO: implement
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_sum(ctx, tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_ABS:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_impl(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ ggml_sgn(ctx, src0),
+ tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SGN:
+ {
+ if (src0->grad) {
+ // noop
+ }
+ } break;
+ case GGML_OP_NEG:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+ }
+ } break;
+ case GGML_OP_STEP:
+ {
+ if (src0->grad) {
+ // noop
+ }
+ } break;
+ case GGML_OP_RELU:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_sub_impl(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ ggml_step(ctx, src0),
+ tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_GELU:
+ {
+ assert(false); // TODO: not implemented
+ } break;
+ case GGML_OP_NORM:
+ {
+ assert(false); // TODO: not implemented
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ if (src0->grad) {
+ // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
+ assert(false);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_add_impl(ctx,
+ src1->grad,
+ // TODO: fix transpose, the node will break the graph connections
+ ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
+ inplace);
+ }
+ } break;
+ case GGML_OP_SCALE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_CPY:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_RESHAPE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_VIEW:
+ {
+ GGML_ASSERT(false); // not supported
+ } break;
+ case GGML_OP_PERMUTE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_TRANSPOSE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_ROPE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_CONV_1D_1S:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_CONV_1D_2S:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_FLASH_ATTN:
+ {
+ GGML_ASSERT(false); // not supported
+ } break;
+ case GGML_OP_FLASH_FF:
+ {
+ GGML_ASSERT(false); // not supported
+ } break;
+ case GGML_OP_NONE:
+ {
+ // nop
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+ if (node->grad == NULL) {
+ // this usually happens when we generate intermediate nodes from constants in the backward pass
+ // it can also happen during forward pass, if the user performs computations with constants
+ if (node->op != GGML_OP_NONE) {
+ //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
+ }
+ }
+
+ // check if already visited
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ if (cgraph->nodes[i] == node) {
+ return;
+ }
+ }
+
+ for (int i = 0; i < cgraph->n_leafs; i++) {
+ if (cgraph->leafs[i] == node) {
+ return;
+ }
+ }
+
+ if (node->src0) {
+ ggml_visit_parents(cgraph, node->src0);
+ }
+
+ if (node->src1) {
+ ggml_visit_parents(cgraph, node->src1);
+ }
+
+ for (int i = 0; i < GGML_MAX_OPT; ++i) {
+ if (node->opt[i]) {
+ ggml_visit_parents(cgraph, node->opt[i]);
+ }
+ }
+
+ if (node->op == GGML_OP_NONE && node->grad == NULL) {
+ // reached a leaf node, not part of the gradient graph (e.g. a constant)
+ assert(cgraph->n_leafs < GGML_MAX_NODES);
+
+ cgraph->leafs[cgraph->n_leafs] = node;
+ cgraph->n_leafs++;
+ } else {
+ assert(cgraph->n_nodes < GGML_MAX_NODES);
+
+ cgraph->nodes[cgraph->n_nodes] = node;
+ cgraph->grads[cgraph->n_nodes] = node->grad;
+ cgraph->n_nodes++;
+ }
+}
+
+static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+ if (!expand) {
+ cgraph->n_nodes = 0;
+ cgraph->n_leafs = 0;
+ }
+
+ const int n0 = cgraph->n_nodes;
+ UNUSED(n0);
+
+ ggml_visit_parents(cgraph, tensor);
+
+ const int n_new = cgraph->n_nodes - n0;
+ GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
+
+ if (n_new > 0) {
+ // the last added node should always be starting point
+ assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
+ }
+}
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
+ ggml_build_forward_impl(cgraph, tensor, true);
+}
+
+struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
+ struct ggml_cgraph result = {
+ /*.n_nodes =*/ 0,
+ /*.n_leafs =*/ 0,
+ /*.n_threads =*/ 0,
+ /*.work_size =*/ 0,
+ /*.work =*/ NULL,
+ /*.nodes =*/ { NULL },
+ /*.grads =*/ { NULL },
+ /*.leafs =*/ { NULL },
+ /*.perf_runs =*/ 0,
+ /*.perf_cycles =*/ 0,
+ /*.perf_time_us =*/ 0,
+ };
+
+ ggml_build_forward_impl(&result, tensor, false);
+
+ return result;
+}
+
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
+ struct ggml_cgraph result = *gf;
+
+ assert(gf->n_nodes > 0);
+
+ // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
+ if (keep) {
+ for (int i = 0; i < gf->n_nodes; i++) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (node->grad) {
+ node->grad = ggml_dup_tensor(ctx, node);
+ gf->grads[i] = node->grad;
+ }
+ }
+ }
+
+ for (int i = gf->n_nodes - 1; i >= 0; i--) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ // because we detached the grad nodes from the original graph, we can afford inplace operations
+ if (node->grad) {
+ ggml_compute_backward(ctx, node, keep);
+ }
+ }
+
+ for (int i = gf->n_nodes - 1; i >= 0; i--) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (node->is_param) {
+ GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
+ ggml_build_forward_impl(&result, node->grad, true);
+ }
+ }
+
+ return result;
+}
+
+//
+// thread data
+//
+// synchronization is done via busy loops
+// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops
+//
+
+#ifdef __APPLE__
+
+//#include <os/lock.h>
+
+//typedef os_unfair_lock ggml_lock_t;
+//
+//#define ggml_lock_init(x) UNUSED(x)
+//#define ggml_lock_destroy(x) UNUSED(x)
+//#define ggml_lock_lock os_unfair_lock_lock
+//#define ggml_lock_unlock os_unfair_lock_unlock
+//
+//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x) UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x) UNUSED(x)
+#define ggml_lock_unlock(x) UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+typedef pthread_t ggml_thread_t;
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join pthread_join
+
+#else
+
+//typedef pthread_spinlock_t ggml_lock_t;
+
+//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE)
+//#define ggml_lock_destroy pthread_spin_destroy
+//#define ggml_lock_lock pthread_spin_lock
+//#define ggml_lock_unlock pthread_spin_unlock
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x) UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x) UNUSED(x)
+#define ggml_lock_unlock(x) UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+typedef pthread_t ggml_thread_t;
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join pthread_join
+
+#endif
+
+struct ggml_compute_state_shared {
+ ggml_lock_t spin;
+
+ int n_threads;
+
+ // synchronization primitives
+ atomic_int n_ready;
+ atomic_bool has_work;
+ atomic_bool stop; // stop all threads
+};
+
+struct ggml_compute_state {
+ ggml_thread_t thrd;
+
+ struct ggml_compute_params params;
+ struct ggml_tensor * node;
+
+ struct ggml_compute_state_shared * shared;
+};
+
+static thread_ret_t ggml_graph_compute_thread(void * data) {
+ struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+
+ const int n_threads = state->shared->n_threads;
+
+ while (true) {
+ if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) {
+ atomic_store(&state->shared->has_work, false);
+ } else {
+ while (atomic_load(&state->shared->has_work)) {
+ if (atomic_load(&state->shared->stop)) {
+ return 0;
+ }
+ ggml_lock_lock (&state->shared->spin);
+ ggml_lock_unlock(&state->shared->spin);
+ }
+ }
+
+ atomic_fetch_sub(&state->shared->n_ready, 1);
+
+ // wait for work
+ while (!atomic_load(&state->shared->has_work)) {
+ if (atomic_load(&state->shared->stop)) {
+ return 0;
+ }
+ ggml_lock_lock (&state->shared->spin);
+ ggml_lock_unlock(&state->shared->spin);
+ }
+
+ // check if we should stop
+ if (atomic_load(&state->shared->stop)) {
+ break;
+ }
+
+ if (state->node) {
+ ggml_compute_forward(&state->params, state->node);
+ state->node = NULL;
+ } else {
+ break;
+ }
+ }
+
+ return 0;
+}
+
+void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
+ if (cgraph->n_threads <= 0) {
+ cgraph->n_threads = 8;
+ }
+
+ const int n_threads = cgraph->n_threads;
+
+ struct ggml_compute_state_shared state_shared = {
+ /*.spin =*/ GGML_LOCK_INITIALIZER,
+ /*.n_threads =*/ n_threads,
+ /*.n_ready =*/ 0,
+ /*.has_work =*/ false,
+ /*.stop =*/ false,
+ };
+ struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
+
+ // create thread pool
+ if (n_threads > 1) {
+ ggml_lock_init(&state_shared.spin);
+
+ atomic_store(&state_shared.has_work, true);
+
+ for (int j = 0; j < n_threads - 1; j++) {
+ workers[j] = (struct ggml_compute_state) {
+ .thrd = 0,
+ .params = {
+ .type = GGML_TASK_COMPUTE,
+ .ith = j + 1,
+ .nth = n_threads,
+ .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+ .wdata = cgraph->work ? cgraph->work->data : NULL,
+ },
+ .node = NULL,
+ .shared = &state_shared,
+ };
+ int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+ assert(rc == 0);
+ UNUSED(rc);
+ }
+ }
+
+ // initialize tasks + work buffer
+ {
+ size_t work_size = 0;
+
+ // thread scheduling for the different operations
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ switch (node->op) {
+ case GGML_OP_DUP:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_ADD:
+ {
+ node->n_tasks = n_threads;
+ } break;
+ case GGML_OP_SUB:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_SUM:
+ case GGML_OP_MEAN:
+ case GGML_OP_REPEAT:
+ case GGML_OP_ABS:
+ case GGML_OP_SGN:
+ case GGML_OP_NEG:
+ case GGML_OP_STEP:
+ case GGML_OP_RELU:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_GELU:
+ {
+ node->n_tasks = n_threads;
+ } break;
+ case GGML_OP_NORM:
+ {
+ node->n_tasks = n_threads;
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ // TODO: use different scheduling for different matrix sizes
+ node->n_tasks = n_threads;
+
+ size_t cur = 0;
+
+ // TODO: better way to determine if the matrix is transposed
+ if (node->src0->nb[1] < node->src0->nb[0]) {
+ cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
+ } else {
+ if (node->src0->type == GGML_TYPE_F16 &&
+ node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+ if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+ cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
+ } else {
+ cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
+ }
+#else
+ cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
+#endif
+ } else if (node->src0->type == GGML_TYPE_F32 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = 0;
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_SCALE:
+ {
+ node->n_tasks = n_threads;
+ } break;
+ case GGML_OP_CPY:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ node->n_tasks = n_threads;
+ } break;
+ case GGML_OP_ROPE:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_CONV_1D_1S:
+ case GGML_OP_CONV_1D_2S:
+ {
+ node->n_tasks = n_threads;
+
+ GGML_ASSERT(node->src0->ne[3] == 1);
+ GGML_ASSERT(node->src1->ne[2] == 1);
+ GGML_ASSERT(node->src1->ne[3] == 1);
+
+ size_t cur = 0;
+ const int nk = node->src0->ne[0];
+
+ if (node->src0->type == GGML_TYPE_F16 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(ggml_fp16_t)*(
+ nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+ ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+ );
+ } else if (node->src0->type == GGML_TYPE_F32 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*(
+ nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+ ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+ );
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_FLASH_ATTN:
+ {
+ node->n_tasks = n_threads;
+
+ size_t cur = 0;
+
+ if (node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+ }
+
+ if (node->src1->type == GGML_TYPE_F16) {
+ cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+ }
+
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_FLASH_FF:
+ {
+ node->n_tasks = n_threads;
+
+ size_t cur = 0;
+
+ if (node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+ }
+
+ if (node->src1->type == GGML_TYPE_F16) {
+ cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+ }
+
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_NONE:
+ {
+ node->n_tasks = 1;
+ } break;
+ case GGML_OP_COUNT:
+ {
+ assert(false);
+ } break;
+ }
+ }
+
+ if (cgraph->work != NULL && work_size > cgraph->work_size) {
+ assert(false); // TODO: better handling
+ }
+
+ if (work_size > 0 && cgraph->work == NULL) {
+ cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
+
+ GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
+ cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
+ }
+ }
+
+ const int64_t perf_start_cycles = ggml_perf_cycles();
+ const int64_t perf_start_time_us = ggml_perf_time_us();
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
+
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ // TODO: this could be used to avoid unnecessary computations, but it needs to be improved
+ //if (node->grad == NULL && node->perf_runs > 0) {
+ // continue;
+ //}
+
+ const int64_t perf_node_start_cycles = ggml_perf_cycles();
+ const int64_t perf_node_start_time_us = ggml_perf_time_us();
+
+ // INIT
+ struct ggml_compute_params params = {
+ /*.type =*/ GGML_TASK_INIT,
+ /*.ith =*/ 0,
+ /*.nth =*/ node->n_tasks,
+ /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+ /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+ };
+
+ ggml_compute_forward(&params, node);
+
+ // COMPUTE
+ if (node->n_tasks > 1) {
+ if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+ atomic_store(&state_shared.has_work, false);
+ }
+
+ while (atomic_load(&state_shared.has_work)) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ // launch thread pool
+ for (int j = 0; j < n_threads - 1; j++) {
+ workers[j].params = (struct ggml_compute_params) {
+ .type = GGML_TASK_COMPUTE,
+ .ith = j + 1,
+ .nth = n_threads,
+ .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+ .wdata = cgraph->work ? cgraph->work->data : NULL,
+ };
+ workers[j].node = node;
+ }
+
+ atomic_fetch_sub(&state_shared.n_ready, 1);
+
+ while (atomic_load(&state_shared.n_ready) > 0) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ atomic_store(&state_shared.has_work, true);
+ }
+
+ params.type = GGML_TASK_COMPUTE;
+ ggml_compute_forward(&params, node);
+
+ // wait for thread pool
+ if (node->n_tasks > 1) {
+ if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+ atomic_store(&state_shared.has_work, false);
+ }
+
+ while (atomic_load(&state_shared.has_work)) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ atomic_fetch_sub(&state_shared.n_ready, 1);
+
+ while (atomic_load(&state_shared.n_ready) != 0) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+ }
+
+ // FINALIZE
+ if (node->n_tasks > 1) {
+ if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+ atomic_store(&state_shared.has_work, false);
+ }
+
+ while (atomic_load(&state_shared.has_work)) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ // launch thread pool
+ for (int j = 0; j < n_threads - 1; j++) {
+ workers[j].params = (struct ggml_compute_params) {
+ .type = GGML_TASK_FINALIZE,
+ .ith = j + 1,
+ .nth = n_threads,
+ .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+ .wdata = cgraph->work ? cgraph->work->data : NULL,
+ };
+ workers[j].node = node;
+ }
+
+ atomic_fetch_sub(&state_shared.n_ready, 1);
+
+ while (atomic_load(&state_shared.n_ready) > 0) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ atomic_store(&state_shared.has_work, true);
+ }
+
+ params.type = GGML_TASK_FINALIZE;
+ ggml_compute_forward(&params, node);
+
+ // wait for thread pool
+ if (node->n_tasks > 1) {
+ if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
+ atomic_store(&state_shared.has_work, false);
+ }
+
+ while (atomic_load(&state_shared.has_work)) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+
+ atomic_fetch_sub(&state_shared.n_ready, 1);
+
+ while (atomic_load(&state_shared.n_ready) != 0) {
+ ggml_lock_lock (&state_shared.spin);
+ ggml_lock_unlock(&state_shared.spin);
+ }
+ }
+
+ // performance stats (node)
+ {
+ int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles;
+ int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
+
+ node->perf_runs++;
+ node->perf_cycles += perf_cycles_cur;
+ node->perf_time_us += perf_time_us_cur;
+ }
+ }
+
+ // join thread pool
+ if (n_threads > 1) {
+ atomic_store(&state_shared.stop, true);
+ atomic_store(&state_shared.has_work, true);
+
+ for (int j = 0; j < n_threads - 1; j++) {
+ int rc = ggml_thread_join(workers[j].thrd, NULL);
+ assert(rc == 0);
+ UNUSED(rc);
+ }
+
+ ggml_lock_destroy(&state_shared.spin);
+ }
+
+ // performance stats (graph)
+ {
+ int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles;
+ int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
+
+ cgraph->perf_runs++;
+ cgraph->perf_cycles += perf_cycles_cur;
+ cgraph->perf_time_us += perf_time_us_cur;
+
+ GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
+ __func__, cgraph->perf_runs,
+ (double) perf_cycles_cur / (double) ggml_cycles_per_ms(),
+ (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs,
+ (double) perf_time_us_cur / 1000.0,
+ (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
+ }
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * grad = cgraph->grads[i];
+
+ if (grad) {
+ ggml_set_zero(grad);
+ }
+ }
+}
+
+void ggml_graph_print(const struct ggml_cgraph * cgraph) {
+ int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};
+
+ GGML_PRINT("=== GRAPH ===\n");
+
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
+
+ GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ perf_total_per_op_us[node->op] += node->perf_time_us;
+
+ GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+ i,
+ node->ne[0], node->ne[1], node->ne[2],
+ GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+ (double) node->perf_cycles / (double) ggml_cycles_per_ms(),
+ (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
+ (double) node->perf_time_us / 1000.0,
+ (double) node->perf_time_us / 1000.0 / node->perf_runs);
+ }
+
+ GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs);
+ for (int i = 0; i < cgraph->n_leafs; i++) {
+ struct ggml_tensor * node = cgraph->leafs[i];
+
+ GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n",
+ i,
+ node->ne[0], node->ne[1],
+ GGML_OP_LABEL[node->op]);
+ }
+
+ for (int i = 0; i < GGML_OP_COUNT; i++) {
+ GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
+ }
+
+ GGML_PRINT("========================================\n");
+}
+
+// check if node is part of the graph
+static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+ if (cgraph == NULL) {
+ return true;
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ if (cgraph->nodes[i] == node) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * parent = cgraph->nodes[i];
+
+ if (parent->grad == node) {
+ return parent;
+ }
+ }
+
+ return NULL;
+}
+
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
+ char color[16];
+
+ FILE * fp = fopen(filename, "w");
+ assert(fp);
+
+ fprintf(fp, "digraph G {\n");
+ fprintf(fp, " newrank = true;\n");
+ fprintf(fp, " rankdir = LR;\n");
+
+ for (int i = 0; i < gb->n_nodes; i++) {
+ struct ggml_tensor * node = gb->nodes[i];
+
+ if (ggml_graph_get_parent(gb, node) != NULL) {
+ continue;
+ }
+
+ if (node->is_param) {
+ snprintf(color, sizeof(color), "yellow");
+ } else if (node->grad) {
+ if (ggml_graph_find(gf, node)) {
+ snprintf(color, sizeof(color), "green");
+ } else {
+ snprintf(color, sizeof(color), "lightblue");
+ }
+ } else {
+ snprintf(color, sizeof(color), "white");
+ }
+
+ fprintf(fp, " \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"%d [%d, %d] | <x>%s",
+ (void *) node, color,
+ i, node->ne[0], node->ne[1],
+ GGML_OP_SYMBOL[node->op]);
+
+ if (node->grad) {
+ fprintf(fp, " | <g>%s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]);
+ } else {
+ fprintf(fp, "\"; ]\n");
+ }
+ }
+
+ for (int i = 0; i < gb->n_leafs; i++) {
+ struct ggml_tensor * node = gb->leafs[i];
+
+ snprintf(color, sizeof(color), "pink");
+
+ if (ggml_nelements(node) == 1) {
+ fprintf(fp, " \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"<x>%.1e\"; ]\n",
+ (void *) node, color, ggml_get_f32_1d(node, 0));
+ } else {
+ fprintf(fp, " \"%p\" [ \
+style = filled; fillcolor = %s; shape = record; \
+label=\"<x>CONST %d [%d, %d]\"; ]\n",
+ (void *) node, color,
+ i, node->ne[0], node->ne[1]);
+ }
+ }
+
+ for (int i = 0; i < gb->n_nodes; i++) {
+ struct ggml_tensor * node = gb->nodes[i];
+
+ struct ggml_tensor * parent = ggml_graph_get_parent(gb, node);
+
+ if (node->src0) {
+ struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0);
+
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n",
+ parent0 ? (void *) parent0 : (void *) node->src0,
+ parent0 ? "g" : "x",
+ parent ? (void *) parent : (void *) node,
+ parent ? "g" : "x",
+ parent ? "empty" : "vee",
+ parent ? "dashed" : "solid");
+ }
+
+ if (node->src1) {
+ struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1);
+
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n",
+ parent1 ? (void *) parent1 : (void *) node->src1,
+ parent1 ? "g" : "x",
+ parent ? (void *) parent : (void *) node,
+ parent ? "g" : "x",
+ parent ? "empty" : "vee",
+ parent ? "dashed" : "solid");
+ }
+ }
+
+ for (int i = 0; i < gb->n_leafs; i++) {
+ struct ggml_tensor * node = gb->leafs[i];
+
+ if (node->src0) {
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n",
+ (void *) node->src0, "x",
+ (void *) node, "x");
+ }
+
+ if (node->src1) {
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n",
+ (void *) node->src1, "x",
+ (void *) node, "x");
+ }
+ }
+
+ fprintf(fp, "}\n");
+
+ fclose(fp);
+
+ GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
+ int i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to set tensor from array
+ for (int j = 0; j < ne; ++j) {
+ ggml_set_f32_1d(ps[p], j, x[i++]);
+ }
+ }
+}
+
+static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
+ int i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int j = 0; j < ne; ++j) {
+ x[i++] = ggml_get_f32_1d(ps[p], j);
+ }
+ }
+}
+
+static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
+ int i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int j = 0; j < ne; ++j) {
+ g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
+ }
+ }
+}
+
+//
+// ADAM
+//
+// ref: https://arxiv.org/pdf/1412.6980.pdf
+//
+
+static enum ggml_opt_result ggml_opt_adam(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb) {
+ assert(ggml_is_scalar(f));
+
+ gf->n_threads = params.n_threads;
+ gb->n_threads = params.n_threads;
+
+ // these will store the parameters we want to optimize
+ struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+ int np = 0;
+ int nx = 0;
+ for (int i = 0; i < gf->n_nodes; ++i) {
+ if (gf->nodes[i]->is_param) {
+ GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+ assert(np < GGML_MAX_PARAMS);
+
+ ps[np++] = gf->nodes[i];
+ nx += ggml_nelements(gf->nodes[i]);
+ }
+ }
+
+ // constants
+ const float alpha = params.adam.alpha;
+ const float beta1 = params.adam.beta1;
+ const float beta2 = params.adam.beta2;
+ const float eps = params.adam.eps;
+
+ float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
+ float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
+ float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
+ float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
+ float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
+ float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
+ float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
+
+ float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+
+ // initialize
+ ggml_vec_set_f32(nx, m, 0.0f);
+ ggml_vec_set_f32(nx, v, 0.0f);
+
+ // update view
+ ggml_opt_get_params(np, ps, x);
+
+ // compute the function value
+ ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(ctx, gb);
+
+ float fx_prev = ggml_get_f32_1d(f, 0);
+ if (pf) {
+ pf[0] = fx_prev;
+ }
+
+ int n_no_improvement = 0;
+ float fx_best = fx_prev;
+
+ // run the optimizer
+ for (int t = 0; t < params.adam.n_iter; ++t) {
+ GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
+
+ GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
+ GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0));
+ GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0));
+
+ for (int i = 0; i < np; ++i) {
+ GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i,
+ ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0));
+ }
+
+ const int64_t t_start_wall = ggml_time_us();
+ const int64_t t_start_cpu = ggml_cycles();
+ UNUSED(t_start_wall);
+ UNUSED(t_start_cpu);
+
+ {
+ // update the gradient
+ ggml_opt_get_grad(np, ps, g1);
+
+ // m_t = beta1*m_t-1 + (1 - beta1)*g_t
+ ggml_vec_scale_f32(nx, m, beta1);
+ ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1);
+
+ // g2 = g1^2
+ ggml_vec_sqr_f32 (nx, g2, g1);
+
+ // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2
+ ggml_vec_scale_f32(nx, v, beta2);
+ ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2);
+
+ // m^hat = m_t / (1 - beta1^t)
+ // v^hat = v_t / (1 - beta2^t)
+ // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
+ ggml_vec_cpy_f32 (nx, mh, m);
+ ggml_vec_cpy_f32 (nx, vh, v);
+
+ ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
+ ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1)));
+
+ ggml_vec_sqrt_f32 (nx, vh, vh);
+ ggml_vec_acc1_f32 (nx, vh, eps);
+
+ ggml_vec_div_f32 (nx, mh, mh, vh);
+ ggml_vec_sub_f32 (nx, x, x, mh);
+
+ // update the parameters
+ ggml_opt_set_params(np, ps, x);
+ }
+
+ ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(ctx, gb);
+
+ const float fx = ggml_get_f32_1d(f, 0);
+
+ // check convergence
+ if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
+ GGML_PRINT_DEBUG("converged\n");
+
+ return GGML_OPT_OK;
+ }
+
+ // delta-based convergence test
+ if (pf != NULL) {
+ // need at least params.past iterations to start checking for convergence
+ if (params.past <= t) {
+ const float rate = (pf[t%params.past] - fx)/fx;
+
+ if (fabs(rate) < params.delta) {
+ return GGML_OPT_OK;
+ }
+ }
+
+ pf[t%params.past] = fx;
+ }
+
+ // check for improvement
+ if (params.max_no_improvement > 0) {
+ if (fx_best > fx) {
+ fx_best = fx;
+ n_no_improvement = 0;
+ } else {
+ ++n_no_improvement;
+
+ if (n_no_improvement >= params.max_no_improvement) {
+ return GGML_OPT_OK;
+ }
+ }
+ }
+
+ fx_prev = fx;
+
+ {
+ const int64_t t_end_cpu = ggml_cycles();
+ GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC);
+ UNUSED(t_end_cpu);
+
+ const int64_t t_end_wall = ggml_time_us();
+ GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6);
+ UNUSED(t_end_wall);
+ }
+ }
+
+ return GGML_OPT_DID_NOT_CONVERGE;
+}
+
+//
+// L-BFGS
+//
+// the L-BFGS implementation below is based on the following implementation:
+//
+// https://github.com/chokkan/liblbfgs
+//
+
+struct ggml_lbfgs_iteration_data {
+ float alpha;
+ float ys;
+ float * s;
+ float * y;
+};
+
+static enum ggml_opt_result linesearch_backtracking(
+ struct ggml_context * ctx,
+ const struct ggml_opt_params * params,
+ int nx,
+ float * x,
+ float * fx,
+ float * g,
+ float * d,
+ float * step,
+ const float * xp,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ const int np,
+ struct ggml_tensor * ps[]) {
+ int count = 0;
+
+ float width = 0.0f;
+ float dg = 0.0f;
+ float finit = 0.0f;
+ float dginit = 0.0f;
+ float dgtest = 0.0f;
+
+ const float dec = 0.5f;
+ const float inc = 2.1f;
+
+ if (*step <= 0.) {
+ return GGML_LINESEARCH_INVALID_PARAMETERS;
+ }
+
+ // compute the initial gradient in the search direction
+ ggml_vec_dot_f32(nx, &dginit, g, d);
+
+ // make sure that d points to a descent direction
+ if (0 < dginit) {
+ return GGML_LINESEARCH_FAIL;
+ }
+
+ // initialize local variables
+ finit = *fx;
+ dgtest = params->lbfgs.ftol*dginit;
+
+ while (true) {
+ ggml_vec_cpy_f32(nx, x, xp);
+ ggml_vec_mad_f32(nx, x, d, *step);
+
+ // evaluate the function and gradient values
+ {
+ ggml_opt_set_params(np, ps, x);
+
+ ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(ctx, gb);
+
+ ggml_opt_get_grad(np, ps, g);
+
+ *fx = ggml_get_f32_1d(f, 0);
+ }
+
+ ++count;
+
+ if (*fx > finit + (*step)*dgtest) {
+ width = dec;
+ } else {
+ // Armijo condition is satisfied
+ if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) {
+ return count;
+ }
+
+ ggml_vec_dot_f32(nx, &dg, g, d);
+
+ // check the Wolfe condition
+ if (dg < params->lbfgs.wolfe * dginit) {
+ width = inc;
+ } else {
+ if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) {
+ // regular Wolfe conditions
+ return count;
+ }
+
+ if(dg > -params->lbfgs.wolfe*dginit) {
+ width = dec;
+ } else {
+ // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
+ return count;
+ }
+ return count;
+ }
+ }
+
+ if (*step < params->lbfgs.min_step) {
+ return GGML_LINESEARCH_MINIMUM_STEP;
+ }
+ if (*step > params->lbfgs.max_step) {
+ return GGML_LINESEARCH_MAXIMUM_STEP;
+ }
+ if (params->lbfgs.max_linesearch <= count) {
+ return GGML_LINESEARCH_MAXIMUM_ITERATIONS;
+ }
+
+ (*step) *= width;
+ }
+
+ return GGML_LINESEARCH_FAIL;
+}
+
+static enum ggml_opt_result ggml_opt_lbfgs(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb) {
+ if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
+ params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
+ if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) {
+ return GGML_OPT_INVALID_WOLFE;
+ }
+ }
+
+ gf->n_threads = params.n_threads;
+ gb->n_threads = params.n_threads;
+
+ const int m = params.lbfgs.m;
+
+ // these will store the parameters we want to optimize
+ struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+ int np = 0;
+ int nx = 0;
+ for (int i = 0; i < gf->n_nodes; ++i) {
+ if (gf->nodes[i]->is_param) {
+ GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+ assert(np < GGML_MAX_PARAMS);
+
+ ps[np++] = gf->nodes[i];
+ nx += ggml_nelements(gf->nodes[i]);
+ }
+ }
+
+ float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
+ float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
+ float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
+ float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
+ float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
+
+ float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+
+ float fx = 0.0f; // cost function value
+ float xnorm = 0.0f; // ||x||
+ float gnorm = 0.0f; // ||g||
+ float step = 0.0f;
+
+ // initialize x from the graph nodes
+ ggml_opt_get_params(np, ps, x);
+
+ // the L-BFGS memory
+ struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
+
+ for (int i = 0; i < m; ++i) {
+ lm[i].alpha = 0.0f;
+ lm[i].ys = 0.0f;
+ lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
+ lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
+ }
+
+ // evaluate the function value and its gradient
+ {
+ ggml_opt_set_params(np, ps, x);
+
+ ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(ctx, gb);
+
+ ggml_opt_get_grad(np, ps, g);
+
+ fx = ggml_get_f32_1d(f, 0);
+ }
+
+ if (pf) {
+ pf[0] = fx;
+ }
+
+ float fx_best = fx;
+
+ // search direction = -gradient
+ ggml_vec_neg_f32(nx, d, g);
+
+ // ||x||, ||g||
+ ggml_vec_norm_f32(nx, &xnorm, x);
+ ggml_vec_norm_f32(nx, &gnorm, g);
+
+ if (xnorm < 1.0f) {
+ xnorm = 1.0f;
+ }
+
+ // already optimized
+ if (gnorm/xnorm <= params.lbfgs.eps) {
+ return GGML_OPT_OK;
+ }
+
+ // initial step
+ ggml_vec_norm_inv_f32(nx, &step, d);
+
+ int j = 0;
+ int k = 1;
+ int ls = 0;
+ int end = 0;
+ int bound = 0;
+ int n_no_improvement = 0;
+
+ float ys = 0.0f;
+ float yy = 0.0f;
+ float beta = 0.0f;
+
+ while (true) {
+ // store the current position and gradient vectors
+ ggml_vec_cpy_f32(nx, xp, x);
+ ggml_vec_cpy_f32(nx, gp, g);
+
+ ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
+
+ if (ls < 0) {
+ // linesearch failed - go back to the previous point and return
+ ggml_vec_cpy_f32(nx, x, xp);
+ ggml_vec_cpy_f32(nx, g, gp);
+
+ return ls;
+ }
+
+ ggml_vec_norm_f32(nx, &xnorm, x);
+ ggml_vec_norm_f32(nx, &gnorm, g);
+
+ GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
+
+ if (xnorm < 1.0) {
+ xnorm = 1.0;
+ }
+ if (gnorm/xnorm <= params.lbfgs.eps) {
+ // converged
+ return GGML_OPT_OK;
+ }
+
+ // delta-based convergence test
+ if (pf != NULL) {
+ // need at least params.past iterations to start checking for convergence
+ if (params.past <= k) {
+ const float rate = (pf[k%params.past] - fx)/fx;
+
+ if (fabs(rate) < params.delta) {
+ return GGML_OPT_OK;
+ }
+ }
+
+ pf[k%params.past] = fx;
+ }
+
+ // check for improvement
+ if (params.max_no_improvement > 0) {
+ if (fx < fx_best) {
+ fx_best = fx;
+ n_no_improvement = 0;
+ } else {
+ n_no_improvement++;
+
+ if (n_no_improvement >= params.max_no_improvement) {
+ return GGML_OPT_OK;
+ }
+ }
+ }
+
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
+ // reached the maximum number of iterations
+ return GGML_OPT_DID_NOT_CONVERGE;
+ }
+
+ // update vectors s and y:
+ // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
+ // y_{k+1} = g_{k+1} - g_{k}.
+ //
+ ggml_vec_sub_f32(nx, lm[end].s, x, xp);
+ ggml_vec_sub_f32(nx, lm[end].y, g, gp);
+
+ // compute scalars ys and yy:
+ // ys = y^t \cdot s -> 1 / \rho.
+ // yy = y^t \cdot y.
+ //
+ ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
+ ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
+
+ lm[end].ys = ys;
+
+ // find new search direction
+ // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
+
+ bound = (m <= k) ? m : k;
+ k++;
+ end = (end + 1)%m;
+
+ // initialize search direction with -g
+ ggml_vec_neg_f32(nx, d, g);
+
+ j = end;
+ for (int i = 0; i < bound; ++i) {
+ j = (j + m - 1) % m;
+ // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
+ ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
+ lm[j].alpha /= lm[j].ys;
+ // q_{i} = q_{i+1} - \alpha_{i} y_{i}
+ ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
+ }
+
+ ggml_vec_scale_f32(nx, d, ys/yy);
+
+ for (int i = 0; i < bound; ++i) {
+ // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
+ ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
+ beta /= lm[j].ys;
+ // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
+ ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
+ j = (j + 1)%m;
+ }
+
+ step = 1.0;
+ }
+
+ return GGML_OPT_DID_NOT_CONVERGE;
+}
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
+ struct ggml_opt_params result;
+
+ switch (type) {
+ case GGML_OPT_ADAM:
+ {
+ result = (struct ggml_opt_params) {
+ .type = GGML_OPT_ADAM,
+ .n_threads = 1,
+ .past = 0,
+ .delta = 1e-5f,
+
+ .max_no_improvement = 100,
+
+ .print_forward_graph = true,
+ .print_backward_graph = true,
+
+ .adam = {
+ .n_iter = 10000,
+ .alpha = 0.001f,
+ .beta1 = 0.9f,
+ .beta2 = 0.999f,
+ .eps = 1e-8f,
+ .eps_f = 1e-5f,
+ .eps_g = 1e-3f,
+ },
+ };
+ } break;
+ case GGML_OPT_LBFGS:
+ {
+ result = (struct ggml_opt_params) {
+ .type = GGML_OPT_LBFGS,
+ .n_threads = 1,
+ .past = 0,
+ .delta = 1e-5f,
+
+ .max_no_improvement = 0,
+
+ .print_forward_graph = true,
+ .print_backward_graph = true,
+
+ .lbfgs = {
+ .m = 6,
+ .n_iter = 100,
+ .max_linesearch = 20,
+
+ .eps = 1e-5f,
+ .ftol = 1e-4f,
+ .wolfe = 0.9f,
+ .min_step = 1e-20f,
+ .max_step = 1e+20f,
+
+ .linesearch = GGML_LINESEARCH_DEFAULT,
+ },
+ };
+ } break;
+ }
+
+ return result;
+}
+
+enum ggml_opt_result ggml_opt(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f) {
+ bool free_ctx = false;
+ if (ctx == NULL) {
+ struct ggml_init_params params_ctx = {
+ .mem_size = 16*1024*1024,
+ .mem_buffer = NULL,
+ };
+
+ ctx = ggml_init(params_ctx);
+ if (ctx == NULL) {
+ return GGML_OPT_NO_CONTEXT;
+ }
+
+ free_ctx = true;
+ }
+
+ enum ggml_opt_result result = GGML_OPT_OK;
+
+ // build forward + backward compute graphs
+ struct ggml_cgraph gf = ggml_build_forward (f);
+ struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, false);
+
+ switch (params.type) {
+ case GGML_OPT_ADAM:
+ {
+ result = ggml_opt_adam(ctx, params, f, &gf, &gb);
+ } break;
+ case GGML_OPT_LBFGS:
+ {
+ result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
+ } break;
+ }
+
+ if (params.print_forward_graph) {
+ ggml_graph_print (&gf);
+ ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
+ }
+
+ if (params.print_backward_graph) {
+ ggml_graph_print (&gb);
+ ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
+ }
+
+ if (free_ctx) {
+ ggml_free(ctx);
+ }
+
+ return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+int ggml_cpu_has_avx(void) {
+#if defined(__AVX__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_avx2(void) {
+#if defined(__AVX2__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512(void) {
+#if defined(__AVX512F__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_fma(void) {
+#if defined(__FMA__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_neon(void) {
+#if defined(__ARM_NEON)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_arm_fma(void) {
+#if defined(__ARM_FEATURE_FMA)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_f16c(void) {
+#if defined(__F16C__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_fp16_va(void) {
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_wasm_simd(void) {
+#if defined(__wasm_simd128__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_blas(void) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/Whisper/source/ggml.h b/Whisper/source/ggml.h
new file mode 100644
index 0000000..a217d2d
--- /dev/null
+++ b/Whisper/source/ggml.h
@@ -0,0 +1,737 @@
+#pragma once
+
+//
+// GGML Tensor Library
+//
+// This documentation is still a work in progress.
+// If you wish some specific topics to be covered, feel free to drop a comment:
+//
+// https://github.com/ggerganov/whisper.cpp/issues/40
+//
+// ## Overview
+//
+// This library implements:
+//
+// - a set of tensor operations
+// - automatic differentiation
+// - basic optimization algorithms
+//
+// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
+// but is not limited to, the following:
+//
+// - linear regression
+// - support vector machines
+// - neural networks
+//
+// The library allows the user to define a certain function using the available tensor operations. This function
+// definition is represented internally via a computation graph. Each tensor operation in the function definition
+// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
+// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
+// using one of the available optimization algorithms.
+//
+// For example, here we define the function: f(x) = a*x^2 + b
+//
+// {
+// struct ggml_init_params params = {
+// .mem_size = 16*1024*1024,
+// .mem_buffer = NULL,
+// };
+//
+// // memory allocation happens here
+// struct ggml_context * ctx = ggml_init(params);
+//
+// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//
+// ggml_set_param(ctx, x); // x is an input variable
+//
+// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+// struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
+// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
+//
+// ...
+// }
+//
+// Notice that the function definition above does not involve any actual computation. The computation is performed only
+// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
+//
+// {
+// ...
+//
+// struct ggml_cgraph gf = ggml_build_forward(f);
+//
+// // set the input variable and parameter values
+// ggml_set_f32(x, 2.0f);
+// ggml_set_f32(a, 3.0f);
+// ggml_set_f32(b, 4.0f);
+//
+// ggml_graph_compute(ctx0, &gf);
+//
+// printf("f = %f\n", ggml_get_f32_1d(f, 0));
+//
+// ...
+// }
+//
+// The actual computation is performed in the ggml_graph_compute() function.
+//
+// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
+// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
+// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
+// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
+// actually needed.
+//
+// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
+// differentiation and optimization algorithms.
+//
+// The described approach allows to define the function graph once and then compute its forward or backward graphs
+// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
+// the user can avoid the memory allocation overhead at runtime.
+//
+// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
+// citizens, but in theory the library can be extended to support FP8 and integer data types.
+//
+// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
+// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
+// clear that the library needs to support more complex operations. The way to support these operations is not clear
+// yet, but a few examples are demonstrated in the following operations:
+//
+// - ggml_permute()
+// - ggml_conv_1d_1s()
+// - ggml_conv_1d_2s()
+//
+// For each tensor operator, the library implements a forward and backward computation function. The forward function
+// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
+// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
+// calculus class, or watch the following video:
+//
+// What is Automatic Differentiation?
+// https://www.youtube.com/watch?v=wG_nF1awSSY
+//
+//
+// ## Tensor data (struct ggml_tensor)
+//
+// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
+// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
+// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
+//
+// {
+// struct ggml_tensor * c = ggml_add(ctx, a, b);
+//
+// assert(c->src[0] == a);
+// assert(c->src[1] == b);
+// }
+//
+// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
+// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
+// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
+// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
+// contiguous in memory.
+//
+// The data of the tensor is accessed via the "data" pointer. For example:
+//
+// {
+// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
+//
+// // a[1, 2] = 1.0f;
+// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
+//
+// // a[2, 0] = 2.0f;
+// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
+//
+// ...
+// }
+//
+// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
+//
+// ## The matrix multiplication operator (ggml_mul_mat)
+//
+// TODO
+//
+//
+// ## Multi-threading
+//
+// TODO
+//
+//
+// ## Overview of ggml.c
+//
+// TODO
+//
+//
+// ## SIMD optimizations
+//
+// TODO
+//
+//
+// ## Debugging ggml
+//
+// TODO
+//
+//
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#include <stdint.h>
+#include <stddef.h>
+#include <stdbool.h>
+
+#define GGML_MAX_DIMS 4
+#define GGML_MAX_NODES 4096
+#define GGML_MAX_PARAMS 16
+#define GGML_MAX_CONTEXTS 64
+#define GGML_MAX_OPT 4
+
+#ifdef __ARM_NEON
+// we use the built-in 16-bit float type
+typedef __fp16 ggml_fp16_t;
+#else
+typedef uint16_t ggml_fp16_t;
+#endif
+
+// convert FP16 <-> FP32
+float ggml_fp16_to_fp32(ggml_fp16_t x);
+ggml_fp16_t ggml_fp32_to_fp16(float x);
+
+struct ggml_object;
+struct ggml_context;
+
+enum ggml_type {
+ GGML_TYPE_I8,
+ GGML_TYPE_I16,
+ GGML_TYPE_I32,
+ GGML_TYPE_F16,
+ GGML_TYPE_F32,
+ GGML_TYPE_COUNT,
+};
+
+// available tensor operations:
+enum ggml_op {
+ GGML_OP_NONE = 0,
+
+ GGML_OP_DUP,
+ GGML_OP_ADD,
+ GGML_OP_SUB,
+ GGML_OP_MUL,
+ GGML_OP_DIV,
+ GGML_OP_SQR,
+ GGML_OP_SQRT,
+ GGML_OP_SUM,
+ GGML_OP_MEAN,
+ GGML_OP_REPEAT,
+ GGML_OP_ABS,
+ GGML_OP_SGN,
+ GGML_OP_NEG,
+ GGML_OP_STEP,
+ GGML_OP_RELU,
+ GGML_OP_GELU,
+ GGML_OP_NORM, // normalize
+
+ GGML_OP_MUL_MAT,
+
+ GGML_OP_SCALE,
+ GGML_OP_CPY,
+ GGML_OP_RESHAPE,
+ GGML_OP_VIEW,
+ GGML_OP_PERMUTE,
+ GGML_OP_TRANSPOSE,
+ GGML_OP_GET_ROWS,
+ GGML_OP_DIAG_MASK_INF,
+ GGML_OP_SOFT_MAX,
+ GGML_OP_ROPE,
+ GGML_OP_CONV_1D_1S,
+ GGML_OP_CONV_1D_2S,
+
+ GGML_OP_FLASH_ATTN,
+ GGML_OP_FLASH_FF,
+
+ GGML_OP_COUNT,
+};
+
+// n-dimensional tensor
+struct ggml_tensor {
+ enum ggml_type type;
+
+ int n_dims;
+ int ne[GGML_MAX_DIMS]; // number of elements
+ size_t nb[GGML_MAX_DIMS]; // stride in bytes:
+ // nb[0] = sizeof(type)
+ // nb[1] = nb[0] * ne[0] + padding
+ // nb[i] = nb[i-1] * ne[i-1]
+
+ // compute data
+ enum ggml_op op;
+
+ bool is_param;
+
+ struct ggml_tensor * grad;
+ struct ggml_tensor * src0;
+ struct ggml_tensor * src1;
+ struct ggml_tensor * opt[GGML_MAX_OPT];
+
+ // thread scheduling
+ int n_tasks;
+
+ // performance
+ int perf_runs;
+ int64_t perf_cycles;
+ int64_t perf_time_us;
+
+ void * data;
+ char padding[8];
+};
+
+// computation graph
+struct ggml_cgraph {
+ int n_nodes;
+ int n_leafs;
+ int n_threads;
+
+ size_t work_size;
+ struct ggml_tensor * work;
+
+ struct ggml_tensor * nodes[GGML_MAX_NODES];
+ struct ggml_tensor * grads[GGML_MAX_NODES];
+ struct ggml_tensor * leafs[GGML_MAX_NODES];
+
+ // performance
+ int perf_runs;
+ int64_t perf_cycles;
+ int64_t perf_time_us;
+};
+
+struct ggml_init_params {
+ // memory pool
+ size_t mem_size; // bytes
+ void * mem_buffer; // if NULL, memory will be allocated internally
+};
+
+void ggml_time_init(void); // call this once at the beginning of the program
+int64_t ggml_time_ms(void);
+int64_t ggml_time_us(void);
+int64_t ggml_cycles(void);
+int64_t ggml_cycles_per_ms(void);
+
+void ggml_print_object (const struct ggml_object * obj);
+void ggml_print_objects(const struct ggml_context * ctx);
+
+int ggml_nelements(const struct ggml_tensor * tensor);
+size_t ggml_nbytes (const struct ggml_tensor * tensor);
+
+size_t ggml_type_size (enum ggml_type type);
+size_t ggml_element_size(const struct ggml_tensor * tensor);
+
+struct ggml_context * ggml_init(struct ggml_init_params params);
+void ggml_free(struct ggml_context * ctx);
+
+size_t ggml_used_mem(const struct ggml_context * ctx);
+
+struct ggml_tensor * ggml_new_tensor(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int *ne);
+
+struct ggml_tensor * ggml_new_tensor_1d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0);
+
+struct ggml_tensor * ggml_new_tensor_2d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1);
+
+struct ggml_tensor * ggml_new_tensor_3d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1,
+ int ne2);
+
+struct ggml_tensor * ggml_new_tensor_4d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int ne0,
+ int ne1,
+ int ne2,
+ int ne3);
+
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
+struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+ void * ggml_get_data (const struct ggml_tensor * tensor);
+float * ggml_get_data_f32(const struct ggml_tensor * tensor);
+
+//
+// operations on tensors with backpropagation
+//
+
+struct ggml_tensor * ggml_dup(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_sub(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_mul(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_div(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_sqr(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_sqrt(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// return scalar
+// TODO: compute sum along rows
+struct ggml_tensor * ggml_sum(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// mean along rows
+struct ggml_tensor * ggml_mean(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// if a is the same shape as b, and a is not parameter, return a
+// otherwise, return a new tensor: repeat(a) to fit in b
+struct ggml_tensor * ggml_repeat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_abs(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_sgn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_neg(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_step(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_relu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// TODO: double-check this computation is correct
+struct ggml_tensor * ggml_gelu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// normalize along rows
+// TODO: eps is hardcoded to 1e-5 for now
+struct ggml_tensor * ggml_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// A: m rows, n columns
+// B: p rows, n columns (i.e. we transpose it internally)
+// result is m columns, p rows
+struct ggml_tensor * ggml_mul_mat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+//
+// operations on tensors without backpropagation
+//
+
+// in-place, returns view(a)
+struct ggml_tensor * ggml_scale(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+// a -> b, return view(b)
+struct ggml_tensor * ggml_cpy(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+// return view(a), b specifies the new shape
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+// return view(a)
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1);
+
+// return view(a)
+// TODO: when we start computing gradient, make a copy instead of view
+struct ggml_tensor * ggml_reshape_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ int ne2);
+
+// offset in bytes
+struct ggml_tensor * ggml_view_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ size_t offset);
+
+struct ggml_tensor * ggml_view_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ size_t nb1, // row stride in bytes
+ size_t offset);
+
+struct ggml_tensor * ggml_permute(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int axis0,
+ int axis1,
+ int axis2,
+ int axis3);
+
+// alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+struct ggml_tensor * ggml_transpose(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+struct ggml_tensor * ggml_get_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+// set elements above the diagonal to -INF
+// in-place, returns view(a)
+struct ggml_tensor * ggml_diag_mask_inf(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past);
+
+// in-place, returns view(a)
+struct ggml_tensor * ggml_soft_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+// rotary position embedding
+// in-place, returns view(a)
+// if mode == 1, skip n_past elements
+// TODO: avoid creating a new tensor every time
+struct ggml_tensor * ggml_rope(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ int n_dims,
+ int mode);
+
+// padding = 1
+// TODO: we don't support extra parameters for now
+// that's why we are hard-coding the stride, padding, and dilation
+// not great ..
+struct ggml_tensor * ggml_conv_1d_1s(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_conv_1d_2s(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+struct ggml_tensor * ggml_flash_attn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ bool masked);
+
+struct ggml_tensor * ggml_flash_ff(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b0,
+ struct ggml_tensor * b1,
+ struct ggml_tensor * c0,
+ struct ggml_tensor * c1);
+
+//
+// automatic differentiation
+//
+
+void ggml_set_param(
+ struct ggml_context * ctx,
+ struct ggml_tensor * tensor);
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+
+struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
+
+void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+void ggml_graph_reset (struct ggml_cgraph * cgraph);
+
+// print info and performance information for the graph
+void ggml_graph_print(const struct ggml_cgraph * cgraph);
+
+// dump the graph into a file using the dot format
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+
+//
+// optimization
+//
+
+// optimization methods
+enum ggml_opt_type {
+ GGML_OPT_ADAM,
+ GGML_OPT_LBFGS,
+};
+
+// linesearch methods
+enum ggml_linesearch {
+ GGML_LINESEARCH_DEFAULT = 1,
+
+ GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
+ GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
+ GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
+};
+
+// optimization return values
+enum ggml_opt_result {
+ GGML_OPT_OK = 0,
+ GGML_OPT_DID_NOT_CONVERGE,
+ GGML_OPT_NO_CONTEXT,
+ GGML_OPT_INVALID_WOLFE,
+ GGML_OPT_FAIL,
+
+ GGML_LINESEARCH_FAIL = -128,
+ GGML_LINESEARCH_MINIMUM_STEP,
+ GGML_LINESEARCH_MAXIMUM_STEP,
+ GGML_LINESEARCH_MAXIMUM_ITERATIONS,
+ GGML_LINESEARCH_INVALID_PARAMETERS,
+};
+
+// optimization parameters
+//
+// see ggml.c (ggml_opt_default_params) for default values
+//
+struct ggml_opt_params {
+ enum ggml_opt_type type;
+
+ int n_threads;
+
+ // delta-based convergence test
+ //
+ // if past == 0 - disabled
+ // if past > 0:
+ // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+ //
+ int past;
+ float delta;
+
+ // maximum number of iterations without improvement
+ //
+ // if 0 - disabled
+ // if > 0:
+ // assume convergence if no cost improvement in this number of iterations
+ //
+ int max_no_improvement;
+
+ bool print_forward_graph;
+ bool print_backward_graph;
+
+ // ADAM parameters
+ struct {
+ int n_iter;
+
+ float alpha; // learning rate
+ float beta1;
+ float beta2;
+ float eps; // epsilon for numerical stability
+ float eps_f; // epsilon for convergence test
+ float eps_g; // epsilon for convergence test
+ } adam;
+
+ // LBFGS parameters
+ struct {
+ int m; // number of corrections to approximate the inv. Hessian
+ int n_iter;
+ int max_linesearch;
+
+ float eps; // convergence tolerance
+ float ftol; // line search tolerance
+ float wolfe;
+ float min_step;
+ float max_step;
+
+ enum ggml_linesearch linesearch;
+ } lbfgs;
+};
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+
+// optimize the function defined by the tensor f
+enum ggml_opt_result ggml_opt(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f);
+
+//
+// system info
+//
+
+int ggml_cpu_has_avx(void);
+int ggml_cpu_has_avx2(void);
+int ggml_cpu_has_avx512(void);
+int ggml_cpu_has_fma(void);
+int ggml_cpu_has_neon(void);
+int ggml_cpu_has_arm_fma(void);
+int ggml_cpu_has_f16c(void);
+int ggml_cpu_has_fp16_va(void);
+int ggml_cpu_has_wasm_simd(void);
+int ggml_cpu_has_blas(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/Whisper/source/whisper.cpp b/Whisper/source/whisper.cpp
new file mode 100644
index 0000000..268774d
--- /dev/null
+++ b/Whisper/source/whisper.cpp
@@ -0,0 +1,3601 @@
+#define WHISPER_BUILD
+#include "whisper.h"
+
+#include "ggml.h"
+
+#include <algorithm>
+#include <cassert>
+#define _USE_MATH_DEFINES
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <thread>
+#include <vector>
+#include <regex>
+
+#define USE_FLASH_ATTN
+//#define USE_FLASH_FF
+
+// available whisper models
+enum e_model {
+ MODEL_UNKNOWN,
+ MODEL_TINY,
+ MODEL_BASE,
+ MODEL_SMALL,
+ MODEL_MEDIUM,
+ MODEL_LARGE,
+};
+
+static const std::map<std::string, std::pair<int, std::string>> g_lang = {
+ { "en", { 0, "english", } },
+ { "zh", { 1, "chinese", } },
+ { "de", { 2, "german", } },
+ { "es", { 3, "spanish", } },
+ { "ru", { 4, "russian", } },
+ { "ko", { 5, "korean", } },
+ { "fr", { 6, "french", } },
+ { "ja", { 7, "japanese", } },
+ { "pt", { 8, "portuguese", } },
+ { "tr", { 9, "turkish", } },
+ { "pl", { 10, "polish", } },
+ { "ca", { 11, "catalan", } },
+ { "nl", { 12, "dutch", } },
+ { "ar", { 13, "arabic", } },
+ { "sv", { 14, "swedish", } },
+ { "it", { 15, "italian", } },
+ { "id", { 16, "indonesian", } },
+ { "hi", { 17, "hindi", } },
+ { "fi", { 18, "finnish", } },
+ { "vi", { 19, "vietnamese", } },
+ { "iw", { 20, "hebrew", } },
+ { "uk", { 21, "ukrainian", } },
+ { "el", { 22, "greek", } },
+ { "ms", { 23, "malay", } },
+ { "cs", { 24, "czech", } },
+ { "ro", { 25, "romanian", } },
+ { "da", { 26, "danish", } },
+ { "hu", { 27, "hungarian", } },
+ { "ta", { 28, "tamil", } },
+ { "no", { 29, "norwegian", } },
+ { "th", { 30, "thai", } },
+ { "ur", { 31, "urdu", } },
+ { "hr", { 32, "croatian", } },
+ { "bg", { 33, "bulgarian", } },
+ { "lt", { 34, "lithuanian", } },
+ { "la", { 35, "latin", } },
+ { "mi", { 36, "maori", } },
+ { "ml", { 37, "malayalam", } },
+ { "cy", { 38, "welsh", } },
+ { "sk", { 39, "slovak", } },
+ { "te", { 40, "telugu", } },
+ { "fa", { 41, "persian", } },
+ { "lv", { 42, "latvian", } },
+ { "bn", { 43, "bengali", } },
+ { "sr", { 44, "serbian", } },
+ { "az", { 45, "azerbaijani", } },
+ { "sl", { 46, "slovenian", } },
+ { "kn", { 47, "kannada", } },
+ { "et", { 48, "estonian", } },
+ { "mk", { 49, "macedonian", } },
+ { "br", { 50, "breton", } },
+ { "eu", { 51, "basque", } },
+ { "is", { 52, "icelandic", } },
+ { "hy", { 53, "armenian", } },
+ { "ne", { 54, "nepali", } },
+ { "mn", { 55, "mongolian", } },
+ { "bs", { 56, "bosnian", } },
+ { "kk", { 57, "kazakh", } },
+ { "sq", { 58, "albanian", } },
+ { "sw", { 59, "swahili", } },
+ { "gl", { 60, "galician", } },
+ { "mr", { 61, "marathi", } },
+ { "pa", { 62, "punjabi", } },
+ { "si", { 63, "sinhala", } },
+ { "km", { 64, "khmer", } },
+ { "sn", { 65, "shona", } },
+ { "yo", { 66, "yoruba", } },
+ { "so", { 67, "somali", } },
+ { "af", { 68, "afrikaans", } },
+ { "oc", { 69, "occitan", } },
+ { "ka", { 70, "georgian", } },
+ { "be", { 71, "belarusian", } },
+ { "tg", { 72, "tajik", } },
+ { "sd", { 73, "sindhi", } },
+ { "gu", { 74, "gujarati", } },
+ { "am", { 75, "amharic", } },
+ { "yi", { 76, "yiddish", } },
+ { "lo", { 77, "lao", } },
+ { "uz", { 78, "uzbek", } },
+ { "fo", { 79, "faroese", } },
+ { "ht", { 80, "haitian creole", } },
+ { "ps", { 81, "pashto", } },
+ { "tk", { 82, "turkmen", } },
+ { "nn", { 83, "nynorsk", } },
+ { "mt", { 84, "maltese", } },
+ { "sa", { 85, "sanskrit", } },
+ { "lb", { 86, "luxembourgish", } },
+ { "my", { 87, "myanmar", } },
+ { "bo", { 88, "tibetan", } },
+ { "tl", { 89, "tagalog", } },
+ { "mg", { 90, "malagasy", } },
+ { "as", { 91, "assamese", } },
+ { "tt", { 92, "tatar", } },
+ { "haw", { 93, "hawaiian", } },
+ { "ln", { 94, "lingala", } },
+ { "ha", { 95, "hausa", } },
+ { "ba", { 96, "bashkir", } },
+ { "jw", { 97, "javanese", } },
+ { "su", { 98, "sundanese", } },
+};
+
+static const size_t MB = 1024*1024;
+
+static const std::map<e_model, size_t> MEM_REQ_MODEL = {
+ { MODEL_TINY, 74ull*MB },
+ { MODEL_BASE, 142ull*MB },
+ { MODEL_SMALL, 466ull*MB },
+ { MODEL_MEDIUM, 1464ull*MB },
+ { MODEL_LARGE, 2952ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
+ { MODEL_TINY, 12ull*MB },
+ { MODEL_BASE, 24ull*MB },
+ { MODEL_SMALL, 70ull*MB },
+ { MODEL_MEDIUM, 184ull*MB },
+ { MODEL_LARGE, 306ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
+ { MODEL_TINY, 80ull*MB },
+ { MODEL_BASE, 128ull*MB },
+ { MODEL_SMALL, 300ull*MB },
+ { MODEL_MEDIUM, 680ull*MB },
+ { MODEL_LARGE, 1100ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
+ { MODEL_TINY, 104ull*MB },
+ { MODEL_BASE, 138ull*MB },
+ { MODEL_SMALL, 208ull*MB },
+ { MODEL_MEDIUM, 280ull*MB },
+ { MODEL_LARGE, 354ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_DECODE = {
+ { MODEL_TINY, 200ull*MB },
+ { MODEL_BASE, 202ull*MB },
+ { MODEL_SMALL, 204ull*MB },
+ { MODEL_MEDIUM, 206ull*MB },
+ { MODEL_LARGE, 208ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
+ { MODEL_TINY, 32ull*MB },
+ { MODEL_BASE, 44ull*MB },
+ { MODEL_SMALL, 64ull*MB },
+ { MODEL_MEDIUM, 84ull*MB },
+ { MODEL_LARGE, 110ull*MB },
+};
+
+struct whisper_mel {
+ int n_len;
+ int n_mel;
+
+ std::vector<float> data;
+};
+
+struct whisper_filters {
+ int32_t n_mel;
+ int32_t n_fft;
+
+ std::vector<float> data;
+};
+
+struct whisper_vocab {
+ using id = int32_t;
+ using token = std::string;
+
+ int n_vocab = 51864;
+
+ std::map<token, id> token_to_id;
+ std::map<id, token> id_to_token;
+
+ id token_eot = 50256;
+ id token_sot = 50257;
+ id token_prev = 50360;
+ id token_solm = 50361; // ??
+ id token_not = 50362; // no timestamps
+ id token_beg = 50363;
+
+ // available tasks
+ static const id token_translate = 50358;
+ static const id token_transcribe = 50359;
+
+ bool is_multilingual() const {
+ return n_vocab == 51865;
+ }
+};
+
+struct whisper_segment {
+ int64_t t0;
+ int64_t t1;
+
+ std::string text;
+
+ std::vector<whisper_token_data> tokens;
+};
+
+// medium
+// hparams: {
+// 'n_mels': 80,
+// 'n_vocab': 51864,
+// 'n_audio_ctx': 1500,
+// 'n_audio_state': 1024,
+// 'n_audio_head': 16,
+// 'n_audio_layer': 24,
+// 'n_text_ctx': 448,
+// 'n_text_state': 1024,
+// 'n_text_head': 16,
+// 'n_text_layer': 24
+// }
+//
+// default hparams (Whisper tiny)
+struct whisper_hparams {
+ int32_t n_vocab = 51864;
+ int32_t n_audio_ctx = 1500;
+ int32_t n_audio_state = 384;
+ int32_t n_audio_head = 6;
+ int32_t n_audio_layer = 4;
+ int32_t n_text_ctx = 448;
+ int32_t n_text_state = 384;
+ int32_t n_text_head = 6;
+ int32_t n_text_layer = 4;
+ int32_t n_mels = 80;
+ int32_t f16 = 1;
+};
+
+// audio encoding layer
+struct whisper_layer_encoder {
+ // encoder.blocks.*.attn_ln
+ struct ggml_tensor * attn_ln_0_w;
+ struct ggml_tensor * attn_ln_0_b;
+
+ // encoder.blocks.*.attn.out
+ struct ggml_tensor * attn_ln_1_w;
+ struct ggml_tensor * attn_ln_1_b;
+
+ // encoder.blocks.*.attn.query
+ struct ggml_tensor * attn_q_w;
+ struct ggml_tensor * attn_q_b;
+
+ // encoder.blocks.*.attn.key
+ struct ggml_tensor * attn_k_w;
+
+ // encoder.blocks.*.attn.value
+ struct ggml_tensor * attn_v_w;
+ struct ggml_tensor * attn_v_b;
+
+ // encoder.blocks.*.mlp_ln
+ struct ggml_tensor * mlp_ln_w;
+ struct ggml_tensor * mlp_ln_b;
+
+ // encoder.blocks.*.mlp.0
+ struct ggml_tensor * mlp_0_w;
+ struct ggml_tensor * mlp_0_b;
+
+ // encoder.blocks.*.mlp.2
+ struct ggml_tensor * mlp_1_w;
+ struct ggml_tensor * mlp_1_b;
+};
+
+// token decoding layer
+struct whisper_layer_decoder {
+ // decoder.blocks.*.attn_ln
+ struct ggml_tensor * attn_ln_0_w;
+ struct ggml_tensor * attn_ln_0_b;
+
+ // decoder.blocks.*.attn.out
+ struct ggml_tensor * attn_ln_1_w;
+ struct ggml_tensor * attn_ln_1_b;
+
+ // decoder.blocks.*.attn.query
+ struct ggml_tensor * attn_q_w;
+ struct ggml_tensor * attn_q_b;
+
+ // decoder.blocks.*.attn.key
+ struct ggml_tensor * attn_k_w;
+
+ // decoder.blocks.*.attn.value
+ struct ggml_tensor * attn_v_w;
+ struct ggml_tensor * attn_v_b;
+
+ // decoder.blocks.*.cross_attn_ln
+ struct ggml_tensor * cross_attn_ln_0_w;
+ struct ggml_tensor * cross_attn_ln_0_b;
+
+ // decoder.blocks.*.cross_attn.out
+ struct ggml_tensor * cross_attn_ln_1_w;
+ struct ggml_tensor * cross_attn_ln_1_b;
+
+ // decoder.blocks.*.cross_attn.query
+ struct ggml_tensor * cross_attn_q_w;
+ struct ggml_tensor * cross_attn_q_b;
+
+ // decoder.blocks.*.cross_attn.key
+ struct ggml_tensor * cross_attn_k_w;
+
+ // decoder.blocks.*.cross_attn.value
+ struct ggml_tensor * cross_attn_v_w;
+ struct ggml_tensor * cross_attn_v_b;
+
+ // decoder.blocks.*.mlp_ln
+ struct ggml_tensor * mlp_ln_w;
+ struct ggml_tensor * mlp_ln_b;
+
+ // decoder.blocks.*.mlp.0
+ struct ggml_tensor * mlp_0_w;
+ struct ggml_tensor * mlp_0_b;
+
+ // decoder.blocks.*.mlp.2
+ struct ggml_tensor * mlp_1_w;
+ struct ggml_tensor * mlp_1_b;
+};
+
+struct whisper_model {
+ e_model type = MODEL_UNKNOWN;
+
+ whisper_hparams hparams;
+ whisper_filters filters;
+
+ // encoder.positional_embedding
+ struct ggml_tensor * e_pe;
+
+ // encoder.conv1
+ struct ggml_tensor * e_conv_1_w;
+ struct ggml_tensor * e_conv_1_b;
+
+ // encoder.conv2
+ struct ggml_tensor * e_conv_2_w;
+ struct ggml_tensor * e_conv_2_b;
+
+ // encoder.ln_post
+ struct ggml_tensor * e_ln_w;
+ struct ggml_tensor * e_ln_b;
+
+ // decoder.positional_embedding
+ struct ggml_tensor * d_pe; // DD
+
+ // decoder.token_embedding
+ struct ggml_tensor * d_te; // DD
+
+ // decoder.ln
+ struct ggml_tensor * d_ln_w; // DD
+ struct ggml_tensor * d_ln_b; // DD
+
+ std::vector<whisper_layer_encoder> layers_encoder;
+ std::vector<whisper_layer_decoder> layers_decoder;
+
+ // key + value memory
+ struct ggml_tensor * memory_k;
+ struct ggml_tensor * memory_v;
+
+ struct ggml_tensor * memory_cross_k;
+ struct ggml_tensor * memory_cross_v;
+
+ // context
+ struct ggml_context * ctx;
+ struct ggml_context * ctx_mem;
+
+ // tensors
+ int n_loaded;
+ std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+struct whisper_context {
+ int64_t t_load_us = 0;
+ int64_t t_mel_us = 0;
+ int64_t t_sample_us = 0;
+ int64_t t_encode_us = 0;
+ int64_t t_decode_us = 0;
+ int64_t t_start_us = 0;
+
+ std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
+ std::vector<uint8_t> buf_memory;
+ std::vector<uint8_t> buf_compute;
+ std::vector<uint8_t> buf_compute_layer;
+
+ whisper_model model;
+ whisper_vocab vocab;
+
+ whisper_mel mel;
+
+ std::vector<float> probs;
+ std::vector<float> logits;
+
+ std::vector<whisper_segment> result_all;
+
+ std::vector<whisper_token> prompt_past;
+
+ // [EXPERIMENTAL] token-level timestamps data
+ int64_t t_beg;
+ int64_t t_last;
+ whisper_token tid_last;
+ std::vector<float> energy; // PCM signal energy
+
+ // [EXPERIMENTAL] speed-up techniques
+ int32_t exp_n_audio_ctx; // 0 - use default
+};
+
+template<typename T>
+static void read_safe(std::ifstream& fin, T& dest)
+{
+ fin.read((char*)& dest, sizeof(T));
+}
+
+// load the model from a ggml file
+//
+// file format:
+//
+// - hparams
+// - pre-computed mel filters
+// - vocab
+// - weights
+//
+// see the convert-pt-to-ggml.py script for details
+//
+static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
+ logDebug( u8"%s: loading model from '%s'", __func__, fname.c_str() );
+
+ auto & model = wctx.model;
+ auto & vocab = wctx.vocab;
+
+ auto fin = std::ifstream(fname, std::ios::binary);
+ if (!fin) {
+ logError( u8"%s: failed to open '%s'", __func__, fname.c_str() );
+ return false;
+ }
+
+ // verify magic
+ {
+ uint32_t magic;
+ read_safe(fin, magic);
+ if (magic != 0x67676d6c) {
+ logError( u8"%s: invalid model file '%s' (bad magic)", __func__, fname.c_str() );
+ return false;
+ }
+ }
+
+ //load hparams
+ {
+ auto & hparams = model.hparams;
+
+ read_safe(fin, hparams.n_vocab);
+ read_safe(fin, hparams.n_audio_ctx);
+ read_safe(fin, hparams.n_audio_state);
+ read_safe(fin, hparams.n_audio_head);
+ read_safe(fin, hparams.n_audio_layer);
+ read_safe(fin, hparams.n_text_ctx);
+ read_safe(fin, hparams.n_text_state);
+ read_safe(fin, hparams.n_text_head);
+ read_safe(fin, hparams.n_text_layer);
+ read_safe(fin, hparams.n_mels);
+ read_safe(fin, hparams.f16);
+
+ assert(hparams.n_text_state == hparams.n_audio_state);
+
+ if (hparams.n_audio_layer == 4) {
+ model.type = e_model::MODEL_TINY;
+ }
+
+ if (hparams.n_audio_layer == 6) {
+ model.type = e_model::MODEL_BASE;
+ }
+
+ if (hparams.n_audio_layer == 12) {
+ model.type = e_model::MODEL_SMALL;
+ }
+
+ if (hparams.n_audio_layer == 24) {
+ model.type = e_model::MODEL_MEDIUM;
+ }
+
+ if (hparams.n_audio_layer == 32) {
+ model.type = e_model::MODEL_LARGE;
+ }
+
+ logDebug( u8"%s: n_vocab = %d", __func__, hparams.n_vocab);
+ logDebug( u8"%s: n_audio_ctx = %d", __func__, hparams.n_audio_ctx);
+ logDebug( u8"%s: n_audio_state = %d", __func__, hparams.n_audio_state);
+ logDebug( u8"%s: n_audio_head = %d", __func__, hparams.n_audio_head);
+ logDebug( u8"%s: n_audio_layer = %d", __func__, hparams.n_audio_layer);
+ logDebug( u8"%s: n_text_ctx = %d", __func__, hparams.n_text_ctx);
+ logDebug( u8"%s: n_text_state = %d", __func__, hparams.n_text_state);
+ logDebug( u8"%s: n_text_head = %d", __func__, hparams.n_text_head);
+ logDebug( u8"%s: n_text_layer = %d", __func__, hparams.n_text_layer);
+ logDebug( u8"%s: n_mels = %d", __func__, hparams.n_mels);
+ logDebug( u8"%s: f16 = %d", __func__, hparams.f16);
+ logDebug( u8"%s: type = %d", __func__, model.type);
+
+ wctx.buf_model = new std::vector<uint8_t>();
+ wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
+ wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
+ wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
+ wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
+ }
+
+ // load mel filters
+ {
+ auto & filters = wctx.model.filters;
+
+ read_safe(fin, filters.n_mel);
+ read_safe(fin, filters.n_fft);
+
+ filters.data.resize(filters.n_mel * filters.n_fft);
+ fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
+ }
+
+ // load vocab
+ {
+ int32_t n_vocab = 0;
+ read_safe(fin, n_vocab);
+
+ //if (n_vocab != model.hparams.n_vocab) {
+ // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+ // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+ // return false;
+ //}
+
+ std::string word;
+ std::vector<char> tmp;
+ for (int i = 0; i < n_vocab; i++) {
+ uint32_t len;
+ read_safe(fin, len);
+
+ if (len > 0) {
+ tmp.resize(len);
+ fin.read(&tmp[0], tmp.size()); // read to buffer
+ word.assign(&tmp[0], tmp.size());
+ } else {
+ // seems like we have an empty-string token in multi-language models (i = 50256)
+ //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
+ word = "";
+ }
+
+ vocab.token_to_id[word] = i;
+ vocab.id_to_token[i] = word;
+
+ //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
+ }
+
+ vocab.n_vocab = model.hparams.n_vocab;
+ if (vocab.is_multilingual()) {
+ vocab.token_eot++;
+ vocab.token_sot++;
+ vocab.token_prev++;
+ vocab.token_solm++;
+ vocab.token_not++;
+ vocab.token_beg++;
+ }
+
+ if (n_vocab < model.hparams.n_vocab) {
+ logDebug( u8"%s: adding %d extra tokens", __func__, model.hparams.n_vocab - n_vocab );
+ for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
+ if (i > vocab.token_beg) {
+ word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
+ } else if (i == vocab.token_eot) {
+ word = "[_EOT_]";
+ } else if (i == vocab.token_sot) {
+ word = "[_SOT_]";
+ } else if (i == vocab.token_prev) {
+ word = "[_PREV_]";
+ } else if (i == vocab.token_not) {
+ word = "[_NOT_]";
+ } else if (i == vocab.token_beg) {
+ word = "[_BEG_]";
+ } else {
+ word = "[_extra_token_" + std::to_string(i) + "]";
+ }
+ vocab.token_to_id[word] = i;
+ vocab.id_to_token[i] = word;
+ }
+ }
+ }
+
+ {
+ // this is the total memory required to run the inference
+ const size_t mem_required =
+ wctx.buf_model->size() +
+ wctx.buf_memory.size() +
+ wctx.buf_compute.size() +
+ wctx.buf_compute_layer.size();
+
+ logDebug( u8"%s: mem_required = %7.2f MB", __func__, mem_required / 1024.0 / 1024.0 );
+ }
+
+ // for the big tensors, we have the option to store the data in 16-bit floats
+ // in order to save memory and also to speed up the computation
+ const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ size_t ctx_size = 0;
+
+ {
+ const auto & hparams = model.hparams;
+
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_audio_ctx = hparams.n_audio_ctx;
+ const int n_audio_state = hparams.n_audio_state;
+ const int n_audio_layer = hparams.n_audio_layer;
+
+ const int n_text_ctx = hparams.n_text_ctx;
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+
+ const int n_mels = hparams.n_mels;
+
+ // encoder
+ {
+ // TODO: F16 .. maybe not?
+ ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
+
+ ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
+
+ ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
+
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
+ ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
+ }
+
+ // decoder
+ {
+ // TODO: F16 .. maybe not?
+ ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
+
+ ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
+
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
+ ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
+ }
+
+ // encoder layers
+ {
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
+
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
+ ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
+
+ ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
+
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
+ ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
+
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
+
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
+
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
+
+ ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
+ ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
+ }
+
+ // decoder layers
+ {
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
+
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
+ ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
+
+ ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
+
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
+ //
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
+ ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
+
+ ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
+ ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
+ }
+
+ ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
+
+ logDebug( u8"%s: ggml ctx size = %7.2f MB", __func__, ctx_size / ( 1024.0 * 1024.0 ) );
+ }
+
+ // create the ggml context
+ {
+ struct ggml_init_params params;
+ params.mem_size = wctx.buf_model->size();
+ params.mem_buffer = wctx.buf_model->data();
+
+ model.ctx = ggml_init(params);
+ if (!model.ctx) {
+ logError( u8"%s: ggml_init() failed", __func__ );
+ return false;
+ }
+ }
+
+ // prepare memory for the weights
+ {
+ auto & ctx = model.ctx;
+
+ const auto & hparams = model.hparams;
+
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_audio_ctx = hparams.n_audio_ctx;
+ const int n_audio_state = hparams.n_audio_state;
+ const int n_audio_layer = hparams.n_audio_layer;
+
+ const int n_text_ctx = hparams.n_text_ctx;
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+
+ const int n_mels = hparams.n_mels;
+
+ model.layers_encoder.resize(n_audio_layer);
+ model.layers_decoder.resize(n_text_layer);
+
+ // encoder
+ {
+ model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
+
+ model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
+ model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
+
+ model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
+ model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
+
+ model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+ model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ // map by name
+ model.tensors["encoder.positional_embedding"] = model.e_pe;
+
+ model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
+ model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
+
+ model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
+ model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
+
+ model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
+ model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
+
+ for (int i = 0; i < n_audio_layer; ++i) {
+ auto & layer = model.layers_encoder[i];
+
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
+
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
+
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+
+ // map by name
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
+
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
+ model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
+ }
+ }
+
+ // decoder
+ {
+ model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
+
+ model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
+
+ model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+ model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ // map by name
+ model.tensors["decoder.positional_embedding"] = model.d_pe;
+
+ model.tensors["decoder.token_embedding.weight"] = model.d_te;
+
+ model.tensors["decoder.ln.weight"] = model.d_ln_w;
+ model.tensors["decoder.ln.bias"] = model.d_ln_b;
+
+ for (int i = 0; i < n_text_layer; ++i) {
+ auto & layer = model.layers_decoder[i];
+
+ layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+ layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
+ layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
+
+ layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
+ layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+ layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+
+ layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+ layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+
+ layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
+ layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
+
+ // map by name
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
+
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
+ model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
+ }
+ }
+ }
+
+ // create the ggml memory context
+ {
+ struct ggml_init_params params;
+ params.mem_size = wctx.buf_memory.size();
+ params.mem_buffer = wctx.buf_memory.data();
+
+ model.ctx_mem = ggml_init(params);
+ if (!model.ctx_mem) {
+ logError( u8"%s: ggml_init() failed", __func__ );
+ return false;
+ }
+ }
+
+ // key + value memory
+ {
+ auto & ctx = model.ctx_mem;
+
+ const auto & hparams = model.hparams;
+
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+ const int n_text_ctx = hparams.n_text_ctx;
+
+ // key/value memory for the self-attention layer
+ {
+ const int n_mem = n_text_layer*n_text_ctx;
+ const int n_elements = n_text_state*n_mem;
+
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ }
+
+ // key/value memory for the cross-attention layer
+ {
+ const int n_audio_ctx = hparams.n_audio_ctx;
+
+ const int n_mem = n_text_layer*n_audio_ctx;
+ const int n_elements = n_text_state*n_mem;
+
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ }
+
+ const size_t memory_size =
+ ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
+ ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
+
+ logDebug( u8"%s: memory size = %7.2f MB", __func__, memory_size/1024.0/1024.0);
+ }
+
+ // load weights
+ {
+ size_t total_size = 0;
+
+ model.n_loaded = 0;
+
+ while (true) {
+ int32_t n_dims;
+ int32_t length;
+ int32_t ftype;
+
+ read_safe(fin, n_dims);
+ read_safe(fin, length);
+ read_safe(fin, ftype);
+
+ if (fin.eof()) {
+ break;
+ }
+
+ int32_t nelements = 1;
+ int32_t ne[3] = { 1, 1, 1 };
+ for (int i = 0; i < n_dims; ++i) {
+ read_safe(fin, ne[i]);
+ nelements *= ne[i];
+ }
+
+ std::string name;
+ std::vector<char> tmp(length); // create a buffer
+ fin.read( &tmp[0], tmp.size() ); // read to buffer
+ name.assign(&tmp[0], tmp.size());
+
+ if (model.tensors.find(name) == model.tensors.end()) {
+ logError( u8"%s: unknown tensor '%s' in model file", __func__, name.data() );
+ return false;
+ }
+
+ auto tensor = model.tensors[name.data()];
+ if (ggml_nelements(tensor) != nelements) {
+ logError( u8"%s: tensor '%s' has wrong size in model file", __func__, name.data());
+ return false;
+ }
+
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
+ logError( u8"%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]",
+ __func__, name.data(), tensor->ne[ 0 ], tensor->ne[ 1 ], tensor->ne[ 2 ], ne[ 0 ], ne[ 1 ], ne[ 2 ] );
+ return false;
+ }
+
+ const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
+
+ if (nelements*bpe != ggml_nbytes(tensor)) {
+ logError( u8"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+ __func__, name.data(), ggml_nbytes( tensor ), nelements* bpe );
+ return false;
+ }
+
+ fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+ //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+ total_size += ggml_nbytes(tensor);
+ model.n_loaded++;
+ }
+
+ logDebug( u8"%s: model size = %7.2f MB", __func__, total_size / 1024.0 / 1024.0 );
+
+ if (model.n_loaded == 0) {
+ logWarning( u8"%s: WARN no tensors loaded from model file - assuming empty model for testing", __func__);
+ } else if (model.n_loaded != (int) model.tensors.size()) {
+ logError( u8"%s: ERROR not all tensors loaded from model file - expected %zu, got %d", __func__, model.tensors.size(), model.n_loaded );
+ return false;
+ }
+ }
+
+ fin.close();
+
+ return true;
+}
+
+// evaluate the encoder
+//
+// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
+// part of the transformer model and returns the encoded features
+//
+// - model: the model
+// - n_threads: number of threads to use
+// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
+//
+static bool whisper_encode(
+ whisper_context & wctx,
+ const int n_threads,
+ const int mel_offset) {
+ const auto & model = wctx.model;
+ const auto & mel_inp = wctx.mel;
+ const auto & hparams = model.hparams;
+
+ const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
+ const int n_state = hparams.n_audio_state;
+ const int n_head = hparams.n_audio_head;
+ const int n_layer = hparams.n_audio_layer;
+
+ const int n_mels = hparams.n_mels;
+ assert(mel_inp.n_mel == n_mels);
+
+ struct ggml_init_params params;
+ params.mem_size = wctx.buf_compute.size();
+ params.mem_buffer = wctx.buf_compute.data();
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
+ assert(mel->type == GGML_TYPE_F32);
+ {
+ float * dst = (float *) mel->data;
+ memset(dst, 0, ggml_nbytes(mel));
+
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
+ const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
+
+ for (int j = 0; j < mel_inp.n_mel; ++j) {
+ for (int i = i0; i < i1; ++i) {
+ dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
+ }
+ }
+ }
+ Tracing::delayTensor( "enc.input", mel );
+
+ struct ggml_tensor * cur;
+
+ // convolution + gelu
+ {
+ cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
+ Tracing::delayTensor( "enc.conv1", cur );
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ model.e_conv_1_b,
+ cur),
+ cur);
+
+ cur = ggml_gelu(ctx0, cur);
+ Tracing::delayTensor( "enc.temp1", cur );
+
+ cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ model.e_conv_2_b,
+ cur),
+ cur);
+
+ cur = ggml_gelu(ctx0, cur);
+ }
+
+ // ===================================================================
+ // NOTE: experimenting with partial evaluation of the encoder (ignore)
+ //static int iter = -1;
+ //const int n_iter = 1500/n_ctx;
+
+ //iter = (iter + 1) % n_iter;
+
+ //if (iter == 0) {
+ // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
+ // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
+ //}
+
+ static int iter = 0;
+
+ const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
+ const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
+
+ struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
+
+ cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
+ // ===================================================================
+
+ // original:
+ //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+
+ struct ggml_tensor * inpL = cur;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const auto & layer = model.layers_encoder[il];
+
+ // create separate context for each layer to reduce memory usage
+
+ struct ggml_init_params paramsL;
+ paramsL.mem_size = wctx.buf_compute_layer.size();
+ paramsL.mem_buffer = wctx.buf_compute_layer.data();
+
+ struct ggml_context * ctxL = ggml_init(paramsL);
+
+ Tracing::delayTensor( { "enc.layer[ %i ].in", il }, inpL );
+
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpL);
+ if( il == 0 )
+ Tracing::delayTensor( "enc-norm", cur );
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
+ }
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ layer.attn_q_w,
+ cur);
+ if( il == 0 )
+ Tracing::delayTensor( "enc-Qcur", Qcur );
+
+ Qcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.attn_q_b,
+ Qcur),
+ Qcur);
+
+ //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ // note: no bias for Key
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
+ layer.attn_k_w,
+ cur);
+ if( il == 0 )
+ Tracing::delayTensor( "enc-Kcur", Kcur );
+
+ //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
+ layer.attn_v_w,
+ cur);
+ if( il == 0 )
+ Tracing::delayTensor( "enc-Vcur", Vcur );
+
+ Vcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.attn_v_b,
+ Vcur),
+ Vcur);
+
+ // ------
+
+#ifdef USE_FLASH_ATTN
+ struct ggml_tensor * Q =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Qcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Kcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * V =
+ ggml_cpy(ctxL,
+ ggml_permute(ctxL,
+ ggml_reshape_3d(ctxL,
+ Vcur,
+ n_state/n_head, n_head, n_ctx),
+ 1, 2, 0, 3),
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
+ );
+
+ struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
+ if( il == 0 )
+ Tracing::delayTensor( "enc-KQV", KQV );
+#else
+ struct ggml_tensor * Q =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Qcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Kcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ 0, 2, 1, 3);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+
+ struct ggml_tensor * KQ_scaled =
+ ggml_scale(ctxL,
+ KQ,
+ ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ );
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
+
+ //struct ggml_tensor * V_trans =
+ // ggml_permute(ctxL,
+ // ggml_cpy(ctxL,
+ // Vcur,
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ // 1, 2, 0, 3);
+
+ //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+
+ struct ggml_tensor * V =
+ ggml_cpy(ctxL,
+ ggml_permute(ctxL,
+ ggml_reshape_3d(ctxL,
+ Vcur,
+ n_state/n_head, n_head, n_ctx),
+ 0, 2, 1, 3),
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
+ );
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
+#endif
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cpy(ctxL,
+ KQV_merged,
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
+ }
+
+ // projection
+ {
+ cur = ggml_mul_mat(ctxL,
+ layer.attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
+ cur);
+ }
+
+ // add the input
+ cur = ggml_add(ctxL, cur, inpL);
+
+ struct ggml_tensor * inpFF = cur;
+
+ // feed-forward network
+ {
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpFF);
+
+ // cur = mlp_ln_w*cur + mlp_ln_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
+ }
+
+#ifdef USE_FLASH_FF
+ cur = ggml_flash_ff(ctxL,
+ ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
+#else
+ // fully connected
+ cur = ggml_mul_mat(ctxL,
+ layer.mlp_0_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
+ cur);
+
+ // GELU activation
+ cur = ggml_gelu(ctxL, cur);
+
+ // projection
+ cur = ggml_mul_mat(ctxL,
+ layer.mlp_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
+ cur);
+#endif
+ }
+
+ // output from this layer
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
+
+ {
+ struct ggml_cgraph gf = {};
+ gf.n_threads = n_threads;
+
+ ggml_build_forward_expand(&gf, inpO);
+ ggml_graph_compute (ctxL, &gf);
+ Tracing::writeDelayedTensors();
+ //ggml_graph_print(&gf);
+ }
+
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
+ // input for next layer (inpO -> inpL)
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
+ inpL->op = GGML_OP_NONE;
+ inpL->src0 = nullptr;
+ inpL->src1 = nullptr;
+
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
+
+ ggml_free(ctxL);
+ }
+ Tracing::tensor( "enc.layers", inpL );
+ cur = inpL;
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, cur);
+
+ // cur = ln_f_g*cur + ln_f_b
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, model.e_ln_w, cur),
+ cur),
+ ggml_repeat(ctx0, model.e_ln_b, cur));
+ }
+
+ // run the computation
+ {
+ struct ggml_cgraph gf = {};
+ gf.n_threads = n_threads;
+
+ ggml_build_forward_expand(&gf, cur);
+ ggml_graph_compute (ctx0, &gf);
+
+ //ggml_graph_print(&gf);
+ }
+
+ Tracing::tensor( "encode-out", cur );
+
+ // cur
+ //{
+ // printf("ne0 = %d\n", cur->ne[0]);
+ // printf("ne1 = %d\n", cur->ne[1]);
+ // for (int i = 0; i < 10; ++i) {
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
+ // }
+ // printf("... ");
+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
+ // printf("%8.4f ", ((float *)(cur->data))[i]);
+ // }
+ // printf("\n");
+ //}
+
+ // pre-compute cross-attention memory
+ {
+ struct ggml_cgraph gf = {};
+ gf.n_threads = n_threads;
+
+ // TODO: hack to disconnect the encoded features from the previous graph
+ cur->op = GGML_OP_NONE;
+ cur->src0 = nullptr;
+ cur->src1 = nullptr;
+
+ for (int il = 0; il < model.hparams.n_text_layer; ++il) {
+ auto & layer = model.layers_decoder[il];
+
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
+ layer.cross_attn_k_w,
+ cur);
+
+ Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
+ layer.cross_attn_v_w,
+ cur);
+
+ Vcross = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ layer.cross_attn_v_b,
+ Vcross),
+ Vcross);
+
+ //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+ //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
+ struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
+
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
+ }
+
+ ggml_graph_compute(ctx0, &gf);
+ }
+
+ ////////////////////////////////////////////////////////////////////////////
+
+ //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
+
+ ggml_free(ctx0);
+
+ return true;
+}
+
+// evaluate the decoder
+//
+// given text prompt + audio features -> predicts the probabilities for the next token
+//
+// - model: the model
+// - n_threads: number of threads to use
+// - tokens: text prompt
+// - n_tokens: number of tokens in the prompt
+// - n_past: number of past tokens to prefix the prompt with
+//
+static bool whisper_decode(
+ whisper_context & wctx,
+ const int n_threads,
+ const whisper_token * tokens,
+ const int n_tokens,
+ const int n_past) {
+ const auto & model = wctx.model;
+ const auto & hparams = model.hparams;
+
+ auto & logits_out = wctx.logits;
+ auto & probs_out = wctx.probs;
+
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_ctx = hparams.n_text_ctx;
+ const int n_state = hparams.n_text_state;
+ const int n_head = hparams.n_text_head;
+ const int n_layer = hparams.n_text_layer;
+
+ const int N = n_tokens;
+ const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
+
+ struct ggml_init_params params;
+ params.mem_size = wctx.buf_compute.size();
+ params.mem_buffer = wctx.buf_compute.data();
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ memcpy(embd->data, tokens, N*ggml_element_size(embd));
+
+ struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ for (int i = 0; i < N; ++i) {
+ ((int32_t *) position->data)[i] = n_past + i;
+ }
+
+ // token encoding + position encoding
+ struct ggml_tensor * cur =
+ ggml_add(ctx0,
+ ggml_get_rows(ctx0, model.d_te, embd),
+ ggml_get_rows(ctx0, model.d_pe, position));
+ Tracing::delayTensor( "dec-rows", cur );
+
+ struct ggml_tensor * inpL = cur;
+
+ for (int il = 0; il < n_layer; ++il) {
+ const auto & layer = model.layers_decoder[il];
+
+ struct ggml_init_params paramsL;
+ paramsL.mem_size = wctx.buf_compute_layer.size();
+ paramsL.mem_buffer = wctx.buf_compute_layer.data();
+
+ struct ggml_context * ctxL = ggml_init(paramsL);
+ struct ggml_cgraph gf = {};
+ gf.n_threads = n_threads;
+
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpL);
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
+ }
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ layer.attn_q_w,
+ cur);
+
+ Qcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.attn_q_b,
+ Qcur),
+ Qcur);
+
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ // note: no bias for Key
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
+ layer.attn_k_w,
+ cur);
+
+ Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
+ layer.attn_v_w,
+ cur);
+
+ Vcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.attn_v_b,
+ Vcur),
+ Vcur);
+
+ // store key and value to memory
+ {
+ struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
+
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
+ ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
+ }
+
+ // ------
+
+ struct ggml_tensor * Q =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Qcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K =
+ ggml_permute(ctxL,
+ ggml_reshape_3d(ctxL,
+ ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
+ n_state/n_head, n_head, n_past + N),
+ 0, 2, 1, 3);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+
+ //struct ggml_tensor * KQ_scaled =
+ // ggml_scale(ctxL,
+ // KQ,
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ // );
+
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
+ if( 0 == il ) Tracing::delayTensor( "dec-KQ", KQ_soft_max );
+
+ struct ggml_tensor * V_trans =
+ ggml_permute(ctxL,
+ ggml_reshape_3d(ctxL,
+ ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
+ n_state/n_head, n_head, n_past + N),
+ 1, 2, 0, 3);
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+ if( 0 == il ) Tracing::delayTensor( "dec-KQV", KQV );
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cpy(ctxL,
+ KQV_merged,
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+ }
+
+ {
+ cur = ggml_mul_mat(ctxL,
+ layer.attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
+ cur);
+ }
+
+ // add the input
+ struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
+
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
+
+ // cur = ln_0_w*cur + ln_0_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
+ }
+
+ // cross-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
+ layer.cross_attn_q_w,
+ cur);
+
+ Qcur = ggml_add(ctxL,
+ ggml_repeat(ctxL,
+ layer.cross_attn_q_b,
+ Qcur),
+ Qcur);
+
+ Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+
+ // Kcross is already scaled
+ struct ggml_tensor * Kcross =
+ ggml_reshape_3d(ctxL,
+ ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
+ n_state/n_head, n_head, M);
+
+ struct ggml_tensor * Vcross =
+ ggml_reshape_3d(ctxL,
+ ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
+ n_state/n_head, n_head, M);
+
+ // ------
+
+ struct ggml_tensor * Q =
+ ggml_permute(ctxL,
+ ggml_cpy(ctxL,
+ Qcur,
+ ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+ 0, 2, 1, 3);
+
+ struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+
+ //struct ggml_tensor * KQ_scaled =
+ // ggml_scale(ctxL,
+ // KQ,
+ // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+ // );
+
+ // no masking for cross-attention
+ //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
+
+ struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
+ if( 0 == il ) Tracing::delayTensor( "dec-KQV", KQV );
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
+
+ // cur = KQV_merged.contiguous().view(n_state, N)
+ cur = ggml_cpy(ctxL,
+ KQV_merged,
+ ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+ }
+
+ // projection
+ {
+ cur = ggml_mul_mat(ctxL,
+ layer.cross_attn_ln_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
+ cur);
+ }
+
+ // add the input
+ cur = ggml_add(ctxL, cur, inpCA);
+
+ struct ggml_tensor * inpFF = cur;
+
+ // feed-forward network
+ {
+ // norm
+ {
+ cur = ggml_norm(ctxL, inpFF);
+
+ // cur = mlp_ln_w*cur + mlp_ln_b
+ cur = ggml_add(ctxL,
+ ggml_mul(ctxL,
+ ggml_repeat(ctxL, layer.mlp_ln_w, cur),
+ cur),
+ ggml_repeat(ctxL, layer.mlp_ln_b, cur));
+ }
+
+ // fully connected
+ cur = ggml_mul_mat(ctxL,
+ layer.mlp_0_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.mlp_0_b, cur),
+ cur);
+
+ // GELU activation
+ cur = ggml_gelu(ctxL, cur);
+
+ // projection
+ cur = ggml_mul_mat(ctxL,
+ layer.mlp_1_w,
+ cur);
+
+ cur = ggml_add(ctxL,
+ ggml_repeat(ctxL, layer.mlp_1_b, cur),
+ cur);
+ }
+
+ // output from this layer
+ struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
+
+ {
+ ggml_build_forward_expand(&gf, inpO);
+ ggml_graph_compute (ctxL, &gf);
+ Tracing::writeDelayedTensors();
+ //ggml_graph_print(&gf);
+ }
+
+ // TODO: this is a hack to have per-layer computation graphs - need to come up with something better
+ // input for next layer (inpO -> inpL)
+ memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
+ inpL->op = GGML_OP_NONE;
+ inpL->src0 = nullptr;
+ inpL->src1 = nullptr;
+
+ if (N > 1) {
+ //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
+ }
+
+ ggml_free(ctxL);
+ }
+
+ cur = inpL;
+
+ // norm
+ {
+ cur = ggml_norm(ctx0, cur);
+
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, model.d_ln_w, cur),
+ cur),
+ ggml_repeat(ctx0, model.d_ln_b, cur));
+ }
+
+ struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
+
+ // logits -> probs
+ cur = ggml_dup(ctx0, logits);
+ cur = ggml_soft_max(ctx0, cur); // in-place
+
+ // run the computation
+ {
+ struct ggml_cgraph gf = {};
+ gf.n_threads = n_threads;
+
+ ggml_build_forward_expand(&gf, cur);
+ ggml_graph_compute (ctx0, &gf);
+ }
+
+ logits_out.resize(N*n_vocab);
+ memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
+
+ probs_out.resize(N*n_vocab);
+ memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
+
+ if (N > 1) {
+ //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
+ //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
+ //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
+ }
+
+ ggml_free(ctx0);
+ // Hash::vector( "probs", probs_out );
+ Tracing::vector( "probs", probs_out );
+
+ return true;
+}
+
+// the most basic sampling scheme - select the top token
+static whisper_token_data whisper_sample_best(
+ const whisper_vocab & vocab,
+ const float * probs,
+ bool force_timestamp,
+ bool is_initial) {
+ whisper_token_data result = {
+ 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
+ };
+
+ int n_logits = vocab.id_to_token.size();
+
+ std::vector<std::pair<double, whisper_vocab::id>> probs_id;
+ probs_id.reserve(n_logits);
+
+ for (int i = 0; i < n_logits; i++) {
+ probs_id.emplace_back(probs[i], i);
+ }
+
+ {
+ double sum_ts = 0.0;
+ double max_ts = -1.0;
+ double max_tx = -1.0;
+
+ for (int i = 0; i < vocab.token_beg; i++) {
+ max_tx = std::max(max_tx, probs_id[i].first);
+ }
+
+ const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
+ const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
+
+ // the initial timestamp cannot be larger than 100
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
+ if (is_initial) {
+ for (int i = i0; i < n_logits; ++ i) {
+ probs_id[i].first = -INFINITY;
+ }
+ }
+
+ for (int i = vocab.token_beg; i < i1; i++) {
+ sum_ts += probs_id[i].first;
+ if (probs_id[i].first > max_ts) {
+ max_ts = probs_id[i].first;
+ result.tid = probs_id[i].second;
+ }
+ }
+
+ // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
+ // timestamp token
+ if (sum_ts > max_tx || force_timestamp) {
+ // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
+ for (int i = 0; i < vocab.token_beg; i++) {
+ probs_id[i].first = -INFINITY;
+ }
+ }
+
+ result.pt = max_ts/(sum_ts + 1e-10);
+ result.ptsum = sum_ts;
+ }
+
+ // find the top K tokens
+ const int top_k = 4;
+
+ std::partial_sort(
+ probs_id.begin(),
+ probs_id.begin() + top_k, probs_id.end(),
+ [](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
+ return a.first > b.first;
+ });
+
+ probs_id.resize(top_k);
+
+ //printf("\n");
+ //for (int i = 0; i < (int) probs_id.size(); i++) {
+ // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
+ //}
+
+ int res = 0;
+ while ((probs_id[res].second == vocab.token_sot ||
+ probs_id[res].second == vocab.token_solm ||
+ probs_id[res].second == vocab.token_not) &&
+ res < (int) probs_id.size() - 1) {
+ res++;
+ }
+
+ result.id = probs_id[res].second;
+ result.p = probs_id[res].first;
+
+ return result;
+}
+
+// 500 -> 00:05.000
+// 6000 -> 01:00.000
+static std::string to_timestamp(int64_t t, bool comma = false) {
+ int64_t msec = t * 10;
+ int64_t hr = msec / (1000 * 60 * 60);
+ msec = msec - hr * (1000 * 60 * 60);
+ int64_t min = msec / (1000 * 60);
+ msec = msec - min * (1000 * 60);
+ int64_t sec = msec / 1000;
+ msec = msec - sec * 1000;
+
+ char buf[32];
+ snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
+
+ return std::string(buf);
+}
+
+// naive Discrete Fourier Transform
+// input is real-valued
+// output is complex-valued
+static void dft(const std::vector<float> & in, std::vector<float> & out) {
+ int N = in.size();
+
+ out.resize(N*2);
+
+ for (int k = 0; k < N; k++) {
+ float re = 0;
+ float im = 0;
+
+ for (int n = 0; n < N; n++) {
+ float angle = 2*M_PI*k*n/N;
+ re += in[n]*cos(angle);
+ im -= in[n]*sin(angle);
+ }
+
+ out[k*2 + 0] = re;
+ out[k*2 + 1] = im;
+ }
+}
+
+// Cooley-Tukey FFT
+// poor man's implementation - use something better
+// input is real-valued
+// output is complex-valued
+static void fft(const std::vector<float> & in, std::vector<float> & out) {
+ out.resize(in.size()*2);
+
+ int N = in.size();
+
+ if (N == 1) {
+ out[0] = in[0];
+ out[1] = 0;
+ return;
+ }
+
+ if (N%2 == 1) {
+ dft(in, out);
+ return;
+ }
+
+ std::vector<float> even;
+ std::vector<float> odd;
+
+ for (int i = 0; i < N; i++) {
+ if (i % 2 == 0) {
+ even.push_back(in[i]);
+ } else {
+ odd.push_back(in[i]);
+ }
+ }
+
+ std::vector<float> even_fft;
+ std::vector<float> odd_fft;
+
+ fft(even, even_fft);
+ fft(odd, odd_fft);
+
+ for (int k = 0; k < N/2; k++) {
+ float theta = 2*M_PI*k/N;
+
+ float re = cos(theta);
+ float im = -sin(theta);
+
+ float re_odd = odd_fft[2*k + 0];
+ float im_odd = odd_fft[2*k + 1];
+
+ out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
+ out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
+
+ out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
+ out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
+ }
+}
+
+// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
+static bool log_mel_spectrogram(
+ const float * samples,
+ const int n_samples,
+ const int /*sample_rate*/,
+ const int fft_size,
+ const int fft_step,
+ const int n_mel,
+ const int n_threads,
+ const whisper_filters & filters,
+ const bool speed_up,
+ whisper_mel & mel) {
+
+ // Hanning window
+ std::vector<float> hann;
+ hann.resize(fft_size);
+ for (int i = 0; i < fft_size; i++) {
+ hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
+ }
+
+ mel.n_mel = n_mel;
+ mel.n_len = (n_samples)/fft_step;
+ mel.data.resize(mel.n_mel*mel.n_len);
+
+ const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);
+
+ //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
+ //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
+
+ std::vector<std::thread> workers(n_threads);
+ for (int iw = 0; iw < n_threads; ++iw) {
+ workers[iw] = std::thread([&](int ith) {
+ std::vector<float> fft_in;
+ fft_in.resize(fft_size);
+ for (int i = 0; i < fft_size; i++) {
+ fft_in[i] = 0.0;
+ }
+
+ std::vector<float> fft_out;
+ fft_out.resize(2*fft_size);
+
+ for (int i = ith; i < mel.n_len; i += n_threads) {
+ const int offset = i*fft_step;
+
+ // apply Hanning window
+ for (int j = 0; j < fft_size; j++) {
+ if (offset + j < n_samples) {
+ fft_in[j] = hann[j]*samples[offset + j];
+ } else {
+ fft_in[j] = 0.0;
+ }
+ }
+
+ // FFT -> mag^2
+ fft(fft_in, fft_out);
+
+ for (int j = 0; j < fft_size; j++) {
+ fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
+ }
+ for (int j = 1; j < fft_size/2; j++) {
+ //if (i == 0) {
+ // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
+ //}
+ fft_out[j] += fft_out[fft_size - j];
+ }
+ if (i == 0) {
+ //for (int j = 0; j < fft_size; j++) {
+ // printf("%d: %e\n", j, fft_out[j]);
+ //}
+ }
+
+ if (speed_up) {
+ // scale down in the frequency domain results in a speed up in the time domain
+ for (int j = 0; j < n_fft; j++) {
+ fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
+ }
+ }
+
+ // mel spectrogram
+ for (int j = 0; j < mel.n_mel; j++) {
+ double sum = 0.0;
+
+ for (int k = 0; k < n_fft; k++) {
+ sum += fft_out[k]*filters.data[j*n_fft + k];
+ }
+ if (sum < 1e-10) {
+ sum = 1e-10;
+ }
+
+ sum = log10(sum);
+
+ mel.data[j*mel.n_len + i] = sum;
+ }
+ }
+ }, iw);
+ }
+
+ for (int iw = 0; iw < n_threads; ++iw) {
+ workers[iw].join();
+ }
+
+ // clamping and normalization
+ double mmax = -1e20;
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
+ if (mel.data[i] > mmax) {
+ mmax = mel.data[i];
+ }
+ }
+ //printf("%s: max = %f\n", __func__, mmax);
+
+ mmax -= 8.0;
+
+ for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
+ if (mel.data[i] < mmax) {
+ mel.data[i] = mmax;
+ }
+
+ mel.data[i] = (mel.data[i] + 4.0)/4.0;
+ }
+
+ return true;
+}
+
+// split text into tokens
+//
+// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
+//
+// Regex (Python):
+// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+//
+// Regex (C++):
+// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
+//
+static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
+ std::vector<std::string> words;
+
+ // first split the text into words
+ {
+ std::string str = text;
+ std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
+
+ std::regex re(pat);
+ std::smatch m;
+
+ while (std::regex_search(str, m, re)) {
+ for (auto x : m) {
+ words.push_back(x);
+ }
+ str = m.suffix();
+ }
+ }
+
+ // find the longest tokens that form the words:
+ std::vector<whisper_vocab::id> tokens;
+ for (const auto & word : words) {
+ if (word.empty()) continue;
+
+ int i = 0;
+ int n = word.size();
+ while (i < n) {
+ int j = n;
+ while (j > i) {
+ auto it = vocab.token_to_id.find(word.substr(i, j-i));
+ if (it != vocab.token_to_id.end()) {
+ tokens.push_back(it->second);
+ i = j;
+ break;
+ }
+ --j;
+ }
+ if (i == n) {
+ break;
+ }
+ if (j == i) {
+ auto sub = word.substr(i, 1);
+ if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
+ tokens.push_back(vocab.token_to_id.at(sub));
+ } else {
+ logWarning( u8"%s: unknown token '%s'", __func__, sub.data() );
+ }
+ ++i;
+ }
+ }
+ }
+
+ return tokens;
+}
+
+//
+// interface implementation
+//
+
+struct whisper_context * whisper_init(const char * path_model) {
+ ggml_time_init();
+
+ whisper_context * ctx = new whisper_context;
+
+ const int64_t t_start_us = ggml_time_us();
+
+ ctx->t_start_us = t_start_us;
+
+ if (!whisper_model_load(path_model, *ctx)) {
+ logError( u8"%s: failed to load model from '%s'", __func__, path_model );
+ delete ctx;
+ return nullptr;
+ }
+
+ ctx->t_load_us = ggml_time_us() - t_start_us;
+
+ return ctx;
+}
+
+void whisper_free(struct whisper_context * ctx) {
+ if (ctx) {
+ if (ctx->model.ctx) {
+ ggml_free(ctx->model.ctx);
+ }
+ if (ctx->model.ctx_mem) {
+ ggml_free(ctx->model.ctx_mem);
+ }
+ if (ctx->buf_model) {
+ delete ctx->buf_model;
+ }
+ delete ctx;
+ }
+}
+
+int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
+ logError( u8"%s: failed to compute mel spectrogram", __func__ );
+ return -1;
+ }
+
+ ctx->t_mel_us = ggml_time_us() - t_start_us;
+
+ return 0;
+}
+
+// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
+int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
+ logError( u8"%s: failed to compute mel spectrogram", __func__ );
+ return -1;
+ }
+
+ ctx->t_mel_us = ggml_time_us() - t_start_us;
+
+ return 0;
+}
+
+int whisper_set_mel(
+ struct whisper_context * ctx,
+ const float * data,
+ int n_len,
+ int n_mel) {
+ if (n_mel != WHISPER_N_MEL) {
+ logError( u8"%s: invalid number of mel bands: %d (expected %d)", __func__, n_mel, WHISPER_N_MEL );
+ return -1;
+ }
+
+ ctx->mel.n_len = n_len;
+ ctx->mel.n_mel = n_mel;
+
+ ctx->mel.data.resize(n_len*n_mel);
+ memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
+
+ return 0;
+}
+
+int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!whisper_encode(*ctx, n_threads, offset)) {
+ logError( u8"%s: failed to eval", __func__ );
+ return -1;
+ }
+
+ ctx->t_encode_us += ggml_time_us() - t_start_us;
+
+ return 0;
+}
+
+int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) {
+ logError( u8"%s: failed to eval", __func__ );
+ return 1;
+ }
+
+ ctx->t_decode_us += ggml_time_us() - t_start_us;
+
+ return 0;
+}
+
+struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
+
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+
+ return res;
+}
+
+struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
+
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+
+ return res;
+}
+
+int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
+ const auto res = tokenize(ctx->vocab, text);
+
+ if (n_max_tokens < (int) res.size()) {
+ logError( u8"%s: too many resulting tokens: %d (max %d)", __func__, (int)res.size(), n_max_tokens );
+ return -1;
+ }
+
+ for (int i = 0; i < (int) res.size(); i++) {
+ tokens[i] = res[i];
+ }
+
+ return res.size();
+}
+
+int whisper_lang_max_id() {
+ auto max_id = 0;
+ for (const auto & kv : g_lang) {
+ max_id = std::max(max_id, kv.second.first);
+ }
+
+ return max_id;
+}
+
+int whisper_lang_id(const char * lang) {
+ if (!g_lang.count(lang)) {
+ for (const auto & kv : g_lang) {
+ if (kv.second.second == lang) {
+ return kv.second.first;
+ }
+ }
+
+ logError( u8"%s: unknown language '%s'", __func__, lang );
+ return -1;
+ }
+
+ return g_lang.at(lang).first;
+}
+
+const char * whisper_lang_str(int id) {
+ for (const auto & kv : g_lang) {
+ if (kv.second.first == id) {
+ return kv.first.c_str();
+ }
+ }
+
+ logError( u8"%s: unknown language id %d", __func__, id );
+ return nullptr;
+}
+
+int whisper_lang_auto_detect(
+ struct whisper_context * ctx,
+ int offset_ms,
+ int n_threads,
+ float * lang_probs) {
+ const int seek = offset_ms/10;
+
+ if (seek < 0) {
+ logError( u8"%s: offset %dms is before the start of the audio", __func__, offset_ms );
+ return -1;
+ }
+
+ if (seek >= ctx->mel.n_len) {
+ logError( u8"%s: offset %dms is past the end of the audio (%dms)", __func__, offset_ms, ctx->mel.n_len * 10 );
+ return -2;
+ }
+
+ // run the encoder
+ if (whisper_encode(ctx, seek, n_threads) != 0) {
+ logError( u8"%s: failed to encode", __func__ );
+ return -6;
+ }
+
+ const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
+
+ if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
+ logError( u8"%s: failed to decode", __func__ );
+ return -7;
+ }
+
+ std::vector<std::pair<float, int>> probs_id;
+ for (const auto & kv : g_lang) {
+ const auto token_lang = whisper_token_lang(ctx, kv.second.first);
+ probs_id.emplace_back( ctx->probs[token_lang], kv.second.first );
+ }
+
+ // sort descending
+ {
+ using pair_type = decltype(probs_id)::value_type;
+ std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
+ return a.first > b.first;
+ });
+ }
+
+ // softmax
+ {
+ float sum = 0;
+ for (const auto & kv : probs_id) {
+ sum += exp(kv.first);
+ }
+
+ for (auto & kv : probs_id) {
+ kv.first = exp(kv.first) / sum;
+ }
+ }
+
+ {
+ for (int i = 0; i < (int) probs_id.size(); i++) {
+ if (lang_probs) {
+ lang_probs[probs_id[i].second] = probs_id[i].first;
+ }
+
+ //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
+ }
+ }
+
+ return probs_id[0].second;
+}
+
+int whisper_n_len(struct whisper_context * ctx) {
+ return ctx->mel.n_len;
+}
+
+int whisper_n_vocab(struct whisper_context * ctx) {
+ return ctx->vocab.n_vocab;
+}
+
+int whisper_n_text_ctx(struct whisper_context * ctx) {
+ return ctx->model.hparams.n_text_ctx;
+}
+
+int whisper_is_multilingual(struct whisper_context * ctx) {
+ return ctx->vocab.is_multilingual() ? 1 : 0;
+}
+
+float * whisper_get_probs(struct whisper_context * ctx) {
+ return ctx->probs.data();
+}
+
+const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
+ return ctx->vocab.id_to_token.at(token).c_str();
+}
+
+whisper_token whisper_token_eot(struct whisper_context * ctx) {
+ return ctx->vocab.token_eot;
+}
+
+whisper_token whisper_token_sot(struct whisper_context * ctx) {
+ return ctx->vocab.token_sot;
+}
+
+whisper_token whisper_token_prev(struct whisper_context * ctx) {
+ return ctx->vocab.token_prev;
+}
+
+whisper_token whisper_token_solm(struct whisper_context * ctx) {
+ return ctx->vocab.token_solm;
+}
+
+whisper_token whisper_token_not(struct whisper_context * ctx) {
+ return ctx->vocab.token_not;
+}
+
+whisper_token whisper_token_beg(struct whisper_context * ctx) {
+ return ctx->vocab.token_beg;
+}
+
+whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
+ return whisper_token_sot(ctx) + 1 + lang_id;
+}
+
+whisper_token whisper_token_translate(void) {
+ return whisper_vocab::token_translate;
+}
+
+whisper_token whisper_token_transcribe(void) {
+ return whisper_vocab::token_transcribe;
+}
+
+void whisper_print_timings(struct whisper_context * ctx) {
+ const int64_t t_end_us = ggml_time_us();
+
+ logInfo( u8"%s: load time = %8.2f ms", __func__, ctx->t_load_us / 1000.0f );
+ logInfo( u8"%s: mel time = %8.2f ms", __func__, ctx->t_mel_us / 1000.0f );
+ logInfo( u8"%s: sample time = %8.2f ms", __func__, ctx->t_sample_us / 1000.0f );
+ logInfo( u8"%s: encode time = %8.2f ms / %.2f ms per layer", __func__,
+ ctx->t_encode_us / 1000.0f, ctx->t_encode_us / 1000.0f / ctx->model.hparams.n_audio_layer );
+ logInfo( u8"%s: decode time = %8.2f ms / %.2f ms per layer", __func__,
+ ctx->t_decode_us / 1000.0f, ctx->t_decode_us / 1000.0f / ctx->model.hparams.n_text_layer );
+ logInfo( u8"%s: total time = %8.2f ms", __func__, ( t_end_us - ctx->t_start_us ) / 1000.0f );
+}
+
+void whisper_reset_timings(struct whisper_context * ctx) {
+ ctx->t_sample_us = 0;
+ ctx->t_encode_us = 0;
+ ctx->t_decode_us = 0;
+}
+
+const char * whisper_print_system_info(void) {
+ static std::string s;
+
+ s = "";
+ s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
+ s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
+ s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
+ s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
+ s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
+ s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
+ s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
+ s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
+ s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
+ s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
+
+ return s.c_str();
+}
+
+////////////////////////////////////////////////////////////////////////////
+
+struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
+ struct whisper_full_params result;
+
+ switch (strategy) {
+ case WHISPER_SAMPLING_GREEDY:
+ {
+ result = {
+ /*.strategy =*/ WHISPER_SAMPLING_GREEDY,
+
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+ /*.n_max_text_ctx =*/ 16384,
+ /*.offset_ms =*/ 0,
+ /*.duration_ms =*/ 0,
+
+ /*.translate =*/ false,
+ /*.no_context =*/ false,
+ /*.single_segment =*/ false,
+ /*.print_special =*/ false,
+ /*.print_progress =*/ true,
+ /*.print_realtime =*/ false,
+ /*.print_timestamps =*/ true,
+
+ /*.token_timestamps =*/ false,
+ /*.thold_pt =*/ 0.01f,
+ /*.thold_ptsum =*/ 0.01f,
+ /*.max_len =*/ 0,
+ /*.max_tokens =*/ 0,
+
+ /*.speed_up =*/ false,
+ /*.audio_ctx =*/ 0,
+
+ /*.prompt_tokens =*/ nullptr,
+ /*.prompt_n_tokens =*/ 0,
+
+ /*.language =*/ "en",
+
+ /*.greedy =*/ {
+ /*.n_past =*/ 0,
+ },
+
+ /*.beam_search =*/ {
+ /*.n_past =*/ -1,
+ /*.beam_width =*/ -1,
+ /*.n_best =*/ -1,
+ },
+
+ /*.new_segment_callback =*/ nullptr,
+ /*.new_segment_callback_user_data =*/ nullptr,
+
+ /*.encoder_begin_callback =*/ nullptr,
+ /*.encoder_begin_callback_user_data =*/ nullptr,
+ };
+ } break;
+ case WHISPER_SAMPLING_BEAM_SEARCH:
+ {
+ result = {
+ /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
+
+ /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
+ /*.n_max_text_ctx =*/ 16384,
+ /*.offset_ms =*/ 0,
+ /*.duration_ms =*/ 0,
+
+ /*.translate =*/ false,
+ /*.no_context =*/ false,
+ /*.single_segment =*/ false,
+ /*.print_special =*/ false,
+ /*.print_progress =*/ true,
+ /*.print_realtime =*/ false,
+ /*.print_timestamps =*/ true,
+
+ /*.token_timestamps =*/ false,
+ /*.thold_pt =*/ 0.01f,
+ /*.thold_ptsum =*/ 0.01f,
+ /*.max_len =*/ 0,
+ /*.max_tokens =*/ 0,
+
+ /*.speed_up =*/ false,
+ /*.audio_ctx =*/ 0,
+
+ /*.prompt_tokens =*/ nullptr,
+ /*.prompt_n_tokens =*/ 0,
+
+ /*.language =*/ "en",
+
+ /*.greedy =*/ {
+ /*.n_past =*/ -1,
+ },
+
+ /*.beam_search =*/ {
+ /*.n_past =*/ 0,
+ /*.beam_width =*/ 10,
+ /*.n_best =*/ 5,
+ },
+
+ /*.new_segment_callback =*/ nullptr,
+ /*.new_segment_callback_user_data =*/ nullptr,
+
+ /*.encoder_begin_callback =*/ nullptr,
+ /*.encoder_begin_callback_user_data =*/ nullptr,
+ };
+ } break;
+ }
+
+ return result;
+}
+
+// forward declarations
+static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
+static void whisper_exp_compute_token_level_timestamps(
+ struct whisper_context * ctx,
+ int i_segment,
+ float thold_pt,
+ float thold_ptsum);
+
+// wrap the last segment to max_len characters
+// returns the number of new segments
+static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
+ auto segment = ctx->result_all.back();
+
+ int res = 1;
+ int acc = 0;
+
+ std::string text;
+
+ for (int i = 0; i < (int) segment.tokens.size(); i++) {
+ const auto & token = segment.tokens[i];
+ if (token.id >= whisper_token_eot(ctx)) {
+ continue;
+ }
+
+ const auto txt = whisper_token_to_str(ctx, token.id);
+
+ const int cur = strlen(txt);
+
+ if (acc + cur > max_len && i > 0) {
+ // split here
+ ctx->result_all.back().text = std::move(text);
+ ctx->result_all.back().t1 = token.t0;
+ ctx->result_all.back().tokens.resize(i);
+
+ ctx->result_all.push_back({});
+ ctx->result_all.back().t0 = token.t0;
+ ctx->result_all.back().t1 = segment.t1;
+
+ // add tokens [i, end] to the new segment
+ ctx->result_all.back().tokens.insert(
+ ctx->result_all.back().tokens.end(),
+ segment.tokens.begin() + i,
+ segment.tokens.end());
+
+ acc = 0;
+ text = "";
+
+ segment = ctx->result_all.back();
+ i = -1;
+
+ res++;
+ } else {
+ acc += cur;
+ text += txt;
+ }
+ }
+
+ ctx->result_all.back().text = std::move(text);
+
+ return res;
+}
+
+int whisper_full(
+ struct whisper_context * ctx,
+ struct whisper_full_params params,
+ const float * samples,
+ int n_samples) {
+ // clear old results
+ auto & result_all = ctx->result_all;
+
+ result_all.clear();
+
+ // compute log mel spectrogram
+ if (params.speed_up) {
+ if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
+ logError( u8"%s: failed to compute log mel spectrogram", __func__ );
+ return -1;
+ }
+ } else {
+ if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
+ logError( u8"%s: failed to compute log mel spectrogram", __func__ );
+ return -2;
+ }
+ }
+
+ // auto-detect language if not specified
+ if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
+ std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
+
+ const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
+ if (lang_id < 0) {
+ logError( u8"%s: failed to auto-detect language", __func__ );
+ return -3;
+ }
+
+ params.language = whisper_lang_str(lang_id);
+
+ logInfo( u8"%s: auto-detected language: %s (p = %f)", __func__, params.language, probs[ whisper_lang_id( params.language ) ] );
+ }
+
+ if (params.token_timestamps) {
+ ctx->t_beg = 0;
+ ctx->t_last = 0;
+ ctx->tid_last = 0;
+ ctx->energy = get_signal_energy(samples, n_samples, 32);
+ }
+
+ const int seek_start = params.offset_ms/10;
+ const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
+
+ // if length of spectrogram is less than 1s (100 samples), then return
+ // basically don't process anything that is less than 1s
+ // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
+ if (seek_end < 100 + seek_start) {
+ return 0;
+ }
+
+ // the accumulated text context so far
+ auto & prompt_past = ctx->prompt_past;
+ if (params.no_context) {
+ prompt_past.clear();
+ }
+
+ // prepend the prompt tokens to the prompt_past
+ if (params.prompt_tokens && params.prompt_n_tokens > 0) {
+ // parse tokens from the pointer
+ for (int i = 0; i < params.prompt_n_tokens; i++) {
+ prompt_past.push_back(params.prompt_tokens[i]);
+ }
+ std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
+ }
+
+ // overwrite audio_ctx
+ ctx->exp_n_audio_ctx = params.audio_ctx;
+
+ // these tokens determine the task that will be performed
+ std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
+ if (whisper_is_multilingual(ctx)) {
+ const int lang_id = whisper_lang_id(params.language);
+ prompt_init.push_back(whisper_token_lang(ctx, lang_id));
+ if (params.translate) {
+ prompt_init.push_back(whisper_token_translate());
+ } else {
+ prompt_init.push_back(whisper_token_transcribe());
+ }
+ }
+
+ int progress_prev = 0;
+ int progress_step = 5;
+
+ std::vector<whisper_token_data> tokens_cur;
+ tokens_cur.reserve(whisper_n_text_ctx(ctx));
+
+ std::vector<whisper_token> prompt;
+ prompt.reserve(whisper_n_text_ctx(ctx));
+
+ // main loop
+ int seek = seek_start;
+ while (true) {
+ const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
+ while (progress_cur >= progress_prev + progress_step) {
+ progress_prev += progress_step;
+ if (params.print_progress) {
+ logInfo( u8"%s: progress = %3d%%", __func__, progress_prev );
+ }
+ }
+
+ // of only 1 second left, then stop
+ if (seek + 100 >= seek_end) {
+ break;
+ }
+
+ // if there is a very short audio segment left to process, we remove any past prompt since it tends
+ // to confuse the decoder and often make it repeat or hallucinate stuff
+ if (seek > seek_start && seek + 500 >= seek_end) {
+ prompt_past.clear();
+ }
+
+ if (params.encoder_begin_callback) {
+ if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
+ logDebug( u8"%s: encoder_begin_callback returned false - aborting", __func__ );
+ break;
+ }
+ }
+
+ // encode audio features starting at offset seek
+ if (whisper_encode(ctx, seek, params.n_threads) != 0) {
+ logError( u8"%s: failed to encode", __func__ );
+ return -4;
+ }
+
+ int n_past = 0;
+ prompt.clear();
+
+ // if we have already generated some text, use it as a prompt to condition the next generation
+ if (!prompt_past.empty()) {
+ int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
+
+ prompt = { whisper_token_prev(ctx) };
+ prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
+
+ prompt_past.clear();
+ prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
+ }
+
+ prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+
+ int seek_delta = 100*WHISPER_CHUNK_SIZE;
+
+ // print the prompt
+ //printf("\n\n");
+ //for (int i = 0; i < prompt.size(); i++) {
+ // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
+ //}
+ //printf("\n\n");
+
+ // the accumulated transcription in the current interation
+ int result_len = 0;
+ tokens_cur.clear();
+
+ bool failed = false;
+ bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
+
+ for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
+ if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
+ logError( u8"%s: failed to decode", __func__ );
+ return -5;
+ }
+
+ n_past += prompt.size();
+ prompt.clear();
+
+ // very basic greedy sampling strategy:
+ //
+ // - always take the most probable token
+ //
+ // more sophisticated sampling strategies could be implemented here, but we keep it simple
+ // feel free to experiment!
+ //
+ {
+ const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
+
+ // timestamp token - update sliding window
+ if (token.id > whisper_token_beg(ctx)) {
+ const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
+
+ // do not allow to go back in time
+ if (has_ts && seek_delta > seek_delta_new && result_len < i) {
+ break;
+ }
+
+ seek_delta = seek_delta_new;
+ result_len = i + 1;
+ has_ts = true;
+ }
+
+ // add it to the context
+ prompt.push_back(token.id);
+ tokens_cur.push_back(token);
+
+ //{
+ // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
+ // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
+ //}
+
+ // end of segment
+ if (token.id == whisper_token_eot(ctx) || // end of text token
+ (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
+ (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
+ ) {
+ if (result_len == 0) {
+ if (seek + seek_delta + 100 >= seek_end) {
+ result_len = i + 1;
+ } else {
+ failed = true;
+ break;
+ }
+ }
+
+ if (params.single_segment) {
+ result_len = i + 1;
+ seek_delta = 100*WHISPER_CHUNK_SIZE;
+ }
+
+ break;
+ }
+
+ // TESTS: if no tensors are loaded, it means we are running tests
+ if (ctx->model.n_loaded == 0) {
+ seek_delta = 100*WHISPER_CHUNK_SIZE;
+ break;
+ }
+ }
+
+ // sometimes, the decoding can get stuck in a repetition loop
+ // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
+ // the sliding window by 1 second
+ if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
+ failed = true;
+ break;
+ }
+ }
+
+ if (failed) {
+ // when we fail to sample timestamp token, retry by clearing the past prompt
+ // if it fails again, then we advance the window by 1 second
+ if (!prompt_past.empty()) {
+ prompt_past.clear();
+ } else {
+ logWarning( u8"%s: failed to generate timestamp token - skipping one second", __func__ );
+ seek += 100;
+ }
+ continue;
+ }
+
+ // shrink down to result_len
+ tokens_cur.resize(result_len);
+
+ for (const auto & r : tokens_cur) {
+ prompt_past.push_back(r.id);
+ }
+
+ // store the text from this iteration
+ if (!tokens_cur.empty()) {
+ int i0 = 0;
+ auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
+
+ std::string text;
+
+ for (int i = 0; i < (int) tokens_cur.size(); i++) {
+ //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
+ // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
+ // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
+
+ if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
+ } else {
+ text += whisper_token_to_str(ctx, tokens_cur[i].id);
+ }
+ if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
+ const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
+ if (!text.empty()) {
+ const auto tt0 = params.speed_up ? 2*t0 : t0;
+ const auto tt1 = params.speed_up ? 2*t1 : t1;
+
+ if (params.print_realtime) {
+ if (params.print_timestamps) {
+ printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
+ } else {
+ printf("%s", text.c_str());
+ fflush(stdout);
+ }
+ }
+
+ result_all.push_back({ tt0, tt1, text, {} });
+ for (int j = i0; j <= i; j++) {
+ result_all.back().tokens.push_back(tokens_cur[j]);
+ }
+
+ int n_new = 1;
+
+ if (params.token_timestamps) {
+ whisper_exp_compute_token_level_timestamps(
+ ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+ if (params.max_len > 0) {
+ n_new = whisper_wrap_segment(ctx, params.max_len);
+ }
+ }
+ if (params.new_segment_callback) {
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+ }
+ }
+ text = "";
+ while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
+ i++;
+ }
+ i--;
+ t0 = t1;
+ i0 = i + 1;
+ }
+ }
+
+ if (!text.empty()) {
+ const auto t1 = seek + seek_delta;
+
+ const auto tt0 = params.speed_up ? 2*t0 : t0;
+ const auto tt1 = params.speed_up ? 2*t1 : t1;
+
+ if (params.print_realtime) {
+ if (params.print_timestamps) {
+ printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
+ } else {
+ printf("%s", text.c_str());
+ fflush(stdout);
+ }
+ }
+
+ result_all.push_back({ tt0, tt1, text, {} });
+ for (int j = i0; j < (int) tokens_cur.size(); j++) {
+ result_all.back().tokens.push_back(tokens_cur[j]);
+ }
+
+ int n_new = 1;
+
+ if (params.token_timestamps) {
+ whisper_exp_compute_token_level_timestamps(
+ ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
+
+ if (params.max_len > 0) {
+ n_new = whisper_wrap_segment(ctx, params.max_len);
+ }
+ }
+ if (params.new_segment_callback) {
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
+ }
+ }
+ }
+
+ seek += seek_delta;
+ }
+
+ return 0;
+}
+
+int whisper_full_parallel(
+ struct whisper_context * ctx,
+ struct whisper_full_params params,
+ const float * samples,
+ int n_samples,
+ int n_processors) {
+ if (n_processors == 1) {
+ return whisper_full(ctx, params, samples, n_samples);
+ }
+
+ int ret = 0;
+
+ // prepare separate contexts for each thread
+ std::vector<struct whisper_context> ctxs(n_processors - 1);
+
+ for (int i = 0; i < n_processors - 1; ++i) {
+ ctxs[i] = *ctx;
+
+ auto & model = ctxs[i].model;
+
+ // create the ggml memory context
+ {
+ struct ggml_init_params params;
+ params.mem_size = ctxs[i].buf_memory.size();
+ params.mem_buffer = ctxs[i].buf_memory.data();
+
+ model.ctx_mem = ggml_init(params);
+ if (!model.ctx_mem) {
+ logError( u8"%s: ggml_init() failed", __func__ );
+ return false;
+ }
+ }
+
+ // separate key + value memory for each processor
+ {
+ auto & ctx = model.ctx_mem;
+
+ const auto & hparams = model.hparams;
+
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+ const int n_text_ctx = hparams.n_text_ctx;
+
+ // key/value memory for the self-attention layer
+ {
+ const int n_mem = n_text_layer*n_text_ctx;
+ const int n_elements = n_text_state*n_mem;
+
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ }
+
+ // key/value memory for the cross-attention layer
+ {
+ const int n_audio_ctx = hparams.n_audio_ctx;
+
+ const int n_mem = n_text_layer*n_audio_ctx;
+ const int n_elements = n_text_state*n_mem;
+
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ }
+ }
+ }
+
+ const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
+ const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
+
+ // the calling thread will process the first chunk
+ // while the other threads will process the remaining chunks
+
+ std::vector<std::thread> workers(n_processors - 1);
+ for (int i = 0; i < n_processors - 1; ++i) {
+ const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
+ const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
+
+ auto params_cur = params;
+
+ params_cur.offset_ms = 0;
+ params_cur.print_progress = false;
+ params_cur.print_realtime = false;
+
+ params_cur.new_segment_callback = nullptr;
+ params_cur.new_segment_callback_user_data = nullptr;
+
+ workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
+ }
+
+ {
+ auto params_cur = params;
+
+ ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
+ }
+
+ for (int i = 0; i < n_processors - 1; ++i) {
+ workers[i].join();
+ }
+
+ const int64_t offset_t = (int64_t) params.offset_ms/10.0;
+
+ // combine results into ctx->result_all
+ for (int i = 0; i < n_processors - 1; ++i) {
+ auto & results_i = ctxs[i].result_all;
+
+ for (int j = 0; j < (int) results_i.size(); ++j) {
+ // correct the segment timestamp taking into account the offset
+ results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
+ results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
+
+ // make sure that segments are not overlapping
+ if (!ctx->result_all.empty()) {
+ results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
+ }
+
+ ctx->result_all.push_back(std::move(results_i[j]));
+
+ // call the new_segment_callback for each segment
+ if (params.new_segment_callback) {
+ params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
+ }
+ }
+
+ ctx->t_mel_us += ctxs[i].t_mel_us;
+ ctx->t_sample_us += ctxs[i].t_sample_us;
+ ctx->t_encode_us += ctxs[i].t_encode_us;
+ ctx->t_decode_us += ctxs[i].t_decode_us;
+ }
+
+ // average the timings
+ ctx->t_mel_us /= n_processors;
+ ctx->t_sample_us /= n_processors;
+ ctx->t_encode_us /= n_processors;
+ ctx->t_decode_us /= n_processors;
+
+ // print information about the audio boundaries
+ logDebug( u8"%s: the audio has been split into %d chunks at the following times:", __func__, n_processors );
+ for( int i = 0; i < n_processors - 1; ++i )
+ logDebug( u8"%s: split %d - %s", __func__, ( i + 1 ), to_timestamp( 100 * ( ( i + 1 ) * n_samples_per_processor ) / WHISPER_SAMPLE_RATE + offset_t ).c_str() );
+ logDebug( u8"%s: the transcription quality may be degraded near these boundaries", __func__ );
+
+ return ret;
+}
+
+int whisper_full_n_segments(struct whisper_context * ctx) {
+ return ctx->result_all.size();
+}
+
+int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
+ return ctx->result_all[i_segment].t0;
+}
+
+int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
+ return ctx->result_all[i_segment].t1;
+}
+
+const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
+ return ctx->result_all[i_segment].text.c_str();
+}
+
+int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
+ return ctx->result_all[i_segment].tokens.size();
+}
+
+const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
+ return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
+}
+
+whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
+ return ctx->result_all[i_segment].tokens[i_token].id;
+}
+
+struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
+ return ctx->result_all[i_segment].tokens[i_token];
+}
+
+float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
+ return ctx->result_all[i_segment].tokens[i_token].p;
+}
+
+// =================================================================================================
+
+//
+// Experimental stuff below
+//
+// Not sure if these should be part of the library at all, because the quality of the results is not
+// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
+//
+
+// =================================================================================================
+
+//
+// token-level timestamps
+//
+
+static int timestamp_to_sample(int64_t t, int n_samples) {
+ return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
+}
+
+static int64_t sample_to_timestamp(int i_sample) {
+ return (100*i_sample)/WHISPER_SAMPLE_RATE;
+}
+
+// a cost-function / heuristic that is high for text that takes longer to pronounce
+// obviously, can be improved
+static float voice_length(const std::string & text) {
+ float res = 0.0f;
+
+ for (size_t i = 0; i < text.size(); ++i) {
+ if (text[i] == ' ') {
+ res += 0.01f;
+ } else if (text[i] == ',') {
+ res += 2.00f;
+ } else if (text[i] == '.') {
+ res += 3.00f;
+ } else if (text[i] == '!') {
+ res += 3.00f;
+ } else if (text[i] == '?') {
+ res += 3.00f;
+ } else if (text[i] >= '0' && text[i] <= '9') {
+ res += 3.00f;
+ } else {
+ res += 1.00f;
+ }
+ }
+
+ return res;
+}
+
+// average the fabs of the signal
+static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
+ const int hw = n_samples_per_half_window;
+
+ std::vector<float> result(n_samples);
+
+ for (int i = 0; i < n_samples; i++) {
+ float sum = 0;
+ for (int j = -hw; j <= hw; j++) {
+ if (i + j >= 0 && i + j < n_samples) {
+ sum += fabs(signal[i + j]);
+ }
+ }
+ result[i] = sum/(2*hw + 1);
+ }
+
+ return result;
+}
+
+static void whisper_exp_compute_token_level_timestamps(
+ struct whisper_context * ctx,
+ int i_segment,
+ float thold_pt,
+ float thold_ptsum) {
+ auto & segment = ctx->result_all[i_segment];
+ auto & tokens = segment.tokens;
+
+ const int n_samples = ctx->energy.size();
+
+ if (n_samples == 0) {
+ logWarning( u8"%s: no signal data available", __func__ );
+ return;
+ }
+
+ const int64_t t0 = segment.t0;
+ const int64_t t1 = segment.t1;
+
+ const int n = tokens.size();
+
+ if (n == 0) {
+ return;
+ }
+
+ if (n == 1) {
+ tokens[0].t0 = t0;
+ tokens[0].t1 = t1;
+
+ return;
+ }
+
+ auto & t_beg = ctx->t_beg;
+ auto & t_last = ctx->t_last;
+ auto & tid_last = ctx->tid_last;
+
+ for (int j = 0; j < n; ++j) {
+ auto & token = tokens[j];
+
+ if (j == 0) {
+ if (token.id == whisper_token_beg(ctx)) {
+ tokens[j ].t0 = t0;
+ tokens[j ].t1 = t0;
+ tokens[j + 1].t0 = t0;
+
+ t_beg = t0;
+ t_last = t0;
+ tid_last = whisper_token_beg(ctx);
+ } else {
+ tokens[j ].t0 = t_last;
+ }
+ }
+
+ const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
+
+ tokens[j].id = token.id;
+ tokens[j].tid = token.tid;
+ tokens[j].p = token.p;
+ tokens[j].pt = token.pt;
+ tokens[j].ptsum = token.ptsum;
+
+ tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
+
+ if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
+ if (j > 0) {
+ tokens[j - 1].t1 = tt;
+ }
+ tokens[j].t0 = tt;
+ tid_last = token.tid;
+ }
+ }
+
+ tokens[n - 2].t1 = t1;
+ tokens[n - 1].t0 = t1;
+ tokens[n - 1].t1 = t1;
+
+ t_last = t1;
+
+ // find intervals of tokens with unknown timestamps
+ // fill the timestamps by proportionally splitting the interval based on the token voice lengths
+ {
+ int p0 = 0;
+ int p1 = 0;
+
+ while (true) {
+ while (p1 < n && tokens[p1].t1 < 0) {
+ p1++;
+ }
+
+ if (p1 >= n) {
+ p1--;
+ }
+
+ if (p1 > p0) {
+ double psum = 0.0;
+ for (int j = p0; j <= p1; j++) {
+ psum += tokens[j].vlen;
+ }
+
+ //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
+
+ const double dt = tokens[p1].t1 - tokens[p0].t0;
+
+ // split the time proportionally to the voice length
+ for (int j = p0 + 1; j <= p1; j++) {
+ const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
+
+ tokens[j - 1].t1 = ct;
+ tokens[j ].t0 = ct;
+ }
+ }
+
+ p1++;
+ p0 = p1;
+ if (p1 >= n) {
+ break;
+ }
+ }
+ }
+
+ // fix up (just in case)
+ for (int j = 0; j < n - 1; j++) {
+ if (tokens[j].t1 < 0) {
+ tokens[j + 1].t0 = tokens[j].t1;
+ }
+
+ if (j > 0) {
+ if (tokens[j - 1].t1 > tokens[j].t0) {
+ tokens[j].t0 = tokens[j - 1].t1;
+ tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
+ }
+ }
+ }
+
+ // VAD
+ // expand or contract tokens based on voice activity
+ {
+ const int hw = WHISPER_SAMPLE_RATE/8;
+
+ for (int j = 0; j < n; j++) {
+ if (tokens[j].id >= whisper_token_eot(ctx)) {
+ continue;
+ }
+
+ int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
+ int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
+
+ const int ss0 = std::max(s0 - hw, 0);
+ const int ss1 = std::min(s1 + hw, n_samples);
+
+ const int ns = ss1 - ss0;
+
+ float sum = 0.0f;
+
+ for (int k = ss0; k < ss1; k++) {
+ sum += ctx->energy[k];
+ }
+
+ const float thold = 0.5*sum/ns;
+
+ {
+ int k = s0;
+ if (ctx->energy[k] > thold && j > 0) {
+ while (k > 0 && ctx->energy[k] > thold) {
+ k--;
+ }
+ tokens[j].t0 = sample_to_timestamp(k);
+ if (tokens[j].t0 < tokens[j - 1].t1) {
+ tokens[j].t0 = tokens[j - 1].t1;
+ } else {
+ s0 = k;
+ }
+ } else {
+ while (ctx->energy[k] < thold && k < s1) {
+ k++;
+ }
+ s0 = k;
+ tokens[j].t0 = sample_to_timestamp(k);
+ }
+ }
+
+ {
+ int k = s1;
+ if (ctx->energy[k] > thold) {
+ while (k < n_samples - 1 && ctx->energy[k] > thold) {
+ k++;
+ }
+ tokens[j].t1 = sample_to_timestamp(k);
+ if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
+ tokens[j].t1 = tokens[j + 1].t0;
+ } else {
+ s1 = k;
+ }
+ } else {
+ while (ctx->energy[k] < thold && k > s0) {
+ k--;
+ }
+ s1 = k;
+ tokens[j].t1 = sample_to_timestamp(k);
+ }
+ }
+ }
+ }
+
+ // fixed token expand (optional)
+ //{
+ // const int t_expand = 0;
+
+ // for (int j = 0; j < n; j++) {
+ // if (j > 0) {
+ // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
+ // }
+ // if (j < n - 1) {
+ // tokens[j].t1 = tokens[j].t1 + t_expand;
+ // }
+ // }
+ //}
+
+ // debug info
+ //for (int j = 0; j < n; ++j) {
+ // const auto & token = tokens[j];
+ // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
+ // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
+ // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
+
+ // if (tokens[j].id >= whisper_token_eot(ctx)) {
+ // continue;
+ // }
+ //}
+}
diff --git a/Whisper/source/whisper.h b/Whisper/source/whisper.h
new file mode 100644
index 0000000..92c14da
--- /dev/null
+++ b/Whisper/source/whisper.h
@@ -0,0 +1,330 @@
+#ifndef WHISPER_H
+#define WHISPER_H
+
+#include <stdint.h>
+#include <stdbool.h>
+
+#ifdef WHISPER_SHARED
+# ifdef _WIN32
+# ifdef WHISPER_BUILD
+# define WHISPER_API __declspec(dllexport)
+# else
+# define WHISPER_API __declspec(dllimport)
+# endif
+# else
+# define WHISPER_API __attribute__ ((visibility ("default")))
+# endif
+#else
+# define WHISPER_API
+#endif
+
+#define WHISPER_SAMPLE_RATE 16000
+#define WHISPER_N_FFT 400
+#define WHISPER_N_MEL 80
+#define WHISPER_HOP_LENGTH 160
+#define WHISPER_CHUNK_SIZE 30
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ //
+ // C interface
+ //
+ // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads
+ // concurrently.
+ //
+ // Basic usage:
+ //
+ // #include "whisper.h"
+ //
+ // ...
+ //
+ // struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin");
+ //
+ // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
+ // fprintf(stderr, "failed to process audio\n");
+ // return 7;
+ // }
+ //
+ // const int n_segments = whisper_full_n_segments(ctx);
+ // for (int i = 0; i < n_segments; ++i) {
+ // const char * text = whisper_full_get_segment_text(ctx, i);
+ // printf("%s", text);
+ // }
+ //
+ // whisper_free(ctx);
+ //
+ // ...
+ //
+ // This is a demonstration of the most straightforward usage of the library.
+ // "pcmf32" contains the RAW audio data in 32-bit floating point format.
+ //
+ // The interface also allows for more fine-grained control over the computation, but it requires a deeper
+ // understanding of how the model works.
+ //
+
+ struct whisper_context;
+
+ typedef int whisper_token;
+
+ typedef struct whisper_token_data {
+ whisper_token id; // token id
+ whisper_token tid; // forced timestamp token id
+
+ float p; // probability of the token
+ float pt; // probability of the timestamp token
+ float ptsum; // sum of probabilities of all timestamp tokens
+
+ // token-level timestamp data
+ // do not use if you haven't computed token-level timestamps
+ int64_t t0; // start time of the token
+ int64_t t1; // end time of the token
+
+ float vlen; // voice length of the token
+ } whisper_token_data;
+
+ // Allocates all memory needed for the model and loads the model from the given file.
+ // Returns NULL on failure.
+ WHISPER_API struct whisper_context * whisper_init(const char * path_model);
+
+ // Frees all memory allocated by the model.
+ WHISPER_API void whisper_free(struct whisper_context * ctx);
+
+ // Convert RAW PCM audio to log mel spectrogram.
+ // The resulting spectrogram is stored inside the provided whisper context.
+ // Returns 0 on success
+ WHISPER_API int whisper_pcm_to_mel(
+ struct whisper_context * ctx,
+ const float * samples,
+ int n_samples,
+ int n_threads);
+
+ // This can be used to set a custom log mel spectrogram inside the provided whisper context.
+ // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
+ // n_mel must be 80
+ // Returns 0 on success
+ WHISPER_API int whisper_set_mel(
+ struct whisper_context * ctx,
+ const float * data,
+ int n_len,
+ int n_mel);
+
+ // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
+ // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
+ // offset can be used to specify the offset of the first frame in the spectrogram.
+ // Returns 0 on success
+ WHISPER_API int whisper_encode(
+ struct whisper_context * ctx,
+ int offset,
+ int n_threads);
+
+ // Run the Whisper decoder to obtain the logits and probabilities for the next token.
+ // Make sure to call whisper_encode() first.
+ // tokens + n_tokens is the provided context for the decoder.
+ // n_past is the number of tokens to use from previous decoder calls.
+ // Returns 0 on success
+ WHISPER_API int whisper_decode(
+ struct whisper_context * ctx,
+ const whisper_token * tokens,
+ int n_tokens,
+ int n_past,
+ int n_threads);
+
+ // Token sampling methods.
+ // These are provided for convenience and can be used after each call to whisper_decode().
+ // You can also implement your own sampling method using the whisper_get_probs() function.
+ // whisper_sample_best() returns the token with the highest probability
+ // whisper_sample_timestamp() returns the most probable timestamp token
+ WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
+ WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
+
+ // Convert the provided text into tokens.
+ // The tokens pointer must be large enough to hold the resulting tokens.
+ // Returns the number of tokens on success, no more than n_max_tokens
+ // Returns -1 on failure
+ // TODO: not sure if correct
+ WHISPER_API int whisper_tokenize(
+ struct whisper_context * ctx,
+ const char * text,
+ whisper_token * tokens,
+ int n_max_tokens);
+
+ // Largest language id (i.e. number of available languages - 1)
+ WHISPER_API int whisper_lang_max_id();
+
+ // Return the id of the specified language, returns -1 if not found
+ // Examples:
+ // "de" -> 2
+ // "german" -> 2
+ WHISPER_API int whisper_lang_id(const char * lang);
+
+ // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
+ WHISPER_API const char * whisper_lang_str(int id);
+
+ // Use mel data at offset_ms to try and auto-detect the spoken language
+ // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
+ // Returns the top language id or negative on failure
+ // If not null, fills the lang_probs array with the probabilities of all languages
+ // The array must be whispe_lang_max_id() + 1 in size
+ // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
+ WHISPER_API int whisper_lang_auto_detect(
+ struct whisper_context * ctx,
+ int offset_ms,
+ int n_threads,
+ float * lang_probs);
+
+ WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
+ WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
+ WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
+ WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
+
+ // The probabilities for the next token
+ WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
+
+ // Token Id -> String. Uses the vocabulary in the provided context
+ WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
+
+ // Special tokens
+ WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
+ WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
+
+ // Task tokens
+ WHISPER_API whisper_token whisper_token_translate (void);
+ WHISPER_API whisper_token whisper_token_transcribe(void);
+
+ // Performance information
+ WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
+ WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);
+
+ // Print system information
+ WHISPER_API const char * whisper_print_system_info(void);
+
+ ////////////////////////////////////////////////////////////////////////////
+
+ // Available sampling strategies
+ enum whisper_sampling_strategy {
+ WHISPER_SAMPLING_GREEDY, // Always select the most probable token
+ WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
+ };
+
+ // Text segment callback
+ // Called on every newly generated text segment
+ // Use the whisper_full_...() functions to obtain the text segments
+ typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
+
+ // Encoder begin callback
+ // If not NULL, called before the encoder starts
+ // If it returns false, the computation is aborted
+ typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
+
+ // Parameters for the whisper_full() function
+ // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
+ // whisper_full_default_params()
+ struct whisper_full_params {
+ enum whisper_sampling_strategy strategy;
+
+ int n_threads;
+ int n_max_text_ctx;
+ int offset_ms; // start offset in ms
+ int duration_ms; // audio duration to process in ms
+
+ bool translate;
+ bool no_context;
+ bool single_segment; // force single segment output (useful for streaming)
+ bool print_special;
+ bool print_progress;
+ bool print_realtime;
+ bool print_timestamps;
+
+ // [EXPERIMENTAL] token-level timestamps
+ bool token_timestamps; // enable token-level timestamps
+ float thold_pt; // timestamp token probability threshold (~0.01)
+ float thold_ptsum; // timestamp token sum probability threshold (~0.01)
+ int max_len; // max segment length in characters
+ int max_tokens; // max tokens per segment (0 = no limit)
+
+ // [EXPERIMENTAL] speed-up techniques
+ bool speed_up; // speed-up the audio by 2x using Phase Vocoder
+ int audio_ctx; // overwrite the audio context size (0 = use default)
+
+ // tokens to provide the whisper model as initial prompt
+ // these are prepended to any existing text context from a previous call
+ const whisper_token * prompt_tokens;
+ int prompt_n_tokens;
+
+ // for auto-detection, set to nullptr, "" or "auto"
+ const char * language;
+
+ struct {
+ int n_past;
+ } greedy;
+
+ struct {
+ int n_past;
+ int beam_width;
+ int n_best;
+ } beam_search;
+
+ whisper_new_segment_callback new_segment_callback;
+ void * new_segment_callback_user_data;
+
+ whisper_encoder_begin_callback encoder_begin_callback;
+ void * encoder_begin_callback_user_data;
+ };
+
+ WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
+
+ // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
+ // Uses the specified decoding strategy to obtain the text.
+ WHISPER_API int whisper_full(
+ struct whisper_context * ctx,
+ struct whisper_full_params params,
+ const float * samples,
+ int n_samples);
+
+ // Split the input audio in chunks and process each chunk separately using whisper_full()
+ // It seems this approach can offer some speedup in some cases.
+ // However, the transcription accuracy can be worse at the beginning and end of each chunk.
+ WHISPER_API int whisper_full_parallel(
+ struct whisper_context * ctx,
+ struct whisper_full_params params,
+ const float * samples,
+ int n_samples,
+ int n_processors);
+
+ // Number of generated text segments.
+ // A segment can be a few words, a sentence, or even a paragraph.
+ WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx);
+
+ // Get the start and end time of the specified segment.
+ WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment);
+ WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment);
+
+ // Get the text of the specified segment.
+ WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment);
+
+ // Get number of tokens in the specified segment.
+ WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment);
+
+ // Get the token text of the specified token in the specified segment.
+ WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
+ WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
+
+ // Get token data for the specified token in the specified segment.
+ // This contains probabilities, timestamps, etc.
+ WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
+
+ // Get the probability of the specified token in the specified segment.
+ WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/Whisper/stdafx.cpp b/Whisper/stdafx.cpp
new file mode 100644
index 0000000..1577c4e
--- /dev/null
+++ b/Whisper/stdafx.cpp
@@ -0,0 +1 @@
+#include "stdafx.h" \ No newline at end of file
diff --git a/Whisper/stdafx.h b/Whisper/stdafx.h
new file mode 100644
index 0000000..c84e10d
--- /dev/null
+++ b/Whisper/stdafx.h
@@ -0,0 +1,43 @@
+#pragma once
+#define _USE_MATH_DEFINES
+#include <stdint.h>
+#include <assert.h>
+#include <array>
+#include <vector>
+#include <algorithm>
+#include <emmintrin.h> // SSE 2
+#include <smmintrin.h> // SSE 4.1
+
+#define WIN32_LEAN_AND_MEAN
+#define NOMINMAX
+// Setup Windows SDK to only enable features available since Windows 8.0
+#include <WinSDKVer.h>
+#define _WIN32_WINNT _WIN32_WINNT_WIN8
+#define NTDDI_VERSION NTDDI_WIN8
+#include <sdkddkver.h>
+
+#include <windows.h>
+#include <d3d11.h>
+#include <atlcomcli.h>
+#include "Utils/Logger.h"
+#include "Utils/miscUtils.h"
+
+// Build both legacy and DirectCompute implementations
+#define BUILD_BOTH_VERSIONS 0
+
+// Build hybrid model which uses DirectCompute only for the encode step of the algorithm, and decodes on CPU, using AVX SIMD and the Windows' built-in thread pool.
+#define BUILD_HYBRID_VERSION 0
+
+// Enable debug traces. Should be disabled in production, the feature comes with a huge performance overhead.
+// When enabled, while computing things it streams gigabytes of data into that binary file.
+// See Tools / compareTraces project for a command-line app to compare these traces.
+#define SAVE_DEBUG_TRACE 0
+
+// In addition to collecting total GPU times per compute shader, also collect and print performance data about individual invocations of some of the most expensive shaders
+// The feature is relatively cheap in terms of performance overhead, but pretty much useless in production, and clutters debug console with all these numbers
+#define PROFILER_COLLECT_TAGS 0
+
+// Reshape some of the tensors to a better VRAM layout while loading a model
+// So far, the feature is only used on AMD GPUs. On AMD Vega integrated GPUs it helps by up to 30%.
+// Should be enabled in production build
+#define RESHAPED_MATRIX_MULTIPLY 1 \ No newline at end of file
diff --git a/Whisper/whisper.def b/Whisper/whisper.def
new file mode 100644
index 0000000..69ada14
--- /dev/null
+++ b/Whisper/whisper.def
@@ -0,0 +1,7 @@
+LIBRARY
+EXPORTS setupLogger
+EXPORTS loadModel
+EXPORTS initMediaFoundation
+EXPORTS findLanguageKeyW
+EXPORTS findLanguageKeyA
+EXPORTS getSupportedLanguages \ No newline at end of file
diff --git a/Whisper/whisperCom.cpp b/Whisper/whisperCom.cpp
new file mode 100644
index 0000000..a0205ec
--- /dev/null
+++ b/Whisper/whisperCom.cpp
@@ -0,0 +1,1070 @@
+#include "stdafx.h"
+#include "ML/Tensor.h"
+#include "API/iMediaFoundation.cl.h"
+#include "API/iContext.cl.h"
+#include "API/sFullParams.h"
+#include "Utils/ReadStream.h"
+#include "ML/testUtils.h"
+#include "Utils/Trace/tracing.h"
+#include "modelFactory.h"
+#if BUILD_BOTH_VERSIONS
+
+namespace
+{
+ LPCTSTR traceFilePath = LR"(C:\Temp\2remove\Whisper\ref.bin)";
+ using ComLight::iReadStream;
+}
+
+struct whisper_context;
+struct ggml_tensor;
+
+class GpuEncTest
+{
+ DirectCompute::Tensor mel, gpuResult;
+
+ DirectCompute::Tensor tempGpu;
+ const ggml_tensor* tempRef = nullptr;
+public:
+ GpuEncTest( const whisper_context& wctx, const int mel_offset );
+ void compare( const ggml_tensor* expected ) const;
+ void compareMel( const ggml_tensor* expected ) const;
+};
+
+class GpuDecTest
+{
+ std::vector<float> logits, probs;
+ const ggml_tensor* tempRef = nullptr;
+
+public:
+
+ GpuDecTest( const whisper_context& wctx, const int* tokens, const int n_tokens, const int n_past );
+
+ void postpone( const ggml_tensor* t );
+ void comparePostponed();
+ void compare( const std::vector<float>& cpuLogits, const std::vector<float>& cpuProbs ) const;
+};
+
+static DirectCompute::Tensor gpuEncode( const whisper_context& wctx, const int mel_offset );
+
+#include "source/whisper.cpp"
+#include "API/iContext.cl.h"
+#include "../ComLightLib/comLightServer.h"
+#include "ML/mlStartup.h"
+#include "Whisper/WhisperContext.h"
+#include "Whisper/ModelLoader.h"
+#include "Whisper/WhisperModel.h"
+#include "source.compat/convertThings.h"
+
+namespace Whisper
+{
+ inline HRESULT isZero( int i )
+ {
+ return ( 0 == i ) ? S_OK : E_FAIL;
+ }
+
+ class Context : public ComLight::ObjectRoot<iContext>,
+ public iModel
+ {
+ virtual HRESULT COMLIGHTCALL isMultilingual() override final
+ {
+ return whisper_is_multilingual( &ctx ) ? S_OK : S_FALSE;
+ }
+ virtual const char* COMLIGHTCALL stringFromToken( whisper_token token ) override final
+ {
+ return whisper_token_to_str( &ctx, token );
+ }
+ virtual HRESULT COMLIGHTCALL getSpecialTokens( SpecialTokens& rdi )
+ {
+ rdi.TranscriptionEnd = whisper_token_eot( &ctx );
+ rdi.TranscriptionStart = whisper_token_sot( &ctx );
+ rdi.PreviousWord = whisper_token_prev( &ctx );
+ rdi.SentenceStart = whisper_token_solm( &ctx );
+ rdi.Not = whisper_token_not( &ctx );
+ rdi.TranscriptionBegin = whisper_token_beg( &ctx );
+ rdi.TaskTranslate = whisper_token_translate();
+ rdi.TaskTranscribe = whisper_token_transcribe();
+ return S_OK;
+ }
+
+ // Performance information
+ virtual HRESULT COMLIGHTCALL timingsPrint() override final
+ {
+ whisper_print_timings( &ctx );
+ return S_OK;
+ }
+ virtual HRESULT COMLIGHTCALL timingsReset() override final
+ {
+ whisper_reset_timings( &ctx );
+ return S_OK;
+ }
+
+ virtual HRESULT COMLIGHTCALL fullDefaultParams( eSamplingStrategy strategy, sFullParams* rdi )
+ {
+ static_assert( (int)eSamplingStrategy::Greedy == whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY );
+ static_assert( (int)eSamplingStrategy::BeamSearch == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH );
+ const whisper_sampling_strategy wss = (whisper_sampling_strategy)(int)strategy;
+ whisper_full_params wfp = whisper_full_default_params( wss );
+
+ *rdi = makeNewParams( wfp );
+ return S_OK;
+ }
+
+ HRESULT COMLIGHTCALL runFull( const sFullParams& params, const iAudioBuffer* buffer ) override final
+ {
+ whisper_full_params wfp = makeOldParams( params, this );
+ const float* const samples = buffer->getPcmMono();
+ const uint32_t n_samples = buffer->countSamples();
+ return isZero( whisper_full( &ctx, wfp, samples, (int)n_samples ) );
+ }
+
+ HRESULT COMLIGHTCALL runStreamed( const sFullParams& params, const sProgressSink& progress, const iAudioReader* reader ) override final
+ {
+ logError( u8"The CPU reference implementation doesn’t support streaming" );
+ return E_NOTIMPL;
+ }
+ HRESULT COMLIGHTCALL runCapture( const sFullParams& params, const sCaptureCallbacks& callbacks, const iAudioCapture* reader ) override final
+ {
+ logError( u8"The CPU reference implementation doesn’t support audio capture" );
+ return E_NOTIMPL;
+ }
+
+ HRESULT COMLIGHTCALL getResults( eResultFlags flags, iTranscribeResult** pp ) const override final
+ {
+ makeNewResults( &ctx, flags, pp );
+ return S_OK;
+ }
+
+ HRESULT loadImpl( iReadStream* stm );
+
+ virtual HRESULT COMLIGHTCALL createContext( iContext** pp ) override final
+ {
+ if( nullptr == pp )
+ return E_POINTER;
+ *pp = this;
+ ( *pp )->AddRef();
+ return S_OK;
+ }
+
+ virtual HRESULT COMLIGHTCALL getModel( iModel** pp ) override final
+ {
+ if( nullptr == pp )
+ return E_POINTER;
+ *pp = this;
+ ( *pp )->AddRef();
+ return S_OK;
+ }
+
+ public:
+
+ Context()
+ {
+ if( nullptr != traceFilePath )
+ Tracing::traceCreate( traceFilePath );
+ }
+
+ mutable whisper_context ctx;
+
+ HRESULT load( iReadStream* stm );
+
+ ~Context()
+ {
+ Tracing::traceClose();
+
+ if( ctx.model.ctx )
+ {
+ ggml_free( ctx.model.ctx );
+ ctx.model.ctx = nullptr;
+ }
+ if( ctx.model.ctx_mem )
+ {
+ ggml_free( ctx.model.ctx_mem );
+ ctx.model.ctx_mem = nullptr;
+ }
+ if( ctx.buf_model )
+ {
+ delete ctx.buf_model;
+ ctx.buf_model = nullptr;
+ }
+ }
+
+ BEGIN_COM_MAP()
+ COM_INTERFACE_ENTRY( iModel );
+ END_COM_MAP()
+ };
+
+ inline HRESULT readBytes( iReadStream* stm, void* rdi, size_t cb )
+ {
+ if( cb > INT_MAX )
+ return DISP_E_OVERFLOW;
+ if( cb == 0 )
+ return S_FALSE;
+ int n;
+ CHECK( stm->read( rdi, (int)cb, n ) );
+ if( n != (int)cb )
+ return E_EOF;
+ return S_OK;
+ }
+
+ template<typename T>
+ inline HRESULT readStruct( iReadStream* stm, T& dest )
+ {
+ return readBytes( stm, &dest, sizeof( T ) );
+ }
+ template<typename E>
+ inline HRESULT readVector( iReadStream* stm, std::vector<E>& vec )
+ {
+ const size_t cb = sizeof( E ) * vec.size();
+ if( cb > 0 )
+ return readBytes( stm, vec.data(), cb );
+ return S_FALSE;
+ }
+
+ inline HRESULT readString( iReadStream* stm, std::string& str )
+ {
+ uint32_t len;
+ CHECK( readStruct( stm, len ) );
+ if( len > 0 )
+ {
+ str.resize( len );
+ return readBytes( stm, str.data(), len );
+ }
+ else
+ {
+ str.clear();
+ return S_FALSE;
+ }
+ }
+
+ // load the model from a ggml file
+ // file format:
+ // - hparams
+ // - pre-computed mel filters
+ // - vocab
+ // - weights
+ // see the convert-pt-to-ggml.py script for details
+ HRESULT Context::loadImpl( iReadStream* stm )
+ {
+ // WhisperModel wm;
+ // return wm.load( stm );
+
+ // Copy-pasted from whisper_model_load() function
+ auto& model = ctx.model;
+ auto& vocab = ctx.vocab;
+
+ // verify magic
+ {
+ uint32_t magic;
+ int cbRead;
+ CHECK( stm->read( &magic, 4, cbRead ) );
+ if( magic != 0x67676d6c )
+ {
+ logError( u8"Invalid model file, bad magic" );
+ return E_INVALIDARG;
+ }
+ }
+
+ //load hparams
+ {
+ auto& hparams = model.hparams;
+ CHECK( readStruct( stm, hparams ) );
+ assert( hparams.n_text_state == hparams.n_audio_state );
+
+ if( hparams.n_audio_layer == 4 )
+ model.type = e_model::MODEL_TINY;
+ if( hparams.n_audio_layer == 6 )
+ model.type = e_model::MODEL_BASE;
+ if( hparams.n_audio_layer == 12 )
+ model.type = e_model::MODEL_SMALL;
+ if( hparams.n_audio_layer == 24 )
+ model.type = e_model::MODEL_MEDIUM;
+ if( hparams.n_audio_layer == 32 )
+ model.type = e_model::MODEL_LARGE;
+
+ logDebug( u8"%s: n_vocab = %d", __func__, hparams.n_vocab );
+ logDebug( u8"%s: n_audio_ctx = %d", __func__, hparams.n_audio_ctx );
+ logDebug( u8"%s: n_audio_state = %d", __func__, hparams.n_audio_state );
+ logDebug( u8"%s: n_audio_head = %d", __func__, hparams.n_audio_head );
+ logDebug( u8"%s: n_audio_layer = %d", __func__, hparams.n_audio_layer );
+ logDebug( u8"%s: n_text_ctx = %d", __func__, hparams.n_text_ctx );
+ logDebug( u8"%s: n_text_state = %d", __func__, hparams.n_text_state );
+ logDebug( u8"%s: n_text_head = %d", __func__, hparams.n_text_head );
+ logDebug( u8"%s: n_text_layer = %d", __func__, hparams.n_text_layer );
+ logDebug( u8"%s: n_mels = %d", __func__, hparams.n_mels );
+ logDebug( u8"%s: f16 = %d", __func__, hparams.f16 );
+ logDebug( u8"%s: type = %d", __func__, model.type );
+
+ ctx.buf_model = new std::vector<uint8_t>();
+ ctx.buf_model->resize( MEM_REQ_MODEL.at( model.type ) );
+ ctx.buf_memory.resize( MEM_REQ_MEMORY.at( model.type ) );
+ ctx.buf_compute.resize( std::max( MEM_REQ_ENCODE.at( model.type ), MEM_REQ_DECODE.at( model.type ) ) );
+ ctx.buf_compute_layer.resize( std::max( MEM_REQ_ENCODE_LAYER.at( model.type ), MEM_REQ_DECODE_LAYER.at( model.type ) ) );
+ }
+
+ // load mel filters
+ {
+ auto& filters = ctx.model.filters;
+ CHECK( readStruct( stm, filters.n_mel ) );
+ CHECK( readStruct( stm, filters.n_fft ) );
+ filters.data.resize( filters.n_mel * filters.n_fft );
+ CHECK( readVector( stm, filters.data ) );
+ }
+
+ // load vocab
+ {
+ int32_t n_vocab = 0;
+ CHECK( readStruct( stm, n_vocab ) );
+
+ //if (n_vocab != model.hparams.n_vocab) {
+ // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+ // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+ // return false;
+ //}
+
+ std::string word;
+ for( int i = 0; i < n_vocab; i++ )
+ {
+ CHECK( readString( stm, word ) );
+ vocab.token_to_id[ word ] = i;
+ vocab.id_to_token[ i ] = word;
+ }
+
+ vocab.n_vocab = model.hparams.n_vocab;
+ if( vocab.is_multilingual() )
+ {
+ vocab.token_eot++;
+ vocab.token_sot++;
+ vocab.token_prev++;
+ vocab.token_solm++;
+ vocab.token_not++;
+ vocab.token_beg++;
+ }
+
+ if( n_vocab < model.hparams.n_vocab )
+ {
+ logDebug( u8"%s: adding %d extra tokens", __func__, model.hparams.n_vocab - n_vocab );
+ for( int i = n_vocab; i < model.hparams.n_vocab; i++ )
+ {
+ if( i > vocab.token_beg )
+ word = "[_TT_" + std::to_string( i - vocab.token_beg ) + "]";
+ else if( i == vocab.token_eot )
+ word = "[_EOT_]";
+ else if( i == vocab.token_sot )
+ word = "[_SOT_]";
+ else if( i == vocab.token_prev )
+ word = "[_PREV_]";
+ else if( i == vocab.token_not )
+ word = "[_NOT_]";
+ else if( i == vocab.token_beg )
+ word = "[_BEG_]";
+ else
+ word = "[_extra_token_" + std::to_string( i ) + "]";
+
+ vocab.token_to_id[ word ] = i;
+ vocab.id_to_token[ i ] = word;
+ }
+ }
+ }
+
+ {
+ // this is the total memory required to run the inference
+ const size_t mem_required =
+ ctx.buf_model->size() +
+ ctx.buf_memory.size() +
+ ctx.buf_compute.size() +
+ ctx.buf_compute_layer.size();
+ logDebug( u8"%s: mem_required = %7.2f MB", __func__, mem_required / 1024.0 / 1024.0 );
+ }
+
+ // for the big tensors, we have the option to store the data in 16-bit floats
+ // in order to save memory and also to speed up the computation
+ const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ size_t ctx_size = 0;
+ size_t ctx_mem_size = 0;
+
+ {
+ const auto& hparams = model.hparams;
+
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_audio_ctx = hparams.n_audio_ctx;
+ const int n_audio_state = hparams.n_audio_state;
+ const int n_audio_layer = hparams.n_audio_layer;
+
+ const int n_text_ctx = hparams.n_text_ctx;
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+
+ const int n_mels = hparams.n_mels;
+
+ // encoder
+ {
+ // TODO: F16 .. maybe not?
+ ctx_size += n_audio_ctx * n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_pe;
+
+ ctx_size += 3 * n_mels * n_audio_state * ggml_type_size( wtype ); // e_conv_1_w
+ ctx_size += n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_conv_1_b
+
+ ctx_size += 3 * n_audio_state * n_audio_state * ggml_type_size( wtype ); // e_conv_2_w
+ ctx_size += n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_conv_2_b
+
+ ctx_size += n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_ln_w;
+ ctx_size += n_audio_state * ggml_type_size( GGML_TYPE_F32 ); // e_ln_b;
+ }
+
+ // decoder
+ {
+ // TODO: F16 .. maybe not?
+ ctx_size += n_text_ctx * n_text_state * ggml_type_size( GGML_TYPE_F32 ); // d_pe;
+
+ ctx_size += n_vocab * n_text_state * ggml_type_size( wtype ); // d_te;
+
+ ctx_size += n_text_state * ggml_type_size( GGML_TYPE_F32 ); // d_ln_w;
+ ctx_size += n_text_state * ggml_type_size( GGML_TYPE_F32 ); // d_ln_b;
+ }
+
+ // encoder layers
+ {
+ ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_ln_w
+ ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_ln_b
+
+ ctx_size += n_audio_layer * ( 4 * n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // mlp_0_w
+ ctx_size += n_audio_layer * ( 4 * n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_0_b
+
+ ctx_size += n_audio_layer * ( 4 * n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // mlp_1_w
+ ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_1_b
+
+ ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_0_w
+ ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_0_b
+
+ ctx_size += n_audio_layer * ( n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // attn_q_w
+ ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_q_b
+
+ ctx_size += n_audio_layer * ( n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // attn_k_w
+
+ ctx_size += n_audio_layer * ( n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // attn_v_w
+ ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_v_b
+
+ ctx_size += n_audio_layer * ( n_audio_state * n_audio_state * ggml_type_size( wtype ) ); // attn_ln_1_w
+ ctx_size += n_audio_layer * ( n_audio_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_1_b
+ }
+
+ // decoder layers
+ {
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_ln_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_ln_b
+
+ ctx_size += n_text_layer * ( 4 * n_text_state * n_text_state * ggml_type_size( wtype ) ); // mlp_0_w
+ ctx_size += n_text_layer * ( 4 * n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_0_b
+
+ ctx_size += n_text_layer * ( 4 * n_text_state * n_text_state * ggml_type_size( wtype ) ); // mlp_1_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // mlp_1_b
+
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_0_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_0_b
+
+ ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // attn_q_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_q_b
+
+ ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // attn_k_w
+
+ ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // attn_v_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_v_b
+
+ ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // attn_ln_1_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // attn_ln_1_b
+ //
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_ln_0_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_ln_0_b
+
+ ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // cross_attn_q_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_q_b
+
+ ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // cross_attn_k_w
+
+ ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // cross_attn_v_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_v_b
+
+ ctx_size += n_text_layer * ( n_text_state * n_text_state * ggml_type_size( wtype ) ); // cross_attn_ln_1_w
+ ctx_size += n_text_layer * ( n_text_state * ggml_type_size( GGML_TYPE_F32 ) ); // cross_attn_ln_1_b
+ }
+
+ ctx_mem_size += n_text_layer * n_text_ctx * n_text_state * ggml_type_size( GGML_TYPE_F16 ); // memory_k
+ ctx_mem_size += n_text_layer * n_text_ctx * n_text_state * ggml_type_size( GGML_TYPE_F16 ); // memory_v
+
+ ctx_mem_size += n_text_layer * n_audio_ctx * n_text_state * ggml_type_size( GGML_TYPE_F16 ); // memory_cross_k
+ ctx_mem_size += n_text_layer * n_audio_ctx * n_text_state * ggml_type_size( GGML_TYPE_F16 ); // memory_cross_v
+
+ ctx_size += ( 15 + 15 * n_audio_layer + 24 * n_text_layer ) * 256; // object overhead
+
+ logDebug( u8"%s: ggml ctx size = %7.2f MB", __func__, ctx_size / ( 1024.0 * 1024.0 ) );
+ }
+
+ // create the ggml context
+ {
+ struct ggml_init_params params;
+ params.mem_size = ctx.buf_model->size();
+ params.mem_buffer = ctx.buf_model->data();
+
+ model.ctx = ggml_init( params );
+ if( !model.ctx )
+ {
+ logError( u8"%s: ggml_init() failed", __func__ );
+ return E_INVALIDARG;
+ }
+ }
+
+ std::map<std::string, struct ggml_tensor*> tensors;
+ DirectCompute::ModelLoader loader{ model.hparams.n_audio_layer, model.hparams.n_text_layer };
+
+ // prepare memory for the weights
+ {
+ auto& ctx = model.ctx;
+ const auto& hparams = model.hparams;
+ const int n_vocab = hparams.n_vocab;
+
+ const int n_audio_ctx = hparams.n_audio_ctx;
+ const int n_audio_state = hparams.n_audio_state;
+ const int n_audio_layer = hparams.n_audio_layer;
+
+ const int n_text_ctx = hparams.n_text_ctx;
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+
+ const int n_mels = hparams.n_mels;
+
+ model.layers_encoder.resize( n_audio_layer );
+ model.layers_decoder.resize( n_text_layer );
+
+ // encoder
+ {
+ model.e_pe = ggml_new_tensor_2d( ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx );
+ loader.add( model.e_pe, loader.model.enc.positionalEmbedding );
+
+ model.e_conv_1_w = ggml_new_tensor_3d( ctx, wtype, 3, n_mels, n_audio_state );
+ model.e_conv_1_b = ggml_new_tensor_2d( ctx, GGML_TYPE_F32, 1, n_audio_state );
+ loader.add( model.e_conv_1_w, model.e_conv_1_b, loader.model.enc.conv1 );
+
+ model.e_conv_2_w = ggml_new_tensor_3d( ctx, wtype, 3, n_audio_state, n_audio_state );
+ model.e_conv_2_b = ggml_new_tensor_2d( ctx, GGML_TYPE_F32, 1, n_audio_state );
+ loader.add( model.e_conv_2_w, model.e_conv_2_b, loader.model.enc.conv2 );
+
+ model.e_ln_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ model.e_ln_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ loader.add( model.e_ln_w, model.e_ln_b, loader.model.enc.lnPost );
+
+ // map by name
+ tensors[ "encoder.positional_embedding" ] = model.e_pe;
+
+ tensors[ "encoder.conv1.weight" ] = model.e_conv_1_w;
+ tensors[ "encoder.conv1.bias" ] = model.e_conv_1_b;
+
+ tensors[ "encoder.conv2.weight" ] = model.e_conv_2_w;
+ tensors[ "encoder.conv2.bias" ] = model.e_conv_2_b;
+
+ tensors[ "encoder.ln_post.weight" ] = model.e_ln_w;
+ tensors[ "encoder.ln_post.bias" ] = model.e_ln_b;
+
+ for( int i = 0; i < n_audio_layer; ++i )
+ {
+ auto& layer = model.layers_encoder[ i ];
+ auto& gpu = loader.model.enc.layers[ i ];
+
+ layer.mlp_ln_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ layer.mlp_ln_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ loader.add( layer.mlp_ln_w, layer.mlp_ln_b, gpu.mlpLn );
+
+ layer.mlp_0_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, 4 * n_audio_state );
+ layer.mlp_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, 4 * n_audio_state );
+ loader.add( layer.mlp_0_w, layer.mlp_0_b, gpu.mlp0 );
+
+ layer.mlp_1_w = ggml_new_tensor_2d( ctx, wtype, 4 * n_audio_state, n_audio_state );
+ layer.mlp_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ loader.add( layer.mlp_1_w, layer.mlp_1_b, gpu.mlp1 );
+
+ layer.attn_ln_0_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ layer.attn_ln_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ loader.add( layer.attn_ln_0_w, layer.attn_ln_0_b, gpu.attnLn0 );
+
+ layer.attn_q_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, n_audio_state );
+ layer.attn_q_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ loader.add( layer.attn_q_w, layer.attn_q_b, gpu.attnQuery );
+
+ layer.attn_k_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, n_audio_state );
+ loader.add( layer.attn_k_w, gpu.attnKey );
+
+ layer.attn_v_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, n_audio_state );
+ layer.attn_v_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ loader.add( layer.attn_v_w, layer.attn_v_b, gpu.attnValue );
+
+ layer.attn_ln_1_w = ggml_new_tensor_2d( ctx, wtype, n_audio_state, n_audio_state );
+ layer.attn_ln_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_audio_state );
+ loader.add( layer.attn_ln_1_w, layer.attn_ln_1_b, gpu.attnLn1 );
+
+ // map by name
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp_ln.weight" ] = layer.mlp_ln_w;
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp_ln.bias" ] = layer.mlp_ln_b;
+
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp.0.weight" ] = layer.mlp_0_w;
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp.0.bias" ] = layer.mlp_0_b;
+
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp.2.weight" ] = layer.mlp_1_w;
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".mlp.2.bias" ] = layer.mlp_1_b;
+
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn_ln.weight" ] = layer.attn_ln_0_w;
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn_ln.bias" ] = layer.attn_ln_0_b;
+
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.query.weight" ] = layer.attn_q_w;
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.query.bias" ] = layer.attn_q_b;
+
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.key.weight" ] = layer.attn_k_w;
+
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.value.weight" ] = layer.attn_v_w;
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.value.bias" ] = layer.attn_v_b;
+
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.out.weight" ] = layer.attn_ln_1_w;
+ tensors[ "encoder.blocks." + std::to_string( i ) + ".attn.out.bias" ] = layer.attn_ln_1_b;
+ }
+ }
+
+ // decoder
+ {
+ model.d_pe = ggml_new_tensor_2d( ctx, GGML_TYPE_F32, n_text_state, n_text_ctx );
+ loader.add( model.d_pe, loader.model.dec.positionalEmbedding );
+
+ model.d_te = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_vocab );
+ loader.add( model.d_te, loader.model.dec.tokenEmbedding );
+
+ model.d_ln_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ model.d_ln_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( model.d_ln_w, model.d_ln_b, loader.model.dec.ln );
+
+ // map by name
+ tensors[ "decoder.positional_embedding" ] = model.d_pe;
+
+ tensors[ "decoder.token_embedding.weight" ] = model.d_te;
+
+ tensors[ "decoder.ln.weight" ] = model.d_ln_w;
+ tensors[ "decoder.ln.bias" ] = model.d_ln_b;
+
+ for( int i = 0; i < n_text_layer; ++i ) {
+ auto& layer = model.layers_decoder[ i ];
+ auto& gpu = loader.model.dec.layers[ i ];
+
+ layer.mlp_ln_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ layer.mlp_ln_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.mlp_ln_w, layer.mlp_ln_b, gpu.mlpLn );
+
+ layer.mlp_0_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, 4 * n_text_state );
+ layer.mlp_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, 4 * n_text_state );
+ loader.add( layer.mlp_0_w, layer.mlp_0_b, gpu.mlp0 );
+
+ layer.mlp_1_w = ggml_new_tensor_2d( ctx, wtype, 4 * n_text_state, n_text_state );
+ layer.mlp_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.mlp_1_w, layer.mlp_1_b, gpu.mlp1 );
+
+ layer.attn_ln_0_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ layer.attn_ln_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.attn_ln_0_w, layer.attn_ln_0_b, gpu.attnLn0 );
+
+ layer.attn_q_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state );
+ layer.attn_q_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.attn_q_w, layer.attn_q_b, gpu.attnQuery );
+
+ layer.attn_k_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state );
+ loader.add( layer.attn_k_w, gpu.attnKey );
+
+ layer.attn_v_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state );
+ layer.attn_v_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.attn_v_w, layer.attn_v_b, gpu.attnValue );
+
+ layer.attn_ln_1_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state );
+ layer.attn_ln_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.attn_ln_1_w, layer.attn_ln_1_b, gpu.attnLn1 );
+
+ layer.cross_attn_ln_0_w = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ layer.cross_attn_ln_0_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.cross_attn_ln_0_w, layer.cross_attn_ln_0_b, gpu.crossAttnLn0 );
+
+ layer.cross_attn_q_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state );
+ layer.cross_attn_q_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.cross_attn_q_w, layer.cross_attn_q_b, gpu.crossAttnQuery );
+
+ layer.cross_attn_k_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state );
+ loader.add( layer.cross_attn_k_w, gpu.crossAttnKey );
+
+ layer.cross_attn_v_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state );
+ layer.cross_attn_v_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.cross_attn_v_w, layer.cross_attn_v_b, gpu.crossAttnValue );
+
+ layer.cross_attn_ln_1_w = ggml_new_tensor_2d( ctx, wtype, n_text_state, n_text_state );
+ layer.cross_attn_ln_1_b = ggml_new_tensor_1d( ctx, GGML_TYPE_F32, n_text_state );
+ loader.add( layer.cross_attn_ln_1_w, layer.cross_attn_ln_1_b, gpu.crossAttnLn1 );
+
+ // map by name
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp_ln.weight" ] = layer.mlp_ln_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp_ln.bias" ] = layer.mlp_ln_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp.0.weight" ] = layer.mlp_0_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp.0.bias" ] = layer.mlp_0_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp.2.weight" ] = layer.mlp_1_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".mlp.2.bias" ] = layer.mlp_1_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn_ln.weight" ] = layer.attn_ln_0_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn_ln.bias" ] = layer.attn_ln_0_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.query.weight" ] = layer.attn_q_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.query.bias" ] = layer.attn_q_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.key.weight" ] = layer.attn_k_w;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.value.weight" ] = layer.attn_v_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.value.bias" ] = layer.attn_v_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.out.weight" ] = layer.attn_ln_1_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".attn.out.bias" ] = layer.attn_ln_1_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn_ln.weight" ] = layer.cross_attn_ln_0_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn_ln.bias" ] = layer.cross_attn_ln_0_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.query.weight" ] = layer.cross_attn_q_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.query.bias" ] = layer.cross_attn_q_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.key.weight" ] = layer.cross_attn_k_w;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.value.weight" ] = layer.cross_attn_v_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.value.bias" ] = layer.cross_attn_v_b;
+
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.out.weight" ] = layer.cross_attn_ln_1_w;
+ tensors[ "decoder.blocks." + std::to_string( i ) + ".cross_attn.out.bias" ] = layer.cross_attn_ln_1_b;
+ }
+ }
+ }
+
+ // create the ggml memory context
+ {
+ struct ggml_init_params params;
+ params.mem_size = ctx.buf_memory.size();
+ params.mem_buffer = ctx.buf_memory.data();
+ model.ctx_mem = ggml_init( params );
+ if( !model.ctx_mem )
+ {
+ logError( u8"%s: ggml_init() failed", __func__ );
+ return E_INVALIDARG;
+ }
+ }
+
+ // key + value memory
+ {
+ auto& ctx = model.ctx_mem;
+
+ const auto& hparams = model.hparams;
+
+ const int n_text_state = hparams.n_text_state;
+ const int n_text_layer = hparams.n_text_layer;
+ const int n_text_ctx = hparams.n_text_ctx;
+
+ // key/value memory for the self-attention layer
+ {
+ const int n_mem = n_text_layer * n_text_ctx;
+ const int n_elements = n_text_state * n_mem;
+
+ model.memory_k = ggml_new_tensor_1d( ctx, GGML_TYPE_F16, n_elements );
+ model.memory_v = ggml_new_tensor_1d( ctx, GGML_TYPE_F16, n_elements );
+ }
+
+ // key/value memory for the cross-attention layer
+ {
+ const int n_audio_ctx = hparams.n_audio_ctx;
+
+ const int n_mem = n_text_layer * n_audio_ctx;
+ const int n_elements = n_text_state * n_mem;
+
+ model.memory_cross_k = ggml_new_tensor_1d( ctx, GGML_TYPE_F16, n_elements );
+ model.memory_cross_v = ggml_new_tensor_1d( ctx, GGML_TYPE_F16, n_elements );
+ }
+
+ const size_t memory_size =
+ ggml_nbytes( model.memory_k ) + ggml_nbytes( model.memory_v ) +
+ ggml_nbytes( model.memory_cross_k ) + ggml_nbytes( model.memory_cross_v );
+
+ logDebug( u8"%s: memory size = %7.2f MB", __func__, memory_size / 1024.0 / 1024.0 );
+ }
+
+ // load weights
+ {
+ size_t total_size = 0;
+ int n_loaded = 0;
+ std::string name;
+
+ while( true )
+ {
+ int32_t n_dims;
+ int32_t length;
+ int32_t ftype;
+
+ HRESULT hr = readStruct( stm, n_dims );
+ if( hr == E_EOF )
+ break;
+ CHECK( hr );
+ CHECK( readStruct( stm, length ) );
+ CHECK( readStruct( stm, ftype ) );
+
+ int32_t nelements = 1;
+ int32_t ne[ 3 ] = { 1, 1, 1 };
+ for( int i = 0; i < n_dims; ++i )
+ {
+ CHECK( readStruct( stm, ne[ i ] ) );
+ nelements *= ne[ i ];
+ }
+
+ name.resize( length );
+ CHECK( readBytes( stm, name.data(), length ) );
+
+ if( tensors.find( name.data() ) == tensors.end() )
+ {
+ logError( u8"%s: unknown tensor '%s' in model file", __func__, name.data() );
+ return E_INVALIDARG;
+ }
+
+ auto tensor = tensors[ name.data() ];
+ if( ggml_nelements( tensor ) != nelements )
+ {
+ logError( u8"%s: tensor '%s' has wrong size in model file", __func__, name.data() );
+ return E_INVALIDARG;
+ }
+
+ if( tensor->ne[ 0 ] != ne[ 0 ] || tensor->ne[ 1 ] != ne[ 1 ] || tensor->ne[ 2 ] != ne[ 2 ] )
+ {
+ logError( u8"%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]",
+ __func__, name.data(), tensor->ne[ 0 ], tensor->ne[ 1 ], tensor->ne[ 2 ], ne[ 0 ], ne[ 1 ], ne[ 2 ] );
+ return E_INVALIDARG;
+ }
+
+ const size_t bpe = ( ftype == 0 ) ? sizeof( float ) : sizeof( ggml_fp16_t );
+
+ if( nelements * bpe != ggml_nbytes( tensor ) )
+ {
+ logError( u8"%s: tensor '%s' has wrong size in model file: got %zu, expected %zu",
+ __func__, name.data(), ggml_nbytes( tensor ), nelements * bpe );
+ return E_INVALIDARG;
+ }
+
+ CHECK( readBytes( stm, tensor->data, ggml_nbytes( tensor ) ) );
+
+ //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+ total_size += ggml_nbytes( tensor );
+ n_loaded++;
+ // loader.tryLoad( tensor );
+ }
+
+ logDebug( u8"%s: model size = %7.2f MB", __func__, total_size / 1024.0 / 1024.0 );
+ if( n_loaded == 0 )
+ {
+ logError( u8"%s: no tensors loaded from model file", __func__ );
+ return E_INVALIDARG;
+ }
+ else if( n_loaded != (int)tensors.size() )
+ {
+ logError( u8"%s: not all tensors loaded from model file - expected %zu, got %d", __func__, tensors.size(), n_loaded );
+ return E_INVALIDARG;
+ }
+ model.n_loaded = n_loaded;
+ }
+
+ return S_OK;
+ }
+
+ HRESULT Context::load( iReadStream* stm )
+ {
+ const int64_t t_start_us = ggml_time_us();
+ ctx.t_start_us = t_start_us;
+ HRESULT hr = loadImpl( stm );
+ ctx.t_load_us = ggml_time_us() - t_start_us;
+ return hr;
+ }
+
+ HRESULT __stdcall loadReferenceCpuModel( const wchar_t* path, iModel** pp )
+ {
+ if( nullptr == path || nullptr == pp )
+ return E_POINTER;
+
+ ComLight::Object<ReadStream> stream;
+ CHECK( stream.open( path ) );
+
+ ggml_time_init();
+ ComLight::CComPtr<ComLight::Object<Context>> obj;
+ CHECK( ComLight::Object<Context>::create( obj ) );
+ CHECK( obj->load( &stream ) );
+ obj.detach( pp );
+ return S_OK;
+ }
+}
+
+#include "Whisper/WhisperContext.h"
+#include "Whisper/ModelBuffers.h"
+#include "ML/testUtils.h"
+using namespace DirectCompute;
+
+static DirectCompute::Tensor gpuEncode( const whisper_context& wctx, const int mel_offset )
+{
+ return DirectCompute::Tensor{};
+#if 0
+ using namespace DirectCompute;
+ WhisperContext& ctx = WhisperContext::current();
+
+ Tensor cur;
+ sEncodeParams whisperParams;
+ const auto& mel_inp = wctx.mel;
+ {
+ const auto& model = wctx.model;
+ const auto& hparams = model.hparams;
+ whisperParams.n_len = (uint32_t)mel_inp.n_len;
+ whisperParams.n_mel = (uint32_t)mel_inp.n_mel;
+
+ const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
+ assert( n_ctx > 0 );
+ whisperParams.n_ctx = (uint32_t)n_ctx;
+
+ const int n_mels = hparams.n_mels;
+ assert( n_mels > 0 );
+ whisperParams.n_mels = (uint32_t)n_mels;
+
+ assert( mel_offset >= 0 );
+ whisperParams.mel_offset = (uint32_t)mel_offset;
+
+ const int layersCount = hparams.n_audio_layer;
+ assert( layersCount > 0 );
+ whisperParams.layersCount = (uint32_t)layersCount;
+
+ const int n_state = hparams.n_audio_state;
+ const int n_head = hparams.n_audio_head;
+ assert( n_state >= 0 );
+ assert( n_head >= 0 );
+
+ whisperParams.n_state = (uint32_t)n_state;
+ whisperParams.n_head = (uint32_t)n_head;
+
+ int n_audio_ctx = hparams.n_audio_ctx;
+ assert( n_audio_ctx > 0 );
+ whisperParams.n_audio_ctx = (uint32_t)n_audio_ctx;
+
+ int n_text_state = hparams.n_text_state;
+ assert( n_text_state > 0 );
+ whisperParams.n_text_state = (uint32_t)n_text_state;
+
+ int n_text_layer = hparams.n_text_layer;
+ assert( n_text_layer > 0 );
+ whisperParams.n_text_layer = (uint32_t)n_text_layer;
+
+ int n_text_ctx = hparams.n_text_ctx;
+ assert( n_text_ctx > 0 );
+ whisperParams.n_text_ctx = (uint32_t)n_text_ctx;
+ }
+
+ return ctx.encode( mel_inp.data, whisperParams );
+#endif
+}
+
+GpuEncTest::GpuEncTest( const whisper_context& wctx, const int mel_offset )
+{
+ return;
+ gpuResult = gpuEncode( wctx, mel_offset );
+}
+
+void GpuEncTest::compare( const ggml_tensor* expected ) const
+{
+ return;
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.dbgPrintDifference( expected, gpuResult, "GpuEncTest.compare", false );
+}
+
+void GpuEncTest::compareMel( const ggml_tensor* expected ) const
+{
+ return;
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.dbgPrintDifference( expected, mel, "GpuEncTest.compareMel", false );
+}
+
+/*
+void GpuEncTest::comparePostponed()
+{
+ if( nullptr == tempRef )
+ return;
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.dbgPrintDifference( tempRef, tempGpu, "comparePostponed" );
+ tempRef = nullptr;
+} */
+
+__declspec( noinline ) GpuDecTest::GpuDecTest( const whisper_context& wctx, const int* tokens, const int n_tokens, const int n_past )
+{
+#if 1
+ return;
+#else
+ sDecodeParams dp;
+ {
+ WhisperContext& ctx = WhisperContext::current();
+ const auto& model = wctx.model;
+ const auto& hparams = model.hparams;
+ dp.n_state = hparams.n_text_state;
+ dp.n_head = hparams.n_text_head;
+ dp.n_ctx = hparams.n_text_ctx;
+ dp.n_past = n_past;
+ dp.M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
+ dp.n_text_layer = hparams.n_text_layer;
+ dp.n_vocab = hparams.n_vocab;
+ }
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.decode( tokens, n_tokens, dp, logits, probs );
+#endif
+}
+
+void __declspec( noinline ) GpuDecTest::compare( const std::vector<float>& cpuLogits, const std::vector<float>& cpuProbs ) const
+{
+ return;
+
+ if( cpuLogits.size() != logits.size() )
+ {
+ printf( "GpuDecTest.compare fail, size different\n" );
+ return;
+ }
+
+ computeDiff( logits.data(), cpuLogits.data(), logits.size() ).print( "GpuDecTest.compare logits" );
+ computeDiff( probs.data(), cpuProbs.data(), probs.size() ).print( "GpuDecTest.compare probs" );
+}
+
+void __declspec( noinline ) GpuDecTest::postpone( const ggml_tensor* t )
+{
+ return;
+
+ if( nullptr != tempRef )
+ return;
+ tempRef = t;
+}
+
+void __declspec( noinline ) GpuDecTest::comparePostponed()
+{
+#if 1
+ return;
+#else
+ if( nullptr == tempRef )
+ return;
+ WhisperContext& ctx = WhisperContext::current();
+ ID3D11ShaderResourceView* srv = ctx.dbgDecodeTest;
+ if( nullptr == srv )
+ return;
+
+ ctx.dbgPrintDifference( tempRef, ctx.dbgDecodeTest, "GpuDecTest.comparePostponed" );
+ tempRef = nullptr;
+#endif
+}
+#else
+HRESULT __stdcall Whisper::loadReferenceCpuModel( const wchar_t* path, Whisper::iModel** pp )
+{
+ logError( u8"This build of the DLL doesn’t implement the reference CPU-running Whisper model." );
+ return E_NOTIMPL;
+}
+#endif \ No newline at end of file
diff --git a/WhisperCpp.sln b/WhisperCpp.sln
new file mode 100644
index 0000000..4c2ae51
--- /dev/null
+++ b/WhisperCpp.sln
@@ -0,0 +1,110 @@
+
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 17
+VisualStudioVersion = 17.4.33122.133
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ComLightLib", "ComLightLib\ComLightLib.vcxproj", "{52F486E7-830C-45D8-BE47-E76B5AAB2772}"
+EndProject
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Whisper", "Whisper\Whisper.vcxproj", "{701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}"
+EndProject
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ComputeShaders", "ComputeShaders\ComputeShaders.vcxproj", "{1C39D386-96D0-47A1-BBFA-68BBDB24439C}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "WhisperNet", "WhisperNet\WhisperNet.csproj", "{F213558F-FEA2-4F66-A07F-69727E9EC81D}"
+ ProjectSection(ProjectDependencies) = postProject
+ {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71} = {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}
+ EndProjectSection
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TranscribeCS", "Examples\TranscribeCS\TranscribeCS.csproj", "{0533B86C-D0E8-4190-9717-7DBD9EC8C11F}"
+EndProject
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "OldMain", "Examples\OldMain\OldMain.vcxproj", "{596F9770-9AEB-49D3-86CA-4200197DF12B}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CompressShaders", "Tools\CompressShaders\CompressShaders.csproj", "{4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}"
+ ProjectSection(ProjectDependencies) = postProject
+ {1C39D386-96D0-47A1-BBFA-68BBDB24439C} = {1C39D386-96D0-47A1-BBFA-68BBDB24439C}
+ EndProjectSection
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Tools", "Tools", "{90D16EBB-08A4-4C9B-9991-B1B2E036838C}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Examples", "Examples", "{B988C132-115D-4157-99FE-0D891CE45A82}"
+EndProject
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "main", "Examples\main\main.vcxproj", "{4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MicrophoneCS", "Examples\MicrophoneCS\MicrophoneCS.csproj", "{A49305C0-7022-45A6-89B4-4BD33138C98A}"
+ ProjectSection(ProjectDependencies) = postProject
+ {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71} = {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}
+ EndProjectSection
+EndProject
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "compareTraces", "Tools\compareTraces\compareTraces.vcxproj", "{8478A77C-D851-4C63-9511-1770CC82D33E}"
+EndProject
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "WhisperDesktop", "Examples\WhisperDesktop\WhisperDesktop.vcxproj", "{CD9E49F0-75A3-4F91-AC71-336109EE39C6}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{17835CA3-D7F6-4BEF-9471-12C015764A2C}"
+ ProjectSection(SolutionItems) = preProject
+ Readme.md = Readme.md
+ EndProjectSection
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|x64 = Debug|x64
+ Release|x64 = Release|x64
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {52F486E7-830C-45D8-BE47-E76B5AAB2772}.Debug|x64.ActiveCfg = Debug|x64
+ {52F486E7-830C-45D8-BE47-E76B5AAB2772}.Debug|x64.Build.0 = Debug|x64
+ {52F486E7-830C-45D8-BE47-E76B5AAB2772}.Release|x64.ActiveCfg = Release|x64
+ {52F486E7-830C-45D8-BE47-E76B5AAB2772}.Release|x64.Build.0 = Release|x64
+ {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}.Debug|x64.ActiveCfg = Debug|x64
+ {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}.Debug|x64.Build.0 = Debug|x64
+ {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}.Release|x64.ActiveCfg = Release|x64
+ {701DF8C8-E4A5-43EC-9C6B-747BBF4D8E71}.Release|x64.Build.0 = Release|x64
+ {1C39D386-96D0-47A1-BBFA-68BBDB24439C}.Debug|x64.ActiveCfg = Debug|x64
+ {1C39D386-96D0-47A1-BBFA-68BBDB24439C}.Debug|x64.Build.0 = Debug|x64
+ {1C39D386-96D0-47A1-BBFA-68BBDB24439C}.Release|x64.ActiveCfg = Release|x64
+ {1C39D386-96D0-47A1-BBFA-68BBDB24439C}.Release|x64.Build.0 = Release|x64
+ {F213558F-FEA2-4F66-A07F-69727E9EC81D}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {F213558F-FEA2-4F66-A07F-69727E9EC81D}.Debug|x64.Build.0 = Debug|Any CPU
+ {F213558F-FEA2-4F66-A07F-69727E9EC81D}.Release|x64.ActiveCfg = Release|Any CPU
+ {F213558F-FEA2-4F66-A07F-69727E9EC81D}.Release|x64.Build.0 = Release|Any CPU
+ {0533B86C-D0E8-4190-9717-7DBD9EC8C11F}.Debug|x64.ActiveCfg = Debug|x64
+ {0533B86C-D0E8-4190-9717-7DBD9EC8C11F}.Debug|x64.Build.0 = Debug|x64
+ {0533B86C-D0E8-4190-9717-7DBD9EC8C11F}.Release|x64.ActiveCfg = Release|x64
+ {596F9770-9AEB-49D3-86CA-4200197DF12B}.Debug|x64.ActiveCfg = Debug|x64
+ {596F9770-9AEB-49D3-86CA-4200197DF12B}.Debug|x64.Build.0 = Debug|x64
+ {596F9770-9AEB-49D3-86CA-4200197DF12B}.Release|x64.ActiveCfg = Release|x64
+ {596F9770-9AEB-49D3-86CA-4200197DF12B}.Release|x64.Build.0 = Release|x64
+ {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}.Debug|x64.Build.0 = Debug|Any CPU
+ {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}.Release|x64.ActiveCfg = Release|Any CPU
+ {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2}.Release|x64.Build.0 = Release|Any CPU
+ {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}.Debug|x64.ActiveCfg = Debug|x64
+ {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}.Debug|x64.Build.0 = Debug|x64
+ {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}.Release|x64.ActiveCfg = Release|x64
+ {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2}.Release|x64.Build.0 = Release|x64
+ {A49305C0-7022-45A6-89B4-4BD33138C98A}.Debug|x64.ActiveCfg = Debug|x64
+ {A49305C0-7022-45A6-89B4-4BD33138C98A}.Debug|x64.Build.0 = Debug|x64
+ {A49305C0-7022-45A6-89B4-4BD33138C98A}.Release|x64.ActiveCfg = Release|x64
+ {8478A77C-D851-4C63-9511-1770CC82D33E}.Debug|x64.ActiveCfg = Debug|x64
+ {8478A77C-D851-4C63-9511-1770CC82D33E}.Debug|x64.Build.0 = Debug|x64
+ {8478A77C-D851-4C63-9511-1770CC82D33E}.Release|x64.ActiveCfg = Release|x64
+ {8478A77C-D851-4C63-9511-1770CC82D33E}.Release|x64.Build.0 = Release|x64
+ {CD9E49F0-75A3-4F91-AC71-336109EE39C6}.Debug|x64.ActiveCfg = Debug|x64
+ {CD9E49F0-75A3-4F91-AC71-336109EE39C6}.Debug|x64.Build.0 = Debug|x64
+ {CD9E49F0-75A3-4F91-AC71-336109EE39C6}.Release|x64.ActiveCfg = Release|x64
+ {CD9E49F0-75A3-4F91-AC71-336109EE39C6}.Release|x64.Build.0 = Release|x64
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(NestedProjects) = preSolution
+ {0533B86C-D0E8-4190-9717-7DBD9EC8C11F} = {B988C132-115D-4157-99FE-0D891CE45A82}
+ {596F9770-9AEB-49D3-86CA-4200197DF12B} = {B988C132-115D-4157-99FE-0D891CE45A82}
+ {4E85005E-D2C7-4B28-A5F2-8BC92DAF6BA2} = {90D16EBB-08A4-4C9B-9991-B1B2E036838C}
+ {4CCA7042-EB15-4F7A-B77B-5CAFD2DF47B2} = {B988C132-115D-4157-99FE-0D891CE45A82}
+ {A49305C0-7022-45A6-89B4-4BD33138C98A} = {B988C132-115D-4157-99FE-0D891CE45A82}
+ {8478A77C-D851-4C63-9511-1770CC82D33E} = {90D16EBB-08A4-4C9B-9991-B1B2E036838C}
+ {CD9E49F0-75A3-4F91-AC71-336109EE39C6} = {B988C132-115D-4157-99FE-0D891CE45A82}
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {07D5F1CF-1FAD-4F40-806A-B148CD609961}
+ EndGlobalSection
+EndGlobal
diff --git a/WhisperNet/API/CaptureDeviceId.cs b/WhisperNet/API/CaptureDeviceId.cs
new file mode 100644
index 0000000..9636e53
--- /dev/null
+++ b/WhisperNet/API/CaptureDeviceId.cs
@@ -0,0 +1,24 @@
+using Whisper.Internal;
+
+namespace Whisper
+{
+ /// <summary>Identifiers for an audio capture device</summary>
+ public record struct CaptureDeviceId
+ {
+ /// <summary>The display name is suitable for showing to the user, but might not be unique.</summary>
+ public string displayName;
+
+ /// <summary>Endpoint ID for an audio capture device.<br/>
+ /// It uniquely identifies the device on the system, but is not a readable string.</summary>
+ public string endpoint;
+
+ internal CaptureDeviceId( in sCaptureDevice rsi )
+ {
+ displayName = rsi.displayName ?? "<display name unavailable>";
+ endpoint = rsi.endpoint ?? throw new ApplicationException( "The device has no endpoint ID" );
+ }
+
+ /// <summary>Returns a String which represents the object instance</summary>
+ public override string ToString() => $"Capture device: \"{displayName}\"";
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/Parameters.cs b/WhisperNet/API/Parameters.cs
new file mode 100644
index 0000000..d2b53f9
--- /dev/null
+++ b/WhisperNet/API/Parameters.cs
@@ -0,0 +1,95 @@
+// Missing XML comment for publicly visible type or member
+// TODO: remove this line and document them.
+#pragma warning disable CS1591
+
+namespace Whisper
+{
+ /// <summary>Available sampling strategies</summary>
+ public enum eSamplingStrategy: int
+ {
+ /// <summary>Always select the most probable token</summary>
+ Greedy,
+ /// <summary>TODO: not implemented yet!</summary>
+ BeamSearch,
+ };
+
+ [Flags]
+ public enum eFullParamsFlags: uint
+ {
+ None = 0,
+ Translate = 1,
+ NoContext = 2,
+ SingleSegment = 4,
+ PrintSpecial = 8,
+ PrintProgress = 0x10,
+ PrintRealtime = 0x20,
+ PrintTimestamps = 0x40,
+
+ // Experimental
+ TokenTimestamps = 0x100,
+ SpeedupAudio = 0x200,
+ };
+
+ /// <summary>Transcribe parameters</summary>
+ public struct Parameters
+ {
+ /// <summary>Sampling strategy</summary>
+ public eSamplingStrategy strategy;
+
+ /// <summary>Count of CPU worker threads to use</summary>
+ /// <remarks>So far, the GPU model only uses CPU threads for MEL spectrograms</remarks>
+ public int cpuThreads;
+
+ public int n_max_text_ctx;
+ /// <summary>start offset in ms</summary>
+ public int offset_ms;
+ /// <summary>audio duration to process in ms</summary>
+ public int duration_ms;
+ public eFullParamsFlags flags;
+
+ /// <summary>Set or clear the specified flag in the <see cref="flags" /> field of this structure</summary>
+ public void setFlag( eFullParamsFlags flag, bool set )
+ {
+ if( flag != eFullParamsFlags.None )
+ {
+ if( set )
+ flags |= flag;
+ else
+ flags &= ~flag;
+ return;
+ }
+ throw new ArgumentException();
+ }
+
+ /// <summary>Language</summary>
+ public eLanguage language;
+
+ // [EXPERIMENTAL] token-level timestamps
+ /// <summary>timestamp token probability threshold (~0.01)</summary>
+ public float thold_pt;
+ /// <summary>timestamp token sum probability threshold (~0.01)</summary>
+ public float thold_ptsum;
+ /// <summary>max segment length in characters</summary>
+ public int max_len;
+ /// <summary>max tokens per segment (0 = no limit)</summary>
+ public int max_tokens;
+
+ public struct sGreedy
+ {
+ public int n_past;
+ }
+ public sGreedy greedy;
+
+ public struct sBeamSearch
+ {
+ public int n_past;
+ public int beam_width;
+ public int n_best;
+ }
+ public sBeamSearch beamSearch;
+
+ // [EXPERIMENTAL] speed-up techniques
+ /// <summary>overwrite the audio context size (0 = use default)</summary>
+ public int audioContextSize;
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/SpecialTokens.cs b/WhisperNet/API/SpecialTokens.cs
new file mode 100644
index 0000000..d672369
--- /dev/null
+++ b/WhisperNet/API/SpecialTokens.cs
@@ -0,0 +1,23 @@
+namespace Whisper
+{
+ /// <summary>Special tokens defined in the model</summary>
+ public readonly struct SpecialTokens
+ {
+ /// <summary>The end of a transcription</summary>
+ public readonly int TranscriptionEnd; // token_eot
+ /// <summary>Start of a transcription</summary>
+ public readonly int TranscriptionStart; // token_sot
+ /// <summary>Represents the previous word in the transcription. It is used to help the model predict the current word based on the context of the words that came before it.</summary>
+ public readonly int PreviousWord; // token_prev
+ /// <summary>Start of a sentence</summary>
+ public readonly int SentenceStart; // token_solm
+ /// <summary>Represents the word "not" in the transcription</summary>
+ public readonly int Not; // token_not
+ /// <summary>New transcription</summary>
+ public readonly int TranscriptionBegin; // token_beg
+ /// <summary>token_translate</summary>
+ public readonly int TaskTranslate;
+ /// <summary>token_transcribe</summary>
+ public readonly int TaskTranscribe;
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/eCaptureStatus.cs b/WhisperNet/API/eCaptureStatus.cs
new file mode 100644
index 0000000..41f05fb
--- /dev/null
+++ b/WhisperNet/API/eCaptureStatus.cs
@@ -0,0 +1,19 @@
+namespace Whisper
+{
+ /// <summary>Status of the voice capture</summary>
+ [Flags]
+ public enum eCaptureStatus: byte
+ {
+ /// <summary>Doing nothing</summary>
+ None = 0,
+ /// <summary>Capturing the audio</summary>
+ Listening = 1,
+ /// <summary>A voice is detected in the captured audio, recording</summary>
+ Voice = 2,
+ /// <summary>Transcribing a recorded piece of the audio</summary>
+ Transcribing = 4,
+ /// <summary>The computer is unable to transcribe the audio quickly enough,<br/>
+ /// and the capture is dropping the incoming audio samples.</summary>
+ Stalled = 0x80,
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/eLanguage.cs b/WhisperNet/API/eLanguage.cs
new file mode 100644
index 0000000..1241077
--- /dev/null
+++ b/WhisperNet/API/eLanguage.cs
@@ -0,0 +1,206 @@
+// This file is generated by a tool, from the `languageCodez.tsv` file in this repository
+namespace Whisper
+{
+ /// <summary>Supported languages</summary>
+ public enum eLanguage: uint
+ {
+ /// <summary>Afrikaans</summary>
+ Afrikaans = 0x6661,
+ /// <summary>Albanian</summary>
+ Albanian = 0x7173,
+ /// <summary>Amharic</summary>
+ Amharic = 0x6D61,
+ /// <summary>Arabic</summary>
+ Arabic = 0x7261,
+ /// <summary>Armenian</summary>
+ Armenian = 0x7968,
+ /// <summary>Assamese</summary>
+ Assamese = 0x7361,
+ /// <summary>Azerbaijani</summary>
+ Azerbaijani = 0x7A61,
+ /// <summary>Bashkir</summary>
+ Bashkir = 0x6162,
+ /// <summary>Basque</summary>
+ Basque = 0x7565,
+ /// <summary>Belarusian</summary>
+ Belarusian = 0x6562,
+ /// <summary>Bengali</summary>
+ Bengali = 0x6E62,
+ /// <summary>Bosnian</summary>
+ Bosnian = 0x7362,
+ /// <summary>Breton</summary>
+ Breton = 0x7262,
+ /// <summary>Bulgarian</summary>
+ Bulgarian = 0x6762,
+ /// <summary>Catalan</summary>
+ Catalan = 0x6163,
+ /// <summary>Chinese</summary>
+ Chinese = 0x687A,
+ /// <summary>Croatian</summary>
+ Croatian = 0x7268,
+ /// <summary>Czech</summary>
+ Czech = 0x7363,
+ /// <summary>Danish</summary>
+ Danish = 0x6164,
+ /// <summary>Dutch</summary>
+ Dutch = 0x6C6E,
+ /// <summary>English</summary>
+ English = 0x6E65,
+ /// <summary>Estonian</summary>
+ Estonian = 0x7465,
+ /// <summary>Faroese</summary>
+ Faroese = 0x6F66,
+ /// <summary>Finnish</summary>
+ Finnish = 0x6966,
+ /// <summary>French</summary>
+ French = 0x7266,
+ /// <summary>Galician</summary>
+ Galician = 0x6C67,
+ /// <summary>Georgian</summary>
+ Georgian = 0x616B,
+ /// <summary>German</summary>
+ German = 0x6564,
+ /// <summary>Greek</summary>
+ Greek = 0x6C65,
+ /// <summary>Gujarati</summary>
+ Gujarati = 0x7567,
+ /// <summary>Haitian Creole</summary>
+ HaitianCreole = 0x7468,
+ /// <summary>Hausa</summary>
+ Hausa = 0x6168,
+ /// <summary>Hawaiian</summary>
+ Hawaiian = 0x776168,
+ /// <summary>Hebrew</summary>
+ Hebrew = 0x7769,
+ /// <summary>Hindi</summary>
+ Hindi = 0x6968,
+ /// <summary>Hungarian</summary>
+ Hungarian = 0x7568,
+ /// <summary>Icelandic</summary>
+ Icelandic = 0x7369,
+ /// <summary>Indonesian</summary>
+ Indonesian = 0x6469,
+ /// <summary>Italian</summary>
+ Italian = 0x7469,
+ /// <summary>Japanese</summary>
+ Japanese = 0x616A,
+ /// <summary>Javanese</summary>
+ Javanese = 0x776A,
+ /// <summary>Kannada</summary>
+ Kannada = 0x6E6B,
+ /// <summary>Kazakh</summary>
+ Kazakh = 0x6B6B,
+ /// <summary>Khmer</summary>
+ Khmer = 0x6D6B,
+ /// <summary>Korean</summary>
+ Korean = 0x6F6B,
+ /// <summary>Lao</summary>
+ Lao = 0x6F6C,
+ /// <summary>Latin</summary>
+ Latin = 0x616C,
+ /// <summary>Latvian</summary>
+ Latvian = 0x766C,
+ /// <summary>Lingala</summary>
+ Lingala = 0x6E6C,
+ /// <summary>Lithuanian</summary>
+ Lithuanian = 0x746C,
+ /// <summary>Luxembourgish</summary>
+ Luxembourgish = 0x626C,
+ /// <summary>Macedonian</summary>
+ Macedonian = 0x6B6D,
+ /// <summary>Malagasy</summary>
+ Malagasy = 0x676D,
+ /// <summary>Malay</summary>
+ Malay = 0x736D,
+ /// <summary>Malayalam</summary>
+ Malayalam = 0x6C6D,
+ /// <summary>Maltese</summary>
+ Maltese = 0x746D,
+ /// <summary>Maori</summary>
+ Maori = 0x696D,
+ /// <summary>Marathi</summary>
+ Marathi = 0x726D,
+ /// <summary>Mongolian</summary>
+ Mongolian = 0x6E6D,
+ /// <summary>Myanmar</summary>
+ Myanmar = 0x796D,
+ /// <summary>Nepali</summary>
+ Nepali = 0x656E,
+ /// <summary>Norwegian</summary>
+ Norwegian = 0x6F6E,
+ /// <summary>Nynorsk</summary>
+ Nynorsk = 0x6E6E,
+ /// <summary>Occitan</summary>
+ Occitan = 0x636F,
+ /// <summary>Pashto</summary>
+ Pashto = 0x7370,
+ /// <summary>Persian</summary>
+ Persian = 0x6166,
+ /// <summary>Polish</summary>
+ Polish = 0x6C70,
+ /// <summary>Portuguese</summary>
+ Portuguese = 0x7470,
+ /// <summary>Punjabi</summary>
+ Punjabi = 0x6170,
+ /// <summary>Romanian</summary>
+ Romanian = 0x6F72,
+ /// <summary>Russian</summary>
+ Russian = 0x7572,
+ /// <summary>Sanskrit</summary>
+ Sanskrit = 0x6173,
+ /// <summary>Serbian</summary>
+ Serbian = 0x7273,
+ /// <summary>Shona</summary>
+ Shona = 0x6E73,
+ /// <summary>Sindhi</summary>
+ Sindhi = 0x6473,
+ /// <summary>Sinhala</summary>
+ Sinhala = 0x6973,
+ /// <summary>Slovak</summary>
+ Slovak = 0x6B73,
+ /// <summary>Slovenian</summary>
+ Slovenian = 0x6C73,
+ /// <summary>Somali</summary>
+ Somali = 0x6F73,
+ /// <summary>Spanish</summary>
+ Spanish = 0x7365,
+ /// <summary>Sundanese</summary>
+ Sundanese = 0x7573,
+ /// <summary>Swahili</summary>
+ Swahili = 0x7773,
+ /// <summary>Swedish</summary>
+ Swedish = 0x7673,
+ /// <summary>Tagalog</summary>
+ Tagalog = 0x6C74,
+ /// <summary>Tajik</summary>
+ Tajik = 0x6774,
+ /// <summary>Tamil</summary>
+ Tamil = 0x6174,
+ /// <summary>Tatar</summary>
+ Tatar = 0x7474,
+ /// <summary>Telugu</summary>
+ Telugu = 0x6574,
+ /// <summary>Thai</summary>
+ Thai = 0x6874,
+ /// <summary>Tibetan</summary>
+ Tibetan = 0x6F62,
+ /// <summary>Turkish</summary>
+ Turkish = 0x7274,
+ /// <summary>Turkmen</summary>
+ Turkmen = 0x6B74,
+ /// <summary>Ukrainian</summary>
+ Ukrainian = 0x6B75,
+ /// <summary>Urdu</summary>
+ Urdu = 0x7275,
+ /// <summary>Uzbek</summary>
+ Uzbek = 0x7A75,
+ /// <summary>Vietnamese</summary>
+ Vietnamese = 0x6976,
+ /// <summary>Welsh</summary>
+ Welsh = 0x7963,
+ /// <summary>Yiddish</summary>
+ Yiddish = 0x6979,
+ /// <summary>Yoruba</summary>
+ Yoruba = 0x6F79,
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/eLogLevel.cs b/WhisperNet/API/eLogLevel.cs
new file mode 100644
index 0000000..ae494d4
--- /dev/null
+++ b/WhisperNet/API/eLogLevel.cs
@@ -0,0 +1,34 @@
+namespace Whisper
+{
+ /// <summary>Message log level</summary>
+ public enum eLogLevel: byte
+ {
+ /// <summary>Error message</summary>
+ Error = 0,
+ /// <summary>Warning message</summary>
+ Warning = 1,
+ /// <summary>Informational message</summary>
+ Info = 2,
+ /// <summary>Debug message</summary>
+ Debug = 3
+ }
+
+ /// <summary>A delegate to receive log messages from the library</summary>
+ public delegate void pfnLogMessage( eLogLevel level, string message );
+
+ /// <summary>Log destination flags</summary>
+ [Flags]
+ public enum eLoggerFlags: byte
+ {
+ /// <summary>No special flags</summary>
+ None = 0,
+
+ /// <summary>In addition to calling the delegate, print messaged to standard error</summary>
+ UseStandardError = 1,
+
+ /// <summary>Don’t format error codes into messages</summary>
+ /// <remarks>It’s recommended to use this flag in .NET.<br/>
+ /// The standard library already formats these messages automatically, as needed.</remarks>
+ SkipFormatMessage = 2,
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/eModelImplementation.cs b/WhisperNet/API/eModelImplementation.cs
new file mode 100644
index 0000000..1b0a079
--- /dev/null
+++ b/WhisperNet/API/eModelImplementation.cs
@@ -0,0 +1,25 @@
+namespace Whisper
+{
+ /// <summary>Implementation value for the <see cref="Library.loadModel(string, eModelImplementation)" /> factory function</summary>
+ public enum eModelImplementation: uint
+ {
+ /// <summary>GPGPU implementation based on Direct3D 11.0 compute shaders</summary>
+ GPU = 1,
+
+ /// <summary>A hybrid implementation which uses DirectCompute for encode, and decodes on CPU</summary>
+ /// <remarks>
+ /// <para>The build of the native DLL included into this nuget package doesn’t implement this version.<br/>
+ /// To enable, edit <c>stdafx.h</c> in Whisper project, change the value of <c>BUILD_HYBRID_VERSION</c> macro from zero to one, and build.</para>
+ /// <para>This implementation requires a CPU with AVX1, FMA3, F16C and BMI1 instruction set extensions.</para>
+ /// </remarks>
+ Hybrid = 2,
+
+ /// <summary>A reference implementation which uses the original GGML CPU-running code.</summary>
+ /// <remarks>
+ /// <para>The build of the native DLL included into this nuget package doesn’t implement this version either.<br/>
+ /// To enable, edit <c>stdafx.h</c> in Whisper project, change the value of <c>BUILD_BOTH_VERSIONS</c> macro from zero to one, and build the project.</para>
+ /// <para>This implementation requires a CPU with AVX1, FMA3, and F16C instruction set extensions.</para>
+ /// </remarks>
+ Reference = 3,
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/eResultFlags.cs b/WhisperNet/API/eResultFlags.cs
new file mode 100644
index 0000000..1de61ab
--- /dev/null
+++ b/WhisperNet/API/eResultFlags.cs
@@ -0,0 +1,21 @@
+namespace Whisper
+{
+ /// <summary>Flags for <see cref="Context.results(eResultFlags)" /> method</summary>
+ [Flags]
+ public enum eResultFlags: uint
+ {
+ /// <summary>No flags</summary>
+ None = 0,
+
+ /// <summary>Return individual tokens in addition to the segments</summary>
+ Tokens = 1,
+
+ /// <summary>Return timestamps</summary>
+ Timestamps = 2,
+
+ /// <summary>Create a new COM object for the results.</summary>
+ /// <remarks>Without this flag, the context returns a pointer to the COM object stored in the context.<br/>
+ /// The content of that object is replaced every time you call <see cref="Internal.iContext.getResults(eResultFlags)" /> method.</remarks>
+ NewObject = 0x100,
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/iAudioBuffer.cs b/WhisperNet/API/iAudioBuffer.cs
new file mode 100644
index 0000000..1b35621
--- /dev/null
+++ b/WhisperNet/API/iAudioBuffer.cs
@@ -0,0 +1,27 @@
+using ComLight;
+using System.Runtime.InteropServices;
+
+namespace Whisper
+{
+ /// <summary>A buffer with a chunk of audio.</summary>
+ /// <remarks>Note the interface supports both marshaling directions.<br/>
+ /// I have not tested, but you should be able to implement this interface in C#, to supply PCM audio data to the native code</remarks>
+ [ComInterface( "013583aa-c9eb-42bc-83db-633c2c317051", eMarshalDirection.BothWays )]
+ public interface iAudioBuffer: IDisposable
+ {
+ /// <summary>Count of samples in the buffer</summary>
+ int countSamples();
+
+ /// <summary>Unmanaged pointer to the internal buffer containing single-channel FP32 samples.</summary>
+ /// <remarks>If you implementing this interface in C# and your audio data is on the managed heap, use <see cref="GCHandle" /> to make sure it doesn't move.<br/>
+ /// Or better yet, move the data to unmanaged buffer allocated with <see cref="Marshal.AllocHGlobal(int)" /> or <see cref="Marshal.AllocCoTaskMem(int)" /> method.</remarks>
+ IntPtr getPcmMono();
+
+ /// <summary>Unmanaged pointer to the internal buffer containing stereo FP32 samples.</summary>
+ /// <remarks>When the buffer doesn’t have stereo data, the method gonna return <see cref="IntPtr.Zero" />.</remarks>
+ IntPtr getPcmStereo();
+
+ /// <summary>Start time of the buffer, relative to the start of the media</summary>
+ void getTime( out TimeSpan time );
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/iAudioReader.cs b/WhisperNet/API/iAudioReader.cs
new file mode 100644
index 0000000..68cf916
--- /dev/null
+++ b/WhisperNet/API/iAudioReader.cs
@@ -0,0 +1,23 @@
+using ComLight;
+
+namespace Whisper
+{
+ /// <summary>Audio stream reader object</summary>
+ /// <remarks>The implementation is forward-only, and these objects ain’t reusable.<br/>
+ /// To read a source file multiple time, dispose and re-create the reader.</remarks>
+ [ComInterface( "35b988da-04a6-476a-a193-d8891d5dc390", eMarshalDirection.ToManaged )]
+ public interface iAudioReader: IDisposable
+ {
+ /// <summary>Get duration of the media file</summary>
+ [RetValIndex]
+ TimeSpan getDuration();
+ }
+
+ /// <summary>Audio capture reader object</summary>
+ /// <remarks>This interface has no public methods callable from C#.<br/>
+ /// It’s only here to pass data between different functions implemented in C++.</remarks>
+ [ComInterface( "747752c2-d9fd-40df-8847-583c781bf013", eMarshalDirection.ToManaged )]
+ public interface iAudioCapture: IDisposable
+ {
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/iMediaFoundation.cs b/WhisperNet/API/iMediaFoundation.cs
new file mode 100644
index 0000000..535f904
--- /dev/null
+++ b/WhisperNet/API/iMediaFoundation.cs
@@ -0,0 +1,36 @@
+using ComLight;
+using System.Runtime.InteropServices;
+using Whisper.Internal;
+
+namespace Whisper
+{
+ /// <summary>Exposes a small subset of MS Media Foundation framework.</summary>
+ /// <remarks>That framework is a part of Windows OS, since Vista.</remarks>
+ /// <seealso href="https://learn.microsoft.com/en-us/windows/win32/medfound/microsoft-media-foundation-sdk" />
+ [ComInterface( "fb9763a5-d77d-4b6e-aff8-f494813cebd8", eMarshalDirection.ToManaged ), CustomConventions( typeof( NativeLogger ) )]
+ public interface iMediaFoundation: IDisposable
+ {
+ /// <summary>Decode complete audio file into a new memory buffer.</summary>
+ /// <returns>
+ /// Under the hood, the method asks MF to resample and convert audio into the suitable type for the Whisper model.<br/>
+ /// If the path is a video file, the method will decode the first audio track.
+ /// </returns>
+ [RetValIndex( 2 )]
+ iAudioBuffer loadAudioFile( [MarshalAs( UnmanagedType.LPWStr )] string path, [MarshalAs( UnmanagedType.U1 )] bool stereo = false );
+
+ /// <summary>Create a reader to stream the audio file from disk</summary>
+ /// <returns>
+ /// Under the hood, the method asks MF to resample and convert audio into the suitable type for the Whisper model.<br/>
+ /// If the path is a video file, the method will decode the first audio track.
+ /// </returns>
+ [RetValIndex( 2 )]
+ iAudioReader openAudioFile( [MarshalAs( UnmanagedType.LPWStr )] string path, [MarshalAs( UnmanagedType.U1 )] bool stereo = false );
+
+ /// <summary>List capture devices</summary>
+ void listCaptureDevices( [MarshalAs( UnmanagedType.FunctionPtr )] pfnFoundCaptureDevices pfn, IntPtr pv );
+
+ /// <summary>Open audio capture device</summary>
+ [RetValIndex( 2 )]
+ iAudioCapture openCaptureDevice( [MarshalAs( UnmanagedType.LPWStr )] string endpoint, [In] ref sCaptureParams captureParams );
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/iModel.cs b/WhisperNet/API/iModel.cs
new file mode 100644
index 0000000..8ec6d17
--- /dev/null
+++ b/WhisperNet/API/iModel.cs
@@ -0,0 +1,27 @@
+using ComLight;
+using System.ComponentModel;
+
+namespace Whisper
+{
+ /// <summary>A model in VRAM, loaded from GGML file.</summary>
+ /// <remarks>This objetc doesn't keep any mutable state, and can be safely used from multiple threads concurrently</remarks>
+ [ComInterface( "abefb4c9-e8d8-46a3-8747-5afbadef1adb", eMarshalDirection.ToManaged ), CustomConventions( typeof( Internal.NativeLogger ) )]
+ public interface iModel: IDisposable
+ {
+ /// <summary>Create a context to transcribe audio with this model</summary>
+ /// <remarks>Don't call this method, use <see cref="ExtensionMethods.createContext(iModel)" /> instead.</remarks>
+ [RetValIndex, EditorBrowsable( EditorBrowsableState.Never )]
+ Internal.iContext createContextInternal();
+
+ /// <summary>True if this model is multi-lingual</summary>
+ bool isMultilingual();
+
+ /// <summary>Retrieve integer IDs of the special tokens defined by the model</summary>
+ [RetValIndex]
+ SpecialTokens getSpecialTokens();
+
+ /// <summary>Try to resolve integer token ID into string.</summary>
+ /// <remarks>Don't call this method, use <see cref="ExtensionMethods.stringFromToken(iModel, int)" /> instead.</remarks>
+ IntPtr stringFromTokenInternal( int id );
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/sCaptureParams.cs b/WhisperNet/API/sCaptureParams.cs
new file mode 100644
index 0000000..7595a69
--- /dev/null
+++ b/WhisperNet/API/sCaptureParams.cs
@@ -0,0 +1,37 @@
+namespace Whisper
+{
+ /// <summary>Flags for the audio capture</summary>
+ [Flags]
+ public enum eCaptureFlags: uint
+ {
+ /// <summary>No special flags</summary>
+ None = 0,
+ /// <summary>When the capture device supports stereo, keep stereo PCM samples in addition to mono</summary>
+ Stereo = 1,
+ }
+
+ /// <summary>Parameters for audio capture</summary>
+ public struct sCaptureParams
+ {
+ /// <summary>Minimum transcribe duration in seconds</summary>
+ public float minDuration;
+ /// <summary>Maximum transcribe duration in seconds</summary>
+ public float maxDuration;
+ /// <summary></summary>
+ public float dropStartSilence;
+ /// <summary></summary>
+ public float pauseDuration;
+ /// <summary>Flags for the audio capture</summary>
+ public eCaptureFlags flags;
+
+ /// <summary>Initialize the structure with some reasonable default values</summary>
+ public sCaptureParams()
+ {
+ minDuration = 7.0f; // 7 seconds
+ maxDuration = 11.0f; // 11 seconds
+ dropStartSilence = 0.25f; // 250 ms
+ pauseDuration = 0.333f; // 333 ms
+ flags = eCaptureFlags.None;
+ }
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Callbacks.cs b/WhisperNet/Callbacks.cs
new file mode 100644
index 0000000..db38718
--- /dev/null
+++ b/WhisperNet/Callbacks.cs
@@ -0,0 +1,44 @@
+using Whisper.Internal;
+
+namespace Whisper
+{
+ /// <summary>Implement this abstract class to receive callbacks from the native code</summary>
+ public abstract class Callbacks
+ {
+ /// <summary>The callback is called before every encoder run.</summary>
+ /// <remarks>If it returns false, the processing is aborted.</remarks>
+ protected virtual bool onEncoderBegin( Context sender ) { return true; }
+
+ /// <summary>This callback is called on each new segment</summary>
+ protected virtual void onNewSegment( Context sender, int countNew ) { }
+
+ const int S_OK = 0;
+ const int S_FALSE = 1;
+ internal int encoderBegin( Context sender )
+ {
+ try
+ {
+ return onEncoderBegin( sender ) ? S_OK : S_FALSE;
+ }
+ catch( Exception ex )
+ {
+ NativeLogger.captureException( ex );
+ return ex.HResult;
+ }
+ }
+
+ internal int newSegment( Context sender, int countNew )
+ {
+ try
+ {
+ onNewSegment( sender, countNew );
+ return S_OK;
+ }
+ catch( Exception ex )
+ {
+ NativeLogger.captureException( ex );
+ return ex.HResult;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/CaptureCallbacks.cs b/WhisperNet/CaptureCallbacks.cs
new file mode 100644
index 0000000..26013f9
--- /dev/null
+++ b/WhisperNet/CaptureCallbacks.cs
@@ -0,0 +1,49 @@
+using Whisper.Internal;
+
+namespace Whisper
+{
+ /// <summary>Implement this abstract class to provide callbacks for audio capture method</summary>
+ public abstract class CaptureCallbacks
+ {
+ /// <summary>Override this method to support cancellation</summary>
+ protected virtual bool shouldCancel( Context sender ) { return false; }
+
+ /// <summary>Override this method to get notified about status changes</summary>
+ protected virtual void captureStatusChanged( Context sender, eCaptureStatus status ) { }
+
+ internal pfnShouldCancel cancel( Context sender )
+ {
+ const int S_OK = 0;
+ const int S_FALSE = 1;
+ return delegate ( IntPtr pv )
+ {
+ try
+ {
+ return shouldCancel( sender ) ? S_OK : S_FALSE;
+ }
+ catch( Exception ex )
+ {
+ NativeLogger.captureException( ex );
+ return ex.HResult;
+ }
+ };
+ }
+
+ internal pfnCaptureStatus status( Context sender )
+ {
+ return delegate ( IntPtr pv, eCaptureStatus status )
+ {
+ try
+ {
+ captureStatusChanged( sender, status );
+ return 0;
+ }
+ catch( Exception ex )
+ {
+ NativeLogger.captureException( ex );
+ return ex.HResult;
+ }
+ };
+ }
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Context.cs b/WhisperNet/Context.cs
new file mode 100644
index 0000000..6c6a737
--- /dev/null
+++ b/WhisperNet/Context.cs
@@ -0,0 +1,201 @@
+using System.Diagnostics;
+using Whisper.Internal;
+using Whisper.Internals;
+
+namespace Whisper
+{
+ /// <summary>Stateful context, contains methods to transcribe audio</summary>
+ public sealed class Context: IDisposable
+ {
+ iContext context;
+ // Caching the results object here saves time spent in ComLight library creating these callable proxies over and over again for the same underlying C++ object
+ readonly iTranscribeResult transcribeResult;
+ sFullParams fullParams;
+ sProgressSink progressSink;
+ bool disposed = false;
+ readonly Action<object> pfnBuffer, pfnStream;
+
+ internal Context( Internal.iContext context )
+ {
+ this.context = context;
+ transcribeResult = context.getResults( eResultFlags.None );
+ fullParams = context.fullDefaultParams( eSamplingStrategy.Greedy );
+ pfnBuffer = processBuffer;
+ pfnStream = processStream;
+ progressSink = default;
+ }
+
+ void IDisposable.Dispose()
+ {
+ if( disposed )
+ return;
+ disposed = true;
+ context?.Dispose();
+ GC.SuppressFinalize( this );
+ }
+
+ /// <summary>Adjustable parameters</summary>
+ public ref Parameters parameters => ref fullParams.publicParams;
+
+ void processBuffer( object buffer )
+ {
+ context.runFull( ref fullParams, (iAudioBuffer)buffer );
+ }
+ void processStream( object reader )
+ {
+ context.runStreamed( ref fullParams, ref progressSink, (iAudioReader)reader );
+ }
+
+ void runImpl( object source, Callbacks? callbacks, ReadOnlySpan<int> promptTokens, Action<object> pfn )
+ {
+ if( null != callbacks )
+ {
+ // TODO [very low, performance]: the following code creates 2 new GC-allocated objects on each call.
+ // Possible to optimize by caching these function pointers in static readonly fields, and use another [ThreadStatic] field for the callbacks object
+ fullParams.newSegmentCallback = delegate ( IntPtr ctx, int countNew, IntPtr userData )
+ {
+ return callbacks.newSegment( this, countNew );
+ };
+
+ fullParams.encoderBeginCallback = delegate ( IntPtr ctx, IntPtr userData )
+ {
+ return callbacks.encoderBegin( this );
+ };
+ }
+
+ try
+ {
+ if( promptTokens.IsEmpty )
+ {
+ pfn( source );
+ return;
+ }
+ unsafe
+ {
+ fixed( int* tokens = promptTokens )
+ {
+ fullParams.prompt_tokens = (IntPtr)tokens;
+ fullParams.prompt_n_tokens = promptTokens.Length;
+ pfn( source );
+ }
+ }
+ }
+ finally
+ {
+ // Reset these delegates.
+ // Otherwise, this class will retain the callbacks object preventing it from being garbage collected.
+ fullParams.newSegmentCallback = null;
+ fullParams.encoderBeginCallback = null;
+
+ fullParams.prompt_tokens = IntPtr.Zero;
+ fullParams.prompt_n_tokens = 0;
+ }
+ }
+
+ /// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary>
+ public void runFull( iAudioBuffer buffer, Callbacks? callbacks, ReadOnlySpan<int> promptTokens )
+ {
+ runImpl( buffer, callbacks, promptTokens, pfnBuffer );
+ }
+ /// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary>
+ public void runFull( iAudioBuffer buffer, Callbacks? callbacks = null ) =>
+ runFull( buffer, callbacks, ReadOnlySpan<int>.Empty );
+ /// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary>
+ public void runFull( iAudioBuffer buffer, Callbacks? callbacks, int[]? promptTokens ) =>
+ runFull( buffer, callbacks, promptTokens ?? ReadOnlySpan<int>.Empty );
+
+ /// <summary>Run the entire model, streaming audio from the provided reader object</summary>
+ public void runFull( iAudioReader reader, Callbacks? callbacks, Action<double>? pfnProgress, ReadOnlySpan<int> promptTokens )
+ {
+ if( null != pfnProgress )
+ {
+ progressSink.pfn = delegate ( double value, IntPtr context, IntPtr pv )
+ {
+ try
+ {
+ pfnProgress.Invoke( value );
+ return 0;
+ }
+ catch( Exception ex )
+ {
+ return ex.HResult;
+ }
+ };
+ }
+ try
+ {
+ runImpl( reader, callbacks, promptTokens, pfnStream );
+ }
+ finally
+ {
+ progressSink.pfn = null;
+ }
+ }
+
+ /// <summary>Run the entire model, streaming audio from the provided reader object</summary>
+ public void runFull( iAudioReader reader, Action<double>? pfnProgress = null, Callbacks? callbacks = null ) =>
+ runFull( reader, callbacks, pfnProgress, ReadOnlySpan<int>.Empty );
+
+ /// <summary>Run the entire model, streaming audio from the provided reader object</summary>
+ public void runFull( iAudioReader reader, Callbacks? callbacks, Action<double>? pfnProgress, int[]? promptTokens ) =>
+ runFull( reader, callbacks, pfnProgress, promptTokens ?? ReadOnlySpan<int>.Empty );
+
+ /// <summary>Get text results out of the context</summary>
+ public TranscribeResult results( eResultFlags flags = eResultFlags.None )
+ {
+ if( flags.HasFlag( eResultFlags.NewObject ) )
+ throw new ArgumentException();
+
+ iTranscribeResult res = context.getResults( flags );
+ Debug.Assert( ReferenceEquals( res, transcribeResult ) );
+ return new TranscribeResult( res );
+ }
+
+ /// <summary>Print timing data</summary>
+ public void timingsPrint() => context.timingsPrint();
+
+ /// <summary>Reset timing data</summary>
+ public void timingsReset() => context.timingsReset();
+
+ /// <summary>Continuously process audio from microphone or a similar capture device</summary>
+ /// <remarks>It’s recommended to call this method on a background thread.</remarks>
+ public void runCapture( iAudioCapture capture, Callbacks? callbacks, CaptureCallbacks? captureCallbacks )
+ {
+ if( null != callbacks )
+ {
+ // TODO [very low, performance]: the following code creates 2 new GC-allocated objects on each call.
+ // Possible to optimize by caching these function pointers in static readonly fields, and use another [ThreadStatic] field for the callbacks object
+ fullParams.newSegmentCallback = delegate ( IntPtr ctx, int countNew, IntPtr userData )
+ {
+ return callbacks.newSegment( this, countNew );
+ };
+
+ fullParams.encoderBeginCallback = delegate ( IntPtr ctx, IntPtr userData )
+ {
+ return callbacks.encoderBegin( this );
+ };
+ }
+
+ try
+ {
+ sCaptureCallbacks cc = default;
+ if( captureCallbacks != null )
+ {
+ cc.shouldCancel = captureCallbacks.cancel( this );
+ cc.captureStatus = captureCallbacks.status( this );
+ }
+ context.runCapture( ref fullParams, ref cc, capture );
+ }
+ finally
+ {
+ // Reset these delegates.
+ // Otherwise, this class will retain the callbacks object preventing it from being garbage collected.
+ fullParams.newSegmentCallback = null;
+ fullParams.encoderBeginCallback = null;
+
+ fullParams.prompt_tokens = IntPtr.Zero;
+ fullParams.prompt_n_tokens = 0;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/ExtensionMethods.cs b/WhisperNet/ExtensionMethods.cs
new file mode 100644
index 0000000..4380ece
--- /dev/null
+++ b/WhisperNet/ExtensionMethods.cs
@@ -0,0 +1,69 @@
+using System.Runtime.InteropServices;
+using Whisper.Internal;
+
+namespace Whisper
+{
+ /// <summary>Extension methods of these COM interfaces</summary>
+ public static class ExtensionMethods
+ {
+ /// <summary>Create a context to transcribe audio with this model</summary>
+ public static Context createContext( this iModel model )
+ {
+ iContext ctx = model.createContextInternal();
+ return new Context( ctx );
+ }
+
+ /// <summary>Convert language into a short ID string, like <c>"en"</c></summary>
+ public static string getCode( this eLanguage lang )
+ {
+ unsafe
+ {
+ sbyte* ptr = stackalloc sbyte[ 5 ];
+ *(uint*)ptr = (uint)lang;
+ ptr[ 4 ] = 0;
+ return new string( ptr );
+ }
+ }
+
+ /// <summary>Resolve integer token ID into string.</summary>
+ /// <remarks>If the token ID was not found in the model, the method returns null without raising exceptions.</remarks>
+ public static string? stringFromToken( this iModel model, int idToken ) =>
+ Marshal.PtrToStringUTF8( model.stringFromTokenInternal( idToken ) );
+
+ /// <summary>List capture devices</summary>
+ public static CaptureDeviceId[]? listCaptureDevices( this iMediaFoundation mf )
+ {
+ List<CaptureDeviceId>? list = null;
+
+ pfnFoundCaptureDevices pfn = delegate ( int len, sCaptureDevice[]? arr, IntPtr pv )
+ {
+ try
+ {
+ if( len == 0 || arr == null )
+ return 1;
+
+ list = new List<CaptureDeviceId>( len );
+ foreach( var i in arr )
+ list.Add( new CaptureDeviceId( i ) );
+ return 0;
+ }
+ catch( Exception ex )
+ {
+ NativeLogger.captureException( ex );
+ return ex.HResult;
+ }
+ };
+
+ mf.listCaptureDevices( pfn, IntPtr.Zero );
+
+ return list?.ToArray();
+ }
+
+ /// <summary>Open audio capture device</summary>
+ public static iAudioCapture openCaptureDevice( this iMediaFoundation mf, in CaptureDeviceId id, sCaptureParams? cp = null )
+ {
+ sCaptureParams captureParams = cp ?? new sCaptureParams();
+ return mf.openCaptureDevice( id.endpoint, ref captureParams );
+ }
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/AssemblyInfo.cs b/WhisperNet/Internal/AssemblyInfo.cs
new file mode 100644
index 0000000..29ec638
--- /dev/null
+++ b/WhisperNet/Internal/AssemblyInfo.cs
@@ -0,0 +1,8 @@
+using System.Reflection;
+using System.Runtime.InteropServices;
+[assembly: AssemblyTitle( "WhisperNet" )]
+[assembly: AssemblyCopyright( "Copyright © const.me, 2022" )]
+[assembly: ComVisible( false )]
+[assembly: Guid( "ced6cdb7-e040-4398-bae8-3417e5fa35f1" )]
+[assembly: AssemblyVersion( "1.0.0.0" )]
+[assembly: AssemblyDescription( "DirectCompute port of whisper.cpp library, C# bindings" )] \ No newline at end of file
diff --git a/WhisperNet/Internal/NativeLogger.cs b/WhisperNet/Internal/NativeLogger.cs
new file mode 100644
index 0000000..b4b4eb2
--- /dev/null
+++ b/WhisperNet/Internal/NativeLogger.cs
@@ -0,0 +1,138 @@
+using System.Runtime.CompilerServices;
+using System.Runtime.ExceptionServices;
+using System.Runtime.InteropServices;
+
+namespace Whisper.Internal
+{
+ /// <summary>Utility class to supply logging function pointer to the C++ library,<br/>
+ /// and provide custom calling conventions to ComLight runtime to convert error messages printed in C++ into .NET exception messages</summary>
+ public static class NativeLogger
+ {
+ internal static void startup() { }
+
+ static NativeLogger()
+ {
+ sink = logSink;
+ sLoggerSetup setup = default;
+ setup.sink = sink;
+ setup.level = eLogLevel.Warning;
+ Library.setupLogger( ref setup );
+ }
+
+ internal static void setup( eLogLevel lvl, eLoggerFlags flags, pfnLogMessage? pfn )
+ {
+ logMessage = pfn;
+
+ sLoggerSetup setup = default;
+ setup.sink = sink;
+ setup.level = lvl;
+ setup.flags = flags;
+ Library.setupLogger( ref setup );
+ }
+
+ // This field is here to protect the function pointer from being collected by the GC
+ static readonly pfnLoggerSink sink;
+
+ static void logSink( IntPtr context, eLogLevel lvl, string message )
+ {
+ if( lvl == eLogLevel.Error )
+ state.setText( message );
+ logMessage?.Invoke( lvl, message );
+ }
+
+ sealed class ThreadState
+ {
+ string? errorText = null;
+ ExceptionDispatchInfo? dispatchInfo = null;
+
+ public void setText( string text ) => errorText = text;
+ public void capture( Exception ex ) => dispatchInfo = ExceptionDispatchInfo.Capture( ex );
+
+ public void clear()
+ {
+ errorText = null;
+ dispatchInfo = null;
+ }
+
+ public void Deconstruct( out string? text, out ExceptionDispatchInfo? edi )
+ {
+ text = errorText;
+ edi = dispatchInfo;
+ errorText = null;
+ dispatchInfo = null;
+ }
+ }
+
+ [ThreadStatic]
+ static ThreadState state = new ThreadState();
+
+ internal static void captureException( Exception ex ) =>
+ state.capture( ex );
+
+ static pfnLogMessage? logMessage = null;
+
+ /// <summary>Called internally by ComLight runtime</summary>
+ [MethodImpl( MethodImplOptions.AggressiveInlining )]
+ public static void prologue()
+ {
+ // https://stackoverflow.com/a/2043505/126995
+ if( null != state )
+ state.clear();
+ else
+ createState();
+ }
+
+ [MethodImpl( MethodImplOptions.NoInlining )]
+ static void createState()
+ {
+ state = new ThreadState();
+ }
+
+ /// <summary>Epilogue implementation for unsuccessful status codes</summary>
+ [MethodImpl( MethodImplOptions.NoInlining )]
+ static void throwException( int hr )
+ {
+ // Move state from the thread local object into local variables, and clear that object
+ (string? text, ExceptionDispatchInfo? edi) = state;
+
+ if( null != edi && edi.SourceException.HResult == hr )
+ {
+ // The error comes from a callback, and we have original context of that exception.
+ // Re-throw the original exception.
+ // This uses the original error message, and even correctly deals with the stack trace.
+ edi.Throw();
+ }
+
+ if( null != text )
+ {
+ // C++ code has printed an error on the current thread, between prologue and epilogue.
+ // Use that text for the exception message.
+ Exception? ex = Marshal.GetExceptionForHR( hr );
+ throw new ApplicationException( text, ex );
+ }
+
+ // We don’t have any additional info about the exception.
+ // Throw an exception from just the HRESULT code.
+ Marshal.ThrowExceptionForHR( hr );
+ }
+
+ /// <summary>Called internally by ComLight runtime</summary>
+ [MethodImpl( MethodImplOptions.AggressiveInlining )]
+ public static void throwForHR( int hr )
+ {
+ if( hr >= 0 )
+ return; // SUCCEEDED
+ throwException( hr );
+ }
+
+ /// <summary>Called internally by ComLight runtime</summary>
+ [MethodImpl( MethodImplOptions.AggressiveInlining )]
+ public static bool throwAndReturnBool( int hr )
+ {
+ if( hr >= 0 )
+ return 0 == hr;
+ throwException( hr );
+ return false;
+ }
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/iContext.cs b/WhisperNet/Internal/iContext.cs
new file mode 100644
index 0000000..6adf8c5
--- /dev/null
+++ b/WhisperNet/Internal/iContext.cs
@@ -0,0 +1,37 @@
+using ComLight;
+using System.Runtime.InteropServices;
+using Whisper.Internals;
+
+namespace Whisper.Internal
+{
+ /// <summary>Stateful context, contains methods to transcribe audio</summary>
+ [ComInterface( "b9956374-3b18-4943-90f2-2ab18a404537", eMarshalDirection.ToManaged ), CustomConventions( typeof( NativeLogger ) )]
+ public interface iContext: IDisposable
+ {
+ /// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary>
+ void runFull( [In] ref sFullParams @params, iAudioBuffer buffer );
+
+ /// <summary>Run the entire model, streaming audio from the provided reader object</summary>
+ void runStreamed( [In] ref sFullParams @params, [In] ref sProgressSink progressSink, iAudioReader reader );
+
+ /// <summary>Continuously process audio from microphone or a similar capture device</summary>
+ void runCapture( [In] ref sFullParams @params, [In] ref sCaptureCallbacks callbacks, iAudioCapture reader );
+
+ /// <summary>Get text results out of the context</summary>
+ [RetValIndex( 1 )]
+ iTranscribeResult getResults( eResultFlags flags );
+
+ /// <summary>Get the model which was used to create this context</summary>
+ [RetValIndex]
+ iModel getModel();
+
+ /// <summary>Full the default parameters of the model, for the specified sampling strategy</summary>
+ [RetValIndex( 1 )]
+ sFullParams fullDefaultParams( eSamplingStrategy strategy );
+
+ /// <summary>Print timing data</summary>
+ void timingsPrint();
+ /// <summary>Reset timing data</summary>
+ void timingsReset();
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/iTranscribeResult.cs b/WhisperNet/Internal/iTranscribeResult.cs
new file mode 100644
index 0000000..cbf49dd
--- /dev/null
+++ b/WhisperNet/Internal/iTranscribeResult.cs
@@ -0,0 +1,122 @@
+#pragma warning disable CS0649 // Field is never assigned to
+using ComLight;
+using System.ComponentModel;
+using System.Runtime.InteropServices;
+
+namespace Whisper.Internal
+{
+ /// <summary>Size of the buffers owned by the <see cref="iTranscribeResult" /> object</summary>
+ public readonly struct sTranscribeLength
+ {
+ /// <summary>Count of segments</summary>
+ public readonly int countSegments;
+ /// <summary>Total count of tokens, for all segments combined</summary>
+ public readonly int countTokens;
+ }
+
+ /// <summary>Output data from the model</summary>
+ [ComInterface( "2871a73f-5ce3-48f8-8779-6582ee11935e", eMarshalDirection.ToManaged ), CustomConventions( typeof( NativeLogger ) )]
+ public interface iTranscribeResult
+ {
+ /// <summary>Get size of the buffers</summary>
+ [RetValIndex, EditorBrowsable( EditorBrowsableState.Never )]
+ public sTranscribeLength getSize();
+
+ /// <summary>Pointer to segment data, a vector of <see cref="sSegment" /> structures</summary>
+ [EditorBrowsable( EditorBrowsableState.Never )]
+ public IntPtr getSegments();
+
+ /// <summary>Pointer to tokens data, a vector of <see cref="sToken" /> structures</summary>
+ [EditorBrowsable( EditorBrowsableState.Never )]
+ public IntPtr getTokens();
+ }
+}
+
+namespace Whisper
+{
+ /// <summary>Start and end times of a segment or token</summary>
+ /// <remarks>The times are relative to the start of the media</remarks>
+ public readonly struct sTimeInterval
+ {
+ /// <summary>Start time</summary>
+ public readonly TimeSpan begin;
+ /// <summary>End time</summary>
+ public readonly TimeSpan end;
+ }
+
+ /// <summary>Segment data</summary>
+ public readonly struct sSegment
+ {
+ internal readonly IntPtr m_text;
+ /// <summary>Segment text</summary>
+ public string? text => Marshal.PtrToStringUTF8( m_text );
+ /// <summary>Start and end times of the segment</summary>
+ public readonly sTimeInterval time;
+ /// <summary>Slice of the tokens</summary>
+ public readonly int firstToken, countTokens;
+ }
+
+ /// <summary>Token flags</summary>
+ [Flags]
+ public enum eTokenFlags: uint
+ {
+ /// <summary>The token is special</summary>
+ Special = 1,
+ }
+
+ /// <summary>Token data</summary>
+ public readonly struct sToken
+ {
+ internal readonly IntPtr m_text;
+ /// <summary>Token text</summary>
+ public string? text => Marshal.PtrToStringUTF8( m_text );
+ /// <summary>Start and end times of the token</summary>
+ public readonly sTimeInterval time;
+ /// <summary>Probability of the token</summary>
+ public readonly float probability;
+ /// <summary>Probability of the timestamp token</summary>
+ public readonly float probabilityTimestamp;
+ /// <summary>Sum of probabilities of all timestamp tokens</summary>
+ public readonly float ptsum;
+ /// <summary>Voice length of the token</summary>
+ public readonly float vlen;
+ /// <summary>Token id</summary>
+ public readonly int id;
+ /// <summary>Token flags</summary>
+ readonly eTokenFlags flags;
+ /// <summary>True if the token flags has the specified bit set</summary>
+ public bool hasFlag( eTokenFlags bit ) => flags.HasFlag( bit );
+ }
+
+ /// <summary>Output data from the model</summary>
+ public readonly ref struct TranscribeResult
+ {
+ /// <summary>Segments in the results</summary>
+ public readonly ReadOnlySpan<sSegment> segments;
+ /// <summary>Tokens in the results, for all segments</summary>
+ public readonly ReadOnlySpan<sToken> tokens;
+
+ internal TranscribeResult( Internal.iTranscribeResult i )
+ {
+ Internal.sTranscribeLength len = i.getSize();
+ unsafe
+ {
+ // This does not copy the buffers to managed memory.
+ // Instead, the C# spans directly reference the native memory stored in these std::vectors
+ if( len.countSegments > 0 )
+ segments = new ReadOnlySpan<sSegment>( (void*)i.getSegments(), len.countSegments );
+ else
+ segments = ReadOnlySpan<sSegment>.Empty;
+
+ if( len.countTokens > 0 )
+ tokens = new ReadOnlySpan<sToken>( (void*)i.getTokens(), len.countTokens );
+ else
+ tokens = ReadOnlySpan<sToken>.Empty;
+ }
+ }
+
+ /// <summary>Get tokens for the specified segment</summary>
+ public ReadOnlySpan<sToken> getTokens( in sSegment seg ) =>
+ tokens.Slice( seg.firstToken, seg.countTokens );
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/sCaptureCallbacks.cs b/WhisperNet/Internal/sCaptureCallbacks.cs
new file mode 100644
index 0000000..483c2f2
--- /dev/null
+++ b/WhisperNet/Internal/sCaptureCallbacks.cs
@@ -0,0 +1,23 @@
+using System.Runtime.InteropServices;
+
+namespace Whisper.Internal
+{
+ /// <summary>Unmanaged code calls this to check for cancellation</summary>
+ [UnmanagedFunctionPointer( CallingConvention.StdCall )]
+ public delegate int pfnShouldCancel( IntPtr pv );
+
+ /// <summary>Unmanaged code calls this to notify about the status</summary>
+ [UnmanagedFunctionPointer( CallingConvention.StdCall )]
+ public delegate int pfnCaptureStatus( IntPtr pv, eCaptureStatus status );
+
+ /// <summary>Capture callbacks for unmanaged code</summary>
+ public struct sCaptureCallbacks
+ {
+ /// <summary>Cancellation function pointer</summary>
+ public pfnShouldCancel shouldCancel;
+ /// <summary>Capture status function pointer</summary>
+ public pfnCaptureStatus captureStatus;
+ /// <summary>COntext pointer, only needed for C++ compatibility</summary>
+ public IntPtr pv;
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/sCaptureDevice.cs b/WhisperNet/Internal/sCaptureDevice.cs
new file mode 100644
index 0000000..e2d524d
--- /dev/null
+++ b/WhisperNet/Internal/sCaptureDevice.cs
@@ -0,0 +1,22 @@
+#pragma warning disable CS0649 // Field is never assigned to
+using System.Runtime.InteropServices;
+
+namespace Whisper.Internal
+{
+ /// <summary>Identifiers for an audio capture device</summary>
+ public struct sCaptureDevice
+ {
+ readonly IntPtr m_displayName;
+ /// <summary>The display name is suitable for showing to the user, but might not be unique.</summary>
+ public string? displayName => Marshal.PtrToStringUni( m_displayName );
+
+ readonly IntPtr m_endpoint;
+ /// <summary>Endpoint ID for an audio capture device.<br/>
+ /// It uniquely identifies the device on the system, but is not a readable string.</summary>
+ public string? endpoint => Marshal.PtrToStringUni( m_endpoint );
+ }
+
+ /// <summary>Function pointer to consume a list of audio capture device IDs</summary>
+ [UnmanagedFunctionPointer( CallingConvention.StdCall )]
+ public delegate int pfnFoundCaptureDevices( int len, [In, MarshalAs( UnmanagedType.LPArray, SizeParamIndex = 0 )] sCaptureDevice[]? arr, IntPtr pv );
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/sFullParams.cs b/WhisperNet/Internal/sFullParams.cs
new file mode 100644
index 0000000..7347afe
--- /dev/null
+++ b/WhisperNet/Internal/sFullParams.cs
@@ -0,0 +1,40 @@
+#pragma warning disable CS0649 // Field is never assigned to
+
+// Missing XML comment for publicly visible type or member
+// TODO: remove this line and document them.
+#pragma warning disable CS1591
+
+using System.Runtime.InteropServices;
+
+namespace Whisper.Internals
+{
+ /// <summary>This callback is called on each new segment</summary>
+ [UnmanagedFunctionPointer( CallingConvention.Cdecl )]
+ delegate int pfnNewSegment( IntPtr ctx, int countNew, IntPtr userData );
+
+ /// <summary>The callback is called before every encoder run. If it returns S_FALSE, the processing is aborted.</summary>
+ [UnmanagedFunctionPointer( CallingConvention.Cdecl )]
+ delegate int pfnEncoderBegin( IntPtr ctx, IntPtr userData );
+
+ /// <summary>Transcribe parameters</summary>
+ public struct sFullParams
+ {
+ internal Parameters publicParams;
+ // The rest of these parameters are not exposed to the user-friendly public API of this DLL
+
+ internal IntPtr prompt_tokens;
+ internal int prompt_n_tokens;
+
+ /// <summary>This callback is called on each new segment</summary>
+ [MarshalAs( UnmanagedType.FunctionPtr )]
+ internal pfnNewSegment? newSegmentCallback;
+ /// <summary>Parameter for the above, not needed in C#</summary>
+ internal IntPtr newSegmentCallbackData;
+
+ /// <summary>The callback is called before every encoder run. If it returns false, the processing is aborted</summary>
+ [MarshalAs( UnmanagedType.FunctionPtr )]
+ internal pfnEncoderBegin? encoderBeginCallback;
+ /// <summary>Parameter for the above, not needed in C#</summary>
+ internal IntPtr encoderBeginCallbackData;
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/sLoadModelCallbacks.cs b/WhisperNet/Internal/sLoadModelCallbacks.cs
new file mode 100644
index 0000000..07f5199
--- /dev/null
+++ b/WhisperNet/Internal/sLoadModelCallbacks.cs
@@ -0,0 +1,64 @@
+using System.Runtime.InteropServices;
+
+namespace Whisper.Internal
+{
+ /// <summary>Function pointer to report model loading progress</summary>
+ [UnmanagedFunctionPointer( CallingConvention.StdCall )]
+ delegate int pfnLoadProgress( double progress, IntPtr pv );
+
+ /// <summary>Function pointer to implement cooperative cancellation</summary>
+ [UnmanagedFunctionPointer( CallingConvention.StdCall )]
+ delegate int pfnCancel( IntPtr pv );
+
+ /// <summary>Callback functions for loading models</summary>
+ public struct sLoadModelCallbacks
+ {
+ /// <summary>Function pointer to report model loading progress</summary>
+ [MarshalAs( UnmanagedType.FunctionPtr )]
+ pfnLoadProgress? progress;
+
+ /// <summary>Function pointer to implement cooperative cancellation</summary>
+ [MarshalAs( UnmanagedType.FunctionPtr )]
+ pfnCancel? cancel;
+
+ // Not needed in C#, delegates can capture things
+ IntPtr pv;
+
+ /// <summary>Wrap idiomatic C# things into these low-level C callbacks</summary>
+ internal sLoadModelCallbacks( CancellationToken cancelToken, Action<double>? pfnProgress )
+ {
+ if( cancelToken != CancellationToken.None )
+ {
+ cancel = delegate ( IntPtr pv )
+ {
+ if( cancelToken.IsCancellationRequested )
+ return 1; // S_FALSE
+ return 0; // S_OK
+ };
+ }
+ else
+ cancel = null;
+
+ if( null != pfnProgress )
+ {
+ progress = delegate ( double val, IntPtr pv )
+ {
+ try
+ {
+ pfnProgress( val );
+ return 0; // S_OK
+ }
+ catch( Exception ex )
+ {
+ NativeLogger.captureException( ex );
+ return ex.HResult;
+ }
+ };
+ }
+ else
+ progress = null;
+
+ pv = IntPtr.Zero;
+ }
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/sLoggerSetup.cs b/WhisperNet/Internal/sLoggerSetup.cs
new file mode 100644
index 0000000..ed1baa4
--- /dev/null
+++ b/WhisperNet/Internal/sLoggerSetup.cs
@@ -0,0 +1,16 @@
+using System.Runtime.InteropServices;
+
+namespace Whisper.Internal
+{
+ [UnmanagedFunctionPointer( CallingConvention.StdCall )]
+ delegate void pfnLoggerSink( IntPtr context, eLogLevel lvl, [MarshalAs( UnmanagedType.LPUTF8Str )] string message );
+
+ struct sLoggerSetup
+ {
+ [MarshalAs( UnmanagedType.FunctionPtr )]
+ public pfnLoggerSink sink;
+ IntPtr context;
+ public eLogLevel level;
+ public eLoggerFlags flags;
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Internal/sProgressSink.cs b/WhisperNet/Internal/sProgressSink.cs
new file mode 100644
index 0000000..9155677
--- /dev/null
+++ b/WhisperNet/Internal/sProgressSink.cs
@@ -0,0 +1,20 @@
+#pragma warning disable CS0649 // Field is never assigned to
+using System.Runtime.InteropServices;
+
+namespace Whisper.Internal
+{
+ /// <summary>A callback to get notified about the progress</summary>
+ [UnmanagedFunctionPointer( CallingConvention.StdCall )]
+ delegate int pfnReportProgress( double value, IntPtr context, IntPtr pv );
+
+ /// <summary>C structure with a progress reporting function pointer</summary>
+ public struct sProgressSink
+ {
+ /// <summary>A callback to get notified about the progress</summary>
+ [MarshalAs( UnmanagedType.FunctionPtr )]
+ internal pfnReportProgress? pfn;
+
+ /// <summary>Last parameter to the callback</summary>
+ internal IntPtr pv;
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Library.cs b/WhisperNet/Library.cs
new file mode 100644
index 0000000..72ecb6e
--- /dev/null
+++ b/WhisperNet/Library.cs
@@ -0,0 +1,110 @@
+using ComLight;
+using System.Runtime.InteropServices;
+using System.Runtime.Intrinsics.X86;
+using Whisper.Internal;
+
+namespace Whisper
+{
+ /// <summary>Factory methods implemented by the C++ DLL</summary>
+ public static class Library
+ {
+ static Library()
+ {
+ if( Environment.OSVersion.Platform != PlatformID.Win32NT )
+ throw new ApplicationException( "This library requires Windows OS" );
+ if( !Environment.Is64BitProcess )
+ throw new ApplicationException( "This library only works in 64-bit processes" );
+ if( RuntimeInformation.ProcessArchitecture != Architecture.X64 )
+ throw new ApplicationException( "This library requires a processor with AMD64 instruction set" );
+ if( !Sse41.IsSupported )
+ throw new ApplicationException( "This library requires a CPU with SSE 4.1 support" );
+ NativeLogger.startup();
+ }
+
+ const string dll = "Whisper.dll";
+
+ [DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = false )]
+ internal static extern void setupLogger( [In] ref sLoggerSetup setup );
+
+ [DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )]
+ static extern int loadModel( [MarshalAs( UnmanagedType.LPWStr )] string path, eModelImplementation impl,
+ [In] ref sLoadModelCallbacks callbacks,
+ [MarshalAs( UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof( Marshaler<iModel> ) )] out iModel model );
+
+ /// <summary>Load Whisper model from GGML file on disk</summary>
+ /// <remarks>Models are large, depending on user’s disk speed this might take a while, and this function blocks the calling thread.<br/>
+ /// Consider <see cref="loadModelAsync" /> instead.</remarks>
+ /// <seealso href="https://huggingface.co/datasets/ggerganov/whisper.cpp" />
+ public static iModel loadModel( string path, eModelImplementation impl = eModelImplementation.GPU )
+ {
+ iModel model;
+ sLoadModelCallbacks callbacks = default;
+ NativeLogger.prologue();
+ int hr = loadModel( path, impl, ref callbacks, out model );
+ NativeLogger.throwForHR( hr );
+ return model;
+ }
+
+ /// <summary>Load Whisper model on a background thread, with optional progress reporting and cancellation</summary>
+ public static Task<iModel> loadModelAsync( string path, CancellationToken cancelToken, Action<double>? pfnProgress = null, eModelImplementation impl = eModelImplementation.GPU )
+ {
+ TaskCompletionSource<iModel> tcs = new TaskCompletionSource<iModel>();
+
+ WaitCallback wcb = delegate ( object? state )
+ {
+ try
+ {
+ sLoadModelCallbacks callbacks = new sLoadModelCallbacks( cancelToken, pfnProgress );
+
+ iModel model;
+ NativeLogger.prologue();
+ int hr = loadModel( path, impl, ref callbacks, out model );
+ NativeLogger.throwForHR( hr );
+
+ tcs.SetResult( model );
+ }
+ catch( Exception ex )
+ {
+ tcs.SetException( ex );
+ }
+ };
+
+ ThreadPool.QueueUserWorkItem( wcb );
+ return tcs.Task;
+ }
+
+ [DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )]
+ static extern int initMediaFoundation( [MarshalAs( UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof( Marshaler<iMediaFoundation> ) )] out iMediaFoundation mf );
+
+ /// <summary>Initialize Media Foundation runtime</summary>
+ public static iMediaFoundation initMediaFoundation()
+ {
+ iMediaFoundation mf;
+ NativeLogger.prologue();
+ int hr = initMediaFoundation( out mf );
+ NativeLogger.throwForHR( hr );
+ return mf;
+ }
+
+ // The .NET runtime uses UTF-16 for the strings, so we only need the Unicode version of this function.
+ // The native DLL exports both Unicode and ASCII versions.
+ [DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )]
+ static extern uint findLanguageKeyW( [MarshalAs( UnmanagedType.LPWStr )] string lang );
+
+ /// <summary>Try to resolve language code string like <c>"en"</c>, <c>"pl"</c> or <c>"uk"</c> into the strongly-typed enum.</summary>
+ /// <remarks>The function is case-sensitive, <c>"EN"</c> or <c>"UK"</c> gonna fail.</remarks>
+ public static eLanguage? languageFromCode( string lang )
+ {
+ uint key = findLanguageKeyW( lang );
+ if( key != uint.MaxValue )
+ return (eLanguage)key;
+ return null;
+ }
+
+ /// <summary>Set up delegate to receive log messages from the C++ library</summary>
+ public static void setLogSink( eLogLevel lvl, eLoggerFlags flags = eLoggerFlags.SkipFormatMessage, pfnLogMessage? pfn = null )
+ {
+ NativeLogger.setup( lvl, flags, pfn );
+ }
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/Readme.txt b/WhisperNet/Readme.txt
new file mode 100644
index 0000000..50b3bb3
--- /dev/null
+++ b/WhisperNet/Readme.txt
@@ -0,0 +1 @@
+This project builds .NET DLL which wraps Whisper.dll into idiomatic C# API. \ No newline at end of file
diff --git a/WhisperNet/WhisperNet.csproj b/WhisperNet/WhisperNet.csproj
new file mode 100644
index 0000000..105aa44
--- /dev/null
+++ b/WhisperNet/WhisperNet.csproj
@@ -0,0 +1,24 @@
+<Project Sdk="Microsoft.NET.Sdk">
+ <PropertyGroup>
+ <TargetFramework>net6.0-windows</TargetFramework>
+ <ImplicitUsings>enable</ImplicitUsings>
+ <Nullable>enable</Nullable>
+ <CheckForOverflowUnderflow>true</CheckForOverflowUnderflow>
+ <AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath>
+ <GenerateDocumentationFile>True</GenerateDocumentationFile>
+ <AllowUnsafeBlocks>True</AllowUnsafeBlocks>
+ <RootNamespace>Whisper</RootNamespace>
+ <GenerateAssemblyInfo>false</GenerateAssemblyInfo>
+ <PlatformTarget>x64</PlatformTarget>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)'=='Release'">
+ <GeneratePackageOnBuild>True</GeneratePackageOnBuild>
+ <NuspecFile>WhisperNet.nuspec</NuspecFile>
+ </PropertyGroup>
+ <ItemGroup>
+ <Content Include="..\x64\Release\Whisper.dll" Link="Whisper.dll" />
+ </ItemGroup>
+ <ItemGroup>
+ <PackageReference Include="ComLightInterop" Version="1.3.7" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/WhisperNet/WhisperNet.nuspec b/WhisperNet/WhisperNet.nuspec
new file mode 100644
index 0000000..d0a61f7
--- /dev/null
+++ b/WhisperNet/WhisperNet.nuspec
@@ -0,0 +1,27 @@
+<?xml version="1.0" encoding="utf-8"?>
+<package xmlns="http://schemas.microsoft.com/packaging/2013/05/nuspec.xsd">
+ <metadata>
+ <id>WhisperNet</id>
+ <version>1.0</version>
+ <authors>Konstantin, const.me</authors>
+ <license type="expression">MPL-2.0</license>
+ <projectUrl>https://github.com/Const-me/Whisper</projectUrl>
+ <description>High-performance GPGPU inference of OpenAI's Whisper automatic speech recognition (ASR) model</description>
+ <releaseNotes>Initial public version</releaseNotes>
+ <copyright>Copyright © const.me, 2022-2023</copyright>
+ <tags>whisper, gpgpu, speech recognition</tags>
+ <repository type="git" url="https://github.com/Const-me/Whisper.git" />
+ <dependencies>
+ <group targetFramework="net6.0">
+ <dependency id="ComLightInterop" version="1.3.7" />
+ </group>
+ </dependencies>
+ </metadata>
+ <files>
+ <!-- Managed DLL with XML documentation -->
+ <file src="bin/Release/WhisperNet.dll" target="lib/net6.0/" />
+ <file src="bin/Release/WhisperNet.xml" target="lib/net6.0/" />
+ <!-- The C++ DLL -->
+ <file src="../x64/Release/Whisper.dll" target="runtimes/win-x64/native/" />
+ </files>
+</package> \ No newline at end of file