Skip to content

Commit c44cc66

Browse files
committed
refactory without namespace
Signed-off-by: noemotiovon <[email protected]>
1 parent 885bd4c commit c44cc66

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,40 +1117,43 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
11171117
}
11181118

11191119
// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
1120-
namespace {
1121-
1122-
static std::unordered_map<int, void*> g_nz_workspace_map;
1123-
static std::unordered_map<int, size_t> g_nz_workspace_allocated_map;
1124-
1125-
void release_nz_workspace(int device) {
1126-
auto it = g_nz_workspace_map.find(device);
1127-
if (it != g_nz_workspace_map.end() && it->second) {
1128-
aclrtFree(it->second);
1129-
g_nz_workspace_map.erase(it);
1130-
g_nz_workspace_allocated_map.erase(device);
1120+
class NzWorkspace {
1121+
public:
1122+
NzWorkspace() : ptr_(nullptr), allocated_(0) {}
1123+
1124+
// 初始化 / 重置为无效
1125+
void init() {
1126+
if (ptr_) {
1127+
aclrtFree(ptr_);
1128+
ptr_ = nullptr;
1129+
allocated_ = 0;
11311130
}
11321131
}
11331132

1134-
void relloc_nz_workspace(int device, size_t new_size) {
1135-
void* &workspace = g_nz_workspace_map[device];
1136-
size_t &allocated = g_nz_workspace_allocated_map[device];
1137-
1138-
if (new_size > allocated) {
1139-
if (workspace) {
1140-
aclrtFree(workspace);
1141-
workspace = nullptr;
1142-
}
1143-
ACL_CHECK(aclrtMalloc(&workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1144-
allocated = new_size;
1133+
void realloc(size_t new_size) {
1134+
if (new_size > allocated_) {
1135+
init();
1136+
ACL_CHECK(aclrtMalloc(&ptr_, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1137+
allocated_ = new_size;
11451138
}
11461139
}
11471140

1148-
void* get_nz_workspace(int device) {
1149-
auto it = g_nz_workspace_map.find(device);
1150-
return (it != g_nz_workspace_map.end()) ? it->second : nullptr;
1141+
void* get() const { return ptr_; }
1142+
1143+
private:
1144+
void* ptr_;
1145+
size_t allocated_;
1146+
};
1147+
1148+
static std::array<NzWorkspace, GGML_CANN_MAX_DEVICES> g_nz_workspaces;
1149+
1150+
inline NzWorkspace& get_workspace(int device) {
1151+
if (device < 0 || device >= static_cast<int>(g_nz_workspaces.size())) {
1152+
throw std::out_of_range("device id out of range");
11511153
}
1154+
return g_nz_workspaces[device];
1155+
}
11521156

1153-
} // namespace
11541157

11551158
/**
11561159
* @brief Convert tensor weights to NZ format using Ascend CANN API.
@@ -1176,9 +1179,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device)
11761179
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
11771180
&workspaceSize, &executor));
11781181
// Avoid frequent malloc/free of the workspace.
1179-
relloc_nz_workspace(device, workspaceSize);
1182+
get_workspace(device).realloc(workspaceSize);
11801183

1181-
void* g_nz_workspace = get_nz_workspace(device);
1184+
void* g_nz_workspace = get_workspace(device).get();
11821185

11831186
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
11841187
ACL_CHECK(aclDestroyTensor(weightTransposed));
@@ -2259,7 +2262,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
22592262
ggml_backend_cann_context* cann_ctx =
22602263
(ggml_backend_cann_context*)backend->context;
22612264
ggml_cann_set_device(cann_ctx->device);
2262-
release_nz_workspace(cann_ctx->device);
2265+
get_workspace(cann_ctx->device).init();
22632266

22642267
#ifdef USE_ACL_GRAPH
22652268
bool use_cann_graph = true;

0 commit comments

Comments
 (0)