replace_dist_arrays_in_io: insert only necessary copy statements

This commit is contained in:
2024-10-31 00:03:32 +03:00
committed by ALEXks
parent b1faa5e80a
commit 413daa2aea

View File

@@ -126,16 +126,70 @@ static void populateDistributedIoArrays(map<DIST::Array*, set<SgStatement*>>& ar
__spf_print(DEBUG_TRACE, "[replace]\n"); __spf_print(DEBUG_TRACE, "[replace]\n");
} }
static void replaceArrayRec(SgSymbol* arr, SgSymbol* replace_by, SgExpression* exp) static void replaceArrayRec(SgSymbol* arr, SgSymbol* replace_by, SgExpression* exp, bool& has_read, bool& has_write, bool from_read, bool from_write)
{ {
if (!exp) if (!exp)
return; return;
if (exp->symbol() && strcmp(exp->symbol()->identifier(), arr->identifier()) == 0) if (exp->symbol() && strcmp(exp->symbol()->identifier(), arr->identifier()) == 0)
{
has_read |= from_read;
has_write |= from_write;
exp->setSymbol(replace_by); exp->setSymbol(replace_by);
}
replaceArrayRec(arr, replace_by, exp->lhs()); switch (exp->variant())
replaceArrayRec(arr, replace_by, exp->rhs()); {
case FUNC_CALL:
{
replaceArrayRec(arr, replace_by, exp->rhs(), has_read, has_write, true, false);
replaceArrayRec(arr, replace_by, exp->lhs(), has_read, has_write, true, true);
break;
}
case EXPR_LIST:
{
replaceArrayRec(arr, replace_by, exp->lhs(), has_read, has_write, from_read, from_write);
replaceArrayRec(arr, replace_by, exp->rhs(), has_read, has_write, from_read, from_write);
break;
}
default:
{
replaceArrayRec(arr, replace_by, exp->lhs(), has_read, has_write, true, false);
replaceArrayRec(arr, replace_by, exp->rhs(), has_read, has_write, true, false);
break;
}
}
}
static void replaceArrayRec(SgSymbol* arr, SgSymbol* replace_by, SgStatement* st, bool& has_read, bool& has_write)
{
if (!st)
return;
switch (st->variant())
{
case ASSIGN_STAT:
case READ_STAT:
{
replaceArrayRec(arr, replace_by, st->expr(0), has_read, has_write, false, true);
replaceArrayRec(arr, replace_by, st->expr(1), has_read, has_write, true, false);
break;
}
case PROC_STAT:
case FUNC_STAT:
{
replaceArrayRec(arr, replace_by, st->expr(0), has_read, has_write, true, false);
replaceArrayRec(arr, replace_by, st->expr(1), has_read, has_write, true, true);
break;
}
default:
{
for (int i = 0; i < 3; i++)
replaceArrayRec(arr, replace_by, st->expr(i), has_read, has_write, true, false);
break;
}
}
} }
static void copyArrayBetweenStatements(SgSymbol* replace_symb, SgSymbol* replace_by, SgStatement* start, SgStatement* last) static void copyArrayBetweenStatements(SgSymbol* replace_symb, SgSymbol* replace_by, SgStatement* start, SgStatement* last)
@@ -144,24 +198,33 @@ static void copyArrayBetweenStatements(SgSymbol* replace_symb, SgSymbol* replace
start = start->lexNext(); start = start->lexNext();
auto* stop = last->lexNext(); auto* stop = last->lexNext();
bool has_read = false, has_write = false;
for (auto* st = start; st != stop; st = st->lexNext()) for (auto* st = start; st != stop; st = st->lexNext())
for (int i = 0; i < 3; i++) replaceArrayRec(replace_symb, replace_by, st, has_read, has_write);
replaceArrayRec(replace_symb, replace_by, st->expr(i));
// A_copy = A
SgAssignStmt* assign = new SgAssignStmt(*new SgArrayRefExp(*replace_by), *new SgArrayRefExp(*replace_symb));
assign->setlineNumber(getNextNegativeLineNumber()); // before region
auto* parent = start->controlParent();
if (parent && parent->lastNodeOfStmt() == start)
parent = parent->controlParent();
start->insertStmtAfter(*assign, *parent); if (has_read)
{
// A_copy = A
SgAssignStmt* assign = new SgAssignStmt(*new SgArrayRefExp(*replace_by), *new SgArrayRefExp(*replace_symb));
assign->setlineNumber(getNextNegativeLineNumber()); // before region
auto* parent = start->controlParent();
if (parent && parent->lastNodeOfStmt() == start)
parent = parent->controlParent();
// A = A_reg start->insertStmtAfter(*assign, *parent);
assign = new SgAssignStmt(*new SgArrayRefExp(*replace_symb), *new SgArrayRefExp(*replace_by)); }
//TODO: bug with insertion
//assign->setlineNumber(getNextNegativeLineNumber()); // after region if (has_write)
last->insertStmtBefore(*assign, *(last->controlParent())); {
// A = A_reg
SgAssignStmt* assign = new SgAssignStmt(*new SgArrayRefExp(*replace_symb), *new SgArrayRefExp(*replace_by));
//TODO: bug with insertion
//assign->setlineNumber(getNextNegativeLineNumber()); // after region
last->insertStmtBefore(*assign, *(last->controlParent()));
}
} }
static void replaceArrayInFragment(DIST::Array* arr, const set<SgStatement*> usages, SgSymbol* replace_by, SgStatement* start, SgStatement* last, const string& filename) static void replaceArrayInFragment(DIST::Array* arr, const set<SgStatement*> usages, SgSymbol* replace_by, SgStatement* start, SgStatement* last, const string& filename)