From 5dbe2b08ec6918ef4d6a561eb9c39abf4a96808f Mon Sep 17 00:00:00 2001 From: xnpster Date: Sat, 31 Jan 2026 16:57:47 +0300 Subject: [PATCH] region merging: derive array types --- .../parse_merge_dirs.cpp | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/ParallelizationRegions/parse_merge_dirs.cpp b/src/ParallelizationRegions/parse_merge_dirs.cpp index 1701427..512a8a2 100644 --- a/src/ParallelizationRegions/parse_merge_dirs.cpp +++ b/src/ParallelizationRegions/parse_merge_dirs.cpp @@ -190,12 +190,61 @@ static pair, SgSymbol *> generateDeclaration(const string return {{decl, comm}, array_symbol}; } +static SgExpression* findExprWithVariant(SgExpression* exp, int variant) +{ + if (exp) + { + if (exp->variant() == variant) + return exp; + + auto *l = findExprWithVariant(exp->lhs(), variant); + if (l) + return l; + + auto *r = findExprWithVariant(exp->rhs(), variant); + if (r) + return r; + } + + return NULL; +} + +SgType* GetArrayType(DIST::Array *array) +{ + if (!array) + return NULL; + + for (const auto& decl_place : array->GetDeclInfo()) + { + if (SgFile::switchToFile(decl_place.first) != -1) + { + auto* decl = SgStatement::getStatementByFileAndLine(decl_place.first, decl_place.second); + if (decl) + { + for (int i = 0; i < 3; i++) + { + auto* found_type = isSgTypeExp(findExprWithVariant(decl->expr(i), TYPE_OP)); + if (found_type) + return found_type->type(); + } + } + } + } + + return NULL; +} + SgSymbol *insertDeclIfNeeded(const string &array_name, const string &common_block_name, DIST::Array *example_array, FuncInfo *dest, unordered_map> &inserted_arrays) { + auto *type = GetArrayType(example_array); + + if (!type) + printInternalError(convertFileName(__FILE__).c_str(), __LINE__); + if (SgFile::switchToFile(dest->fileName) == -1) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); @@ -218,7 +267,7 @@ SgSymbol *insertDeclIfNeeded(const string &array_name, auto generated = generateDeclaration(array_name, common_block_name, example_array->GetSizes(), - SgTypeInt(), dest->funcPointer); + type, dest->funcPointer); for (auto *new_stmt : generated.first) st->insertStmtBefore(*new_stmt, *dest->funcPointer);