@@ -1117,40 +1117,43 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1117
1117
}
1118
1118
1119
1119
// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
1120
- namespace {
1121
-
1122
- static std::unordered_map<int , void *> g_nz_workspace_map;
1123
- static std::unordered_map<int , size_t > g_nz_workspace_allocated_map;
1124
-
1125
- void release_nz_workspace (int device) {
1126
- auto it = g_nz_workspace_map.find (device);
1127
- if (it != g_nz_workspace_map.end () && it->second ) {
1128
- aclrtFree (it->second );
1129
- g_nz_workspace_map.erase (it);
1130
- g_nz_workspace_allocated_map.erase (device);
1120
+ class NzWorkspace {
1121
+ public:
1122
+ NzWorkspace () : ptr_(nullptr ), allocated_(0 ) {}
1123
+
1124
+ // 初始化 / 重置为无效
1125
+ void init () {
1126
+ if (ptr_) {
1127
+ aclrtFree (ptr_);
1128
+ ptr_ = nullptr ;
1129
+ allocated_ = 0 ;
1131
1130
}
1132
1131
}
1133
1132
1134
- void relloc_nz_workspace (int device, size_t new_size) {
1135
- void * &workspace = g_nz_workspace_map[device];
1136
- size_t &allocated = g_nz_workspace_allocated_map[device];
1137
-
1138
- if (new_size > allocated) {
1139
- if (workspace) {
1140
- aclrtFree (workspace);
1141
- workspace = nullptr ;
1142
- }
1143
- ACL_CHECK (aclrtMalloc (&workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1144
- allocated = new_size;
1133
+ void realloc (size_t new_size) {
1134
+ if (new_size > allocated_) {
1135
+ init ();
1136
+ ACL_CHECK (aclrtMalloc (&ptr_, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1137
+ allocated_ = new_size;
1145
1138
}
1146
1139
}
1147
1140
1148
- void * get_nz_workspace (int device) {
1149
- auto it = g_nz_workspace_map.find (device);
1150
- return (it != g_nz_workspace_map.end ()) ? it->second : nullptr ;
1141
+ void * get () const { return ptr_; }
1142
+
1143
+ private:
1144
+ void * ptr_;
1145
+ size_t allocated_;
1146
+ };
1147
+
1148
+ static std::array<NzWorkspace, GGML_CANN_MAX_DEVICES> g_nz_workspaces;
1149
+
1150
+ inline NzWorkspace& get_workspace (int device) {
1151
+ if (device < 0 || device >= static_cast <int >(g_nz_workspaces.size ())) {
1152
+ throw std::out_of_range (" device id out of range" );
1151
1153
}
1154
+ return g_nz_workspaces[device];
1155
+ }
1152
1156
1153
- } // namespace
1154
1157
1155
1158
/* *
1156
1159
* @brief Convert tensor weights to NZ format using Ascend CANN API.
@@ -1176,9 +1179,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device)
1176
1179
ACL_CHECK (aclnnTransMatmulWeightGetWorkspaceSize (weightTransposed,
1177
1180
&workspaceSize, &executor));
1178
1181
// Avoid frequent malloc/free of the workspace.
1179
- relloc_nz_workspace (device, workspaceSize);
1182
+ get_workspace (device). realloc ( workspaceSize);
1180
1183
1181
- void * g_nz_workspace = get_nz_workspace (device);
1184
+ void * g_nz_workspace = get_workspace (device). get ( );
1182
1185
1183
1186
ACL_CHECK (aclnnTransMatmulWeight (g_nz_workspace, workspaceSize, executor, nullptr ));
1184
1187
ACL_CHECK (aclDestroyTensor (weightTransposed));
@@ -2259,7 +2262,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
2259
2262
ggml_backend_cann_context* cann_ctx =
2260
2263
(ggml_backend_cann_context*)backend->context ;
2261
2264
ggml_cann_set_device (cann_ctx->device );
2262
- release_nz_workspace (cann_ctx->device );
2265
+ get_workspace (cann_ctx->device ). init ( );
2263
2266
2264
2267
#ifdef USE_ACL_GRAPH
2265
2268
bool use_cann_graph = true ;
0 commit comments