From 8752f4a13988392aa6baf115756d0c3bfbeda12d Mon Sep 17 00:00:00 2001 From: xnpster Date: Wed, 1 Oct 2025 18:54:55 +0300 Subject: [PATCH] REMOVE_DIST_ARRAYS_FROM_IO: consider labels and goto statements while inserting copy statements --- .../replace_dist_arrays_in_io.cpp | 38 ++++++++++++++----- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/src/Transformations/ReplaceArraysInIO/replace_dist_arrays_in_io.cpp b/src/Transformations/ReplaceArraysInIO/replace_dist_arrays_in_io.cpp index c4da117..021e797 100644 --- a/src/Transformations/ReplaceArraysInIO/replace_dist_arrays_in_io.cpp +++ b/src/Transformations/ReplaceArraysInIO/replace_dist_arrays_in_io.cpp @@ -172,7 +172,7 @@ static void findArrays(SgExpression* exp, set& arrays) } } -static void populateDistributedIoArrays(map>& arrays, +static bool populateDistributedIoArrays(map>& arrays, SgStatement* stat, const string& current_file_name, FuncInfo *current_func) @@ -180,7 +180,7 @@ static void populateDistributedIoArrays(map>& array auto var = stat->variant(); if (var != READ_STAT && var != PRINT_STAT && var != WRITE_STAT) - return; + return false; // check if such IO allowed in dvm: // list should consist only of single array and format string should be * @@ -190,19 +190,19 @@ static void populateDistributedIoArrays(map>& array SgExpression* ioList = stat->expr(0); if (!ioList) - return; + return false; if (ioList->variant() != EXPR_LIST) - return; + return false; if (ioList->rhs() == NULL) { SgExpression* arg = ioList->lhs(); if (!arg) - return; + return false; if (!isArrayRef(arg)) - return; + return false; if (arg->lhs()) need_replace = true; @@ -225,7 +225,6 @@ static void populateDistributedIoArrays(map>& array if (fmt->rhs()->variant() != KEYWORD_VAL || fmt->rhs()->sunparse() != "*") need_replace = true; - break; } case READ_STAT: @@ -266,7 +265,9 @@ static void populateDistributedIoArrays(map>& array } if (!need_replace) - return; + return false; + + bool ret = false; set found_arrays; @@ -285,10 +286,13 @@ static void populateDistributedIoArrays(map>& array if (inserted) __spf_print(DEBUG_TRACE, "[%d]: add array %s %p\n", stat->lineNumber(), array_p->GetName().c_str(), by_symb); + + ret = true; } } __spf_print(DEBUG_TRACE, "[replace]\n"); + return ret; } static void replaceArrayRec(SgSymbol* arr, SgSymbol* replace_by, SgExpression* exp, bool& has_read, bool& has_write, bool from_read, bool from_write) @@ -506,12 +510,16 @@ static bool ioReginBorder(SgStatement* stat, SgStatement* last_io_bound) STOP_STAT, STOP_NODE, EXIT_STMT, - EXIT_NODE + EXIT_NODE, + GOTO_NODE }; if (border_stats.find(var) != border_stats.end()) return true; + if (stat->hasLabel()) + return true; + if (last_io_bound && last_io_bound->lastNodeOfStmt() && last_io_bound->lastNodeOfStmt() == stat) return true; @@ -837,7 +845,17 @@ void replaceDistributedArraysInIO(vector& regions, } } - populateDistributedIoArrays(need_replace, curr_stmt, current_file_name, current_func_info); + auto need_fix_io = populateDistributedIoArrays(need_replace, curr_stmt, current_file_name, current_func_info); + + // incorrect IO statement with label + // move label to dummy statement and insert copy statements between dummy statement and IO + if (need_fix_io && curr_stmt->hasLabel()) + { + moveLabelBefore(curr_stmt); + if (last_io_bound == curr_stmt) // always true + last_io_bound = curr_stmt->lexPrev(); + } + curr_stmt = curr_stmt->lexNext(); } }