diff --git a/src/ArrayConstantPropagation/propagation.cpp b/src/ArrayConstantPropagation/propagation.cpp index 95011b3..bf56a3f 100644 --- a/src/ArrayConstantPropagation/propagation.cpp +++ b/src/ArrayConstantPropagation/propagation.cpp @@ -34,7 +34,7 @@ static bool CheckConstIndexes(SgExpression* exp) return true; } -static SgExpression* CreateVar(int& variableNumber, SgType* type) +static SgExpression* CreateVar(int& variableNumber, SgType* type) { string varName = "__tmp_prop_var"; string name = varName + std::to_string(variableNumber) + "__"; @@ -42,17 +42,134 @@ static SgExpression* CreateVar(int& variableNumber, SgType* type) SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), *type, *declPlace->controlParent()); - SgVarDeclStmt* decl = varSymbol->makeVarDeclStmt(); - decl->setVariant(VAR_DECL); - SgStatement* insertPoint = declPlace; - SgStatement* scope = declPlace->controlParent(); - decl->setFileName(insertPoint->fileName()); - decl->setFileId(insertPoint->getFileId()); - decl->setProject(insertPoint->getProject()); - decl->setlineNumber(getNextNegativeLineNumber()); + const string commonBlockName = "__propagation_common__"; + + SgStatement* funcStart = declPlace->controlParent(); + SgStatement* commonStat = NULL; + SgExpression* commonList = NULL; + + SgStatement* funcEnd = funcStart->lastNodeOfStmt(); + SgStatement* current = funcStart->lexNext(); + + while (current != funcEnd && current) + { + if (current->variant() == COMM_STAT) + { + for (SgExpression* exp = current->expr(0); exp; exp = exp->rhs()) + { + if (exp->variant() == COMM_LIST) + { + string existingName = exp->symbol() ? + string(exp->symbol()->identifier()) : + string("spf_unnamed"); + if (existingName == commonBlockName) + { + commonStat = current; + commonList = exp; + break; + } + } + } + if (commonStat) + break; + } + current = current->lexNext(); + } + + vector varRefs; + if (commonList) + { + SgExpression* varList = commonList->lhs(); + if (varList) + { + auto extractSymbol = [](SgExpression* exp) -> SgSymbol* { + if (!exp) + return NULL; + if (exp->symbol()) + return exp->symbol(); + if (exp->lhs() && exp->lhs()->symbol()) + return exp->lhs()->symbol(); + return NULL; + }; + if (varList->variant() == EXPR_LIST) + { + for (SgExpression* exp = varList; exp; exp = exp->rhs()) + { + SgExpression* varExp = exp->lhs(); + SgSymbol* sym = extractSymbol(varExp); + if (sym) + { + varRefs.push_back(new SgVarRefExp(sym)); + } + } + } + else + { + for (SgExpression* varExp = varList; varExp; varExp = varExp->rhs()) + { + SgSymbol* sym = extractSymbol(varExp); + if (sym) + { + varRefs.push_back(new SgVarRefExp(sym)); + } + } + } + } + } + + if (!commonList) + { + current = funcStart->lexNext(); + while (current != funcEnd && current) + { + if (current->variant() == COMM_STAT) + { + commonStat = current; + break; + } + current = current->lexNext(); + } + + SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str()); + commonList = new SgExpression(COMM_LIST, NULL, NULL, commonSymbol); + + if (commonStat) + { + SgExpression* lastCommList = commonStat->expr(0); + if (lastCommList) + { + while (lastCommList->rhs()) + lastCommList = lastCommList->rhs(); + lastCommList->setRhs(commonList); + } + else + { + commonStat->setExpression(0, commonList); + } + } + else + { + commonStat = new SgStatement(COMM_STAT); + commonStat->setFileName(declPlace->fileName()); + commonStat->setFileId(declPlace->getFileId()); + commonStat->setProject(declPlace->getProject()); + commonStat->setlineNumber(getNextNegativeLineNumber()); + commonStat->setExpression(0, commonList); + + declPlace->insertStmtBefore(*commonStat, *declPlace->controlParent()); + } + + } + varRefs.push_back(new SgVarRefExp(varSymbol)); + + if (varRefs.size() > 0) + { + std::reverse(varRefs.begin(), varRefs.end()); + SgExpression* varList = makeExprList(varRefs, false); + + commonList->setLhs(varList); + } - insertPoint->insertStmtBefore(*decl, *scope); - return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr()); } @@ -113,7 +230,7 @@ static void TransformLeftPart(SgStatement* st, SgExpression* exp, unordered_map< string expUnparsed = exp->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType()) { - arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()); + arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); } SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, arrayToVariable[expUnparsed]->copyPtr(), st->expr(1)->copyPtr(), NULL); @@ -145,7 +262,6 @@ void ArrayConstantPropagation(SgProject& project) for (; st != lastNode; st = st->lexNext()) { - cout << st->unparse() << endl; if (st->variant() == ASSIGN_STAT) { if (st->expr(1)) @@ -180,13 +296,6 @@ void ArrayConstantPropagation(SgProject& project) } } } - /*st = file->functions(i); - for (; st != lastNode; st = st->lexNext()) - { - cout << st->unparse() << endl; - }*/ } - //FILE* unp = fopen(file->filename(), "w"); - //file->unparse(unp); } } \ No newline at end of file diff --git a/src/Utils/PassManager.h b/src/Utils/PassManager.h index c2a53f2..1a4f0e8 100644 --- a/src/Utils/PassManager.h +++ b/src/Utils/PassManager.h @@ -319,7 +319,6 @@ void InitPassesDependencies(map> &passDepsIn, set list({ CALL_GRAPH2, CALL_GRAPH, BUILD_IR, LOOP_GRAPH, LOOP_ANALYZER_DATA_DIST_S2 }) <= Pass(FIND_PRIVATE_ARRAYS_ANALYSIS); list({ FIND_PRIVATE_ARRAYS_ANALYSIS, CONVERT_LOOP_TO_ASSIGN, RESTORE_LOOP_FROM_ASSIGN, REVERT_SUBST_EXPR_RD }) <= Pass(FIND_PRIVATE_ARRAYS); - //Pass( CALL_GRAPH2 ) <= Pass(ARRAY_PROPAGATION); passesIgnoreStateDone.insert({ CREATE_PARALLEL_DIRS, INSERT_PARALLEL_DIRS, INSERT_SHADOW_DIRS, EXTRACT_PARALLEL_DIRS, EXTRACT_SHADOW_DIRS, CREATE_REMOTES, UNPARSE_FILE, REMOVE_AND_CALC_SHADOW,