-
Notifications
You must be signed in to change notification settings - Fork 547
Expand file tree
/
Copy pathWeightsContextMemoryMap.cpp
More file actions
166 lines (136 loc) · 3.75 KB
/
Copy pathWeightsContextMemoryMap.cpp
File metadata and controls
166 lines (136 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "WeightsContext.hpp"
#include <fstream>
#ifdef _WIN32
#include <windows.h>
#else
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
namespace onnx2trt
{
namespace
{
//! \return The size of the file in bytes, or -1 if the file does not exist.
[[nodiscard]] int64_t getFileSize(std::string const& file)
{
std::ifstream fileStream(file, std::ios::binary);
if (!fileStream)
{
return -1L;
}
fileStream.seekg(0, std::ios::end);
std::streamsize fileSize = fileStream.tellg();
return static_cast<int64_t>(fileSize);
}
} // namespace
#ifdef _WIN32
WeightsContext::MemoryMapping_t WeightsContext::mmap(std::string const& file)
{
auto* ctx = this; // For logging macros.
auto it = mMemoryMappings.find(file);
if (it != mMemoryMappings.end())
{
return it->second;
}
int64_t fileSize = getFileSize(file);
if (fileSize < 0L)
{
LOG_ERROR("Failed to open file: " << file);
return {nullptr, -1L};
}
FileHandle fd
= CreateFileA(file.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (fd == INVALID_HANDLE_VALUE)
{
LOG_ERROR("Failed to open file: " << file);
return {nullptr, -1L};
}
FileHandle mappingHandle = CreateFileMapping(fd, nullptr, PAGE_READONLY, 0U, 0U, nullptr);
if (mappingHandle == INVALID_HANDLE_VALUE)
{
LOG_ERROR("Failed to map file to memory: " << file);
CloseHandle(fd);
return {nullptr, -1L};
}
auto const mappedAddr = MapViewOfFile(mappingHandle, FILE_MAP_READ, 0U, 0U, 0U);
if (mappedAddr == nullptr)
{
LOG_ERROR("Failed to map file to memory: " << file);
CloseHandle(fd);
CloseHandle(mappingHandle);
return {nullptr, -1L};
}
mMappedFiles[file] = fd;
mFileMappingHandles[file] = mappingHandle;
mMemoryMappings[file] = std::make_pair(mappedAddr, fileSize);
return std::make_pair(mappedAddr, fileSize);
}
void WeightsContext::clearMemoryMappings()
{
for (auto const& [file, mapping] : mMemoryMappings)
{
UnmapViewOfFile(mapping.first);
}
for (auto const& [file, fd] : mFileMappingHandles)
{
CloseHandle(fd);
}
for (auto const& [file, fd] : mMappedFiles)
{
CloseHandle(fd);
}
mMappedFiles.clear();
mFileMappingHandles.clear();
mMemoryMappings.clear();
}
#else
WeightsContext::MemoryMapping_t WeightsContext::mmap(std::string const& file)
{
auto* const ctx = this; // For logging macros.
if (auto it = mMemoryMappings.find(file); it != mMemoryMappings.end())
{
return it->second;
}
int64_t const fileSize = getFileSize(file);
if (fileSize < 0L)
{
LOG_ERROR("Failed to open file: " << file);
return {nullptr, -1L};
}
FileHandle fd = open(file.c_str(), O_RDONLY);
if (fd == -1L)
{
LOG_ERROR("Failed to open file: " << file);
return {nullptr, -1L};
}
void* mappedAddr = ::mmap(nullptr, fileSize, PROT_READ, MAP_PRIVATE, fd, 0);
if (mappedAddr == MAP_FAILED)
{
LOG_ERROR("Failed to map file to memory: " << file);
close(fd);
return {nullptr, -1L};
}
mMappedFiles[file] = fd;
auto it = mMemoryMappings.insert_or_assign(file, MemoryMapping_t{mappedAddr, fileSize}).first;
return it->second;
}
void WeightsContext::clearMemoryMappings()
{
for (auto const& [file, mapping] : mMemoryMappings)
{
::munmap(mapping.first, mapping.second);
}
for (auto const& [file, fd] : mMappedFiles)
{
close(fd);
}
mMappedFiles.clear();
mMemoryMappings.clear();
}
#endif
} // namespace onnx2trt