#include "propagation.h" #include "../Utils/SgUtils.h" #include #include #include #include #include using namespace std; static SgStatement* declPlace = NULL; static set changed; static map variablesToAdd; static map> positionsToAdd; static map arrayToName; static set statementsToRemove; static map> expToChange; static bool CheckConstIndexes(SgExpression* exp) { if (!exp) { return false; } SgExpression* lhs = exp->lhs(); SgExpression* rhs = exp->rhs(); do { if (lhs && lhs->variant() != INT_VAL) { return false; } if (rhs) { lhs = rhs->lhs(); rhs = rhs->rhs(); } } while (rhs); return true; } static SgExpression* CreateVar(int& variableNumber, SgType* type) { string varName = "tmp_prop_var"; string name = varName + std::to_string(variableNumber) + "__"; variableNumber++; SgStatement* funcStart = declPlace->controlParent(); SgSymbol* varSymbol = new SgSymbol(VARIABLE_NAME, name.c_str(), SgTypeInt(), funcStart); variablesToAdd[name] = varSymbol; positionsToAdd[string(declPlace->fileName())].insert(declPlace); return new SgExpression(VAR_REF, NULL, NULL, varSymbol, type->copyPtr()); } static SgStatement* FindLastDeclStatement(SgStatement* funcStart) { if (!funcStart) return NULL; SgStatement* endSt = funcStart->lastNodeOfStmt(); SgStatement* cur = funcStart->lexNext(); SgStatement* lastDecl = funcStart; const set declVariants = { VAR_DECL, VAR_DECL_90, ALLOCATABLE_STMT, DIM_STAT, EXTERN_STAT, COMM_STAT, HPF_TEMPLATE_STAT, DVM_VAR_DECL, STRUCT_DECL }; while (cur && cur != endSt) { if (cur->variant() == INTERFACE_STMT) cur = cur->lastNodeOfStmt(); if (declVariants.find(cur->variant()) != declVariants.end()) lastDecl = cur; else if (isSgExecutableStatement(cur)) break; cur = cur->lexNext(); } return lastDecl; } static void InsertCommonAndDeclsForFunction(SgStatement* funcStart, const map& symbols) { if (symbols.empty()) return; if (!funcStart) return; const string commonBlockName = "propagation_common__"; SgStatement* funcEnd = funcStart->lastNodeOfStmt(); SgStatement* commonStat = NULL; SgExpression* commonList = NULL; for (SgStatement* cur = funcStart->lexNext(); cur && cur != funcEnd; cur = cur->lexNext()) { if (cur->variant() != COMM_STAT) continue; for (SgExpression* exp = cur->expr(0); exp; exp = exp->rhs()) { if (exp->variant() != COMM_LIST) continue; const char* id = exp->symbol() ? exp->symbol()->identifier() : NULL; string existingName = id ? string(id) : string("spf_unnamed"); if (existingName == commonBlockName) { commonStat = cur; commonList = exp; break; } } if (commonStat) break; } vector varRefs; for (const auto& [name, sym] : symbols) { if (!sym || sym->variant() != VARIABLE_NAME || string(sym->identifier()) == commonBlockName) continue; SgSymbol* symToAdd = new SgSymbol(VARIABLE_NAME, name.c_str(), SgTypeInt(), funcStart); varRefs.push_back(new SgVarRefExp(symToAdd)); } SgExpression* varList = makeExprList(varRefs, false); SgStatement* insertAfter = FindLastDeclStatement(funcStart); for (const auto& [name, sym] : symbols) { if (!sym) continue; SgStatement* declStmt = sym->makeVarDeclStmt(); if (!declStmt) continue; if (SgVarDeclStmt* vds = isSgVarDeclStmt(declStmt)) vds->setVariant(VAR_DECL_90); declStmt->setFileName(funcStart->fileName()); declStmt->setFileId(funcStart->getFileId()); declStmt->setProject(funcStart->getProject()); declStmt->setlineNumber(getNextNegativeLineNumber()); insertAfter->insertStmtAfter(*declStmt, *funcStart); insertAfter = declStmt; statementsToRemove.insert(declStmt); } if (!commonList) { SgSymbol* commonSymbol = new SgSymbol(COMMON_NAME, commonBlockName.c_str()); commonList = new SgExpression(COMM_LIST, varList, NULL, commonSymbol); commonStat = new SgStatement(COMM_STAT); commonStat->setFileName(funcStart->fileName()); commonStat->setFileId(funcStart->getFileId()); commonStat->setProject(funcStart->getProject()); commonStat->setlineNumber(getNextNegativeLineNumber()); commonStat->setExpression(0, commonList); SgStatement* lastDecl = FindLastDeclStatement(funcStart); lastDecl->insertStmtAfter(*commonStat, *funcStart); statementsToRemove.insert(commonStat); } else { commonList->setLhs(varList); } } static void copyStatement(SgStatement* st) { if (!st) return; if (expToChange[st->fileName()].find(st) == expToChange[st->fileName()].end()) { SgStatement* boundCopy = st->copyPtr(); for (int i = 0; i < 3; i++) { SgExpression* expCopy = st->expr(i); if (expCopy) boundCopy->setExpression(i, expCopy->copyPtr()); else boundCopy->setExpression(i, NULL); } expToChange[st->fileName()][st] = boundCopy; } } static bool TransformRightPart(SgStatement* st, SgExpression* exp, map& arrayToVariable, int& variableNumber) { if (!exp) return false; bool isChanged = false; vector subnodes = { exp->lhs(), exp->rhs() }; string expUnparsed; SgExpression* toAdd = NULL; if (isArrayRef(exp) && CheckConstIndexes(exp->lhs())) { expUnparsed = exp->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp && exp->symbol() && exp->symbol()->type() && exp->symbol()->type()->baseType()) { arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); arrayToName[expUnparsed] = arrayToVariable[expUnparsed]->unparse(); } positionsToAdd[string(declPlace->fileName())].insert(declPlace); auto* sym = new SgSymbol(VARIABLE_NAME, arrayToName[expUnparsed].c_str(), SgTypeInt(), declPlace->controlParent()); auto* newVarExp = new SgVarRefExp(sym); copyStatement(st); st->setExpression(1, newVarExp); return true; } for (int i = 0; i < 2; i++) { if (subnodes[i] && isArrayRef(subnodes[i]) && subnodes[i]->symbol() && subnodes[i]->symbol()->type() && subnodes[i]->symbol()->type()->baseType() && CheckConstIndexes(subnodes[i]->lhs())) { isChanged = true; expUnparsed = subnodes[i]->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end()) { arrayToVariable[expUnparsed] = CreateVar(variableNumber, subnodes[i]->symbol()->type()->baseType()); arrayToName[expUnparsed] = arrayToVariable[expUnparsed]->unparse(); } positionsToAdd[string(declPlace->fileName())].insert(declPlace); SgSymbol* builder = arrayToVariable[expUnparsed]->symbol(); auto* sym = new SgSymbol(VARIABLE_NAME, arrayToName[expUnparsed].c_str(), SgTypeInt(), declPlace->controlParent()); toAdd = new SgVarRefExp(sym); if (toAdd) { copyStatement(st); if (i == 0) exp->setLhs(toAdd); else exp->setRhs(toAdd); } } else isChanged = isChanged || TransformRightPart(st, subnodes[i], arrayToVariable, variableNumber); } return isChanged; } static void TransformLeftPart(SgStatement* st, SgExpression* exp, map& arrayToVariable, int& variableNumber) { if (!st || !st->expr(1)) return; if (!exp || !exp->symbol() || !exp->symbol()->type() || !exp->symbol()->type()->baseType()) return; if (exp->symbol()->type()->variant() == T_STRING) return; if (changed.find(st) != changed.end()) return; string expUnparsed = exp->unparse(); if (arrayToVariable.find(expUnparsed) == arrayToVariable.end() && exp->symbol()->type()->baseType()) { arrayToVariable[expUnparsed] = CreateVar(variableNumber, exp->symbol()->type()->baseType()); arrayToName[expUnparsed] = arrayToVariable[expUnparsed]->unparse(); } positionsToAdd[string(declPlace->fileName())].insert(declPlace); auto* sym = new SgSymbol(VARIABLE_NAME, arrayToName[expUnparsed].c_str(), SgTypeInt(), declPlace->controlParent()); auto* newVarExp = new SgVarRefExp(sym); SgStatement* newStatement = new SgStatement(ASSIGN_STAT, NULL, NULL, newVarExp, st->expr(1)->copyPtr(), NULL); newStatement->setFileId(st->getFileId()); newStatement->setProject(st->getProject()); st->insertStmtBefore(*newStatement, *st->controlParent()); newStatement->setlineNumber(getNextNegativeLineNumber()); newStatement->setLocalLineNumber(st->lineNumber()); changed.insert(st); statementsToRemove.insert(newStatement); } static void TransformBorder(SgStatement* st, SgExpression* exp, map& arrayToVariable, int& variableNumber) { if (!st || !exp) return; SgStatement* firstStatement = declPlace->lexPrev(); positionsToAdd[string(declPlace->fileName())].insert(declPlace); TransformRightPart(st, exp, arrayToVariable, variableNumber); st = st->lexPrev(); while (st &&st != firstStatement) { if (st->variant() == ASSIGN_STAT) { if (st->expr(1)) { TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber); } if (st->expr(0) && isArrayRef(st->expr(0)) && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); } st = st->lexPrev(); } } static void CheckVariable(SgStatement* st, SgExpression* exp, map& arrayToVariable, int& variableNumber) { SgStatement* firstStatement = declPlace->lexPrev(); st = st->lexPrev(); while (st != firstStatement) { if (st->variant() == ASSIGN_STAT && st->expr(0)->symbol() == exp->symbol()) { if (TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber)) { positionsToAdd[string(declPlace->fileName())].insert(declPlace); } } if (st->variant() == ASSIGN_STAT && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { if (st->expr(1)) { if(TransformRightPart(st, st->expr(1), arrayToVariable, variableNumber)) { positionsToAdd[string(declPlace->fileName())].insert(declPlace); } } if (st->expr(0) && isArrayRef(st->expr(0)) && CheckConstIndexes(st->expr(0)->lhs()) && arrayToVariable.find(st->expr(0)->unparse()) != arrayToVariable.end()) { TransformLeftPart(st, st->expr(0), arrayToVariable, variableNumber); positionsToAdd[string(declPlace->fileName())].insert(declPlace); } } st = st->lexPrev(); } } static void findConstValues( SgProject& project, const map>& borderVars, const map& arrayToVariable, map& hitCount, map>>>& result) { for (int i = 0; i < project.numberOfFiles(); i++) { SgFile* file = &(project.file(i)); if (!file) continue; SgFile::switchToFile(file->filename()); const int funcNum = file->numberOfFunctions(); for (int i = 0; i < funcNum; ++i) { SgStatement* st = file->functions(i); SgStatement* lastNode = st->lastNodeOfStmt(); if (!st) continue; for (; st != lastNode; st = st->lexNext()) { if (st && st->variant() == ASSIGN_STAT) { if (!st->expr(0) || !st->expr(1)) continue; SgExpression* lhs = st->expr(0); SgExpression* rhs = st->expr(1); auto varIt = arrayToVariable.find(lhs->unparse()); string varName = (varIt != arrayToVariable.end()) ? varIt->second->unparse() : lhs->unparse(); if (rhs->variant() == INT_VAL) hitCount[string(lhs->unparse())]++; for (const auto& [filename, names] : borderVars) { if(names.find(string(lhs->unparse())) != names.end() && rhs->variant() == INT_VAL) result[filename][names.at(lhs->unparse())].push_back({ varName, rhs->unparse()}); } } } } } } static void insertDefinition(map>>>& definitions, map& hitCount) { for (const auto& [filename, variables] : definitions) { if (SgFile::switchToFile(filename) == -1) continue; for (const auto& [statement, values] : variables) { if (!statement) continue; SgStatement* insertBefore = statement, *st = statement; while (st && !isSgExecutableStatement(st)) { st = st->lexNext(); insertBefore = st; } for (const auto& [varName, value] : values) { if (hitCount.find(varName) == hitCount.end() || hitCount[varName] > 1) continue; SgSymbol* sym = new SgSymbol(VARIABLE_NAME, varName.c_str(), SgTypeInt(), statement); SgExpression* lhs = new SgVarRefExp(sym); SgExpression* rhs = new SgValueExp(stoi(value)); SgStatement* asg = new SgStatement(ASSIGN_STAT, NULL, NULL, lhs, rhs, NULL); asg->setFileName(statement->fileName()); asg->setFileId(statement->getFileId()); asg->setProject(statement->getProject()); asg->setlineNumber(getNextNegativeLineNumber()); if (insertBefore && insertBefore->controlParent()) { insertBefore->insertStmtBefore(*asg, *insertBefore->controlParent()); statementsToRemove.insert(asg); } } } } } static void applyLeftPartForUnchangedAssignments(SgProject& project, map& arrayToVariable, int& variableNumber) { for (int fi = 0; fi < project.numberOfFiles(); ++fi) { SgFile* file = &(project.file(fi)); if (!file) continue; const string fileName = file->filename(); if (SgFile::switchToFile(fileName) == -1) continue; const int funcNum = file->numberOfFunctions(); for (int fni = 0; fni < funcNum; ++fni) { SgStatement* funcStart = file->functions(fni); if (!funcStart) continue; declPlace = funcStart; positionsToAdd[string(declPlace->fileName())].insert(declPlace); SgStatement* endSt = funcStart->lastNodeOfStmt(); for (SgStatement* st = funcStart; st && st != endSt; st = st->lexNext()) { if (st->variant() != ASSIGN_STAT) continue; if (!st->expr(0) || !st->expr(1)) continue; if (changed.find(st) != changed.end()) continue; SgExpression* lhs = st->expr(0); if (!isArrayRef(lhs)) continue; if (!lhs->symbol() || !lhs->symbol()->type()) continue; if (!lhs->symbol()->type()->baseType()) continue; if (!CheckConstIndexes(lhs->lhs())) continue; const string lhsUnparsed = lhs->unparse(); if (arrayToVariable.find(lhsUnparsed) == arrayToVariable.end()) continue; TransformLeftPart(st, lhs, arrayToVariable, variableNumber); } } } } static bool ContainsArrayRefRecursive(SgExpression* exp) { if (!exp) return false; if (isArrayRef(exp) && CheckConstIndexes(exp->lhs())) return true; return ContainsArrayRefRecursive(exp->lhs()) || ContainsArrayRefRecursive(exp->rhs()); } static void getBorderVars(SgExpression* exp, const string& filename, map>& borderVars) { if (!exp) return; if ((isArrayRef(exp) && CheckConstIndexes(exp->lhs())) || exp->variant() == VAR_REF) borderVars[filename][string(exp->unparse())] = declPlace; getBorderVars(exp->lhs(), filename, borderVars); getBorderVars(exp->rhs(), filename, borderVars); } static void processLoopBound( SgStatement* st, SgExpression* bound, const string& boundUnparsed, bool isUpperBound, map& arrayToVariable, map>& borderVars, int& variableNumber) { if (!bound || !st) return; SgExpression* exp = isUpperBound ? bound->rhs() : bound->lhs(); getBorderVars(exp, st->fileName(), borderVars); if (ContainsArrayRefRecursive(exp), borderVars, st->fileName()) { copyStatement(st); TransformBorder(st, bound, arrayToVariable, variableNumber); positionsToAdd[string(declPlace->fileName())].insert(declPlace); } else if (bound->variant() == VAR_REF) CheckVariable(st, bound, arrayToVariable, variableNumber); } void arrayConstantPropagation(SgProject& project) { map arrayToVariable; map> borderVars; int variableNumber = 0; for (int i = 0; i < project.numberOfFiles(); i++) { SgFile* file = &(project.file(i)); if (!file) continue; SgFile::switchToFile(file->filename()); const int funcNum = file->numberOfFunctions(); for (int i = 0; i < funcNum; ++i) { SgStatement* st = file->functions(i); if (!st) continue; declPlace = st; SgStatement* lastNode = st->lastNodeOfStmt(); for (; st != lastNode; st = st->lexNext()) { if (st && st->variant() == FOR_NODE) { if (!st->expr(0)) continue; if (!st->expr(0)->lhs() || !st->expr(0)->rhs()) continue; SgExpression* lowerBound = st->expr(0)->lhs(); SgExpression* upperBound = st->expr(0)->rhs(); string lowerBoundUnparsed = lowerBound->unparse(); string upperBoundUnparsed = upperBound->unparse(); processLoopBound(st, st->expr(0), upperBoundUnparsed, true, arrayToVariable, borderVars, variableNumber); processLoopBound(st, st->expr(0), lowerBoundUnparsed, false, arrayToVariable, borderVars, variableNumber); } } } } applyLeftPartForUnchangedAssignments(project, arrayToVariable, variableNumber); map> funcStarts; for (const auto& [fileName, statements] : positionsToAdd) { int res = SgFile::switchToFile(fileName); if (res == -1) continue; for (SgStatement* st : statements) { SgStatement* scope = isSgProgHedrStmt(st) ? st : st->controlParent(); if (scope) funcStarts[fileName].insert(scope); } } for (const auto& [fileName, statements] : funcStarts) { SgFile::switchToFile(fileName); for (SgStatement* st : statements) { InsertCommonAndDeclsForFunction(st, variablesToAdd); } } map>>> result; map hitCount; findConstValues(project, borderVars, arrayToVariable, hitCount, result); insertDefinition(result, hitCount); } void restoreArrays() { cout << "ARRAY_PROPAGATION_RESTORE" << endl; for (auto& [filename, statements] : expToChange) { if (SgFile::switchToFile(filename) == -1) continue; for (auto& [statement, statementCopy] : statements) { if (statement && statementCopy) { for (int i = 0; i < 3; i++) { statement->setExpression(i, statementCopy->expr(i)); } } } } for (SgStatement* st : statementsToRemove) { SgFile::switchToFile(st->fileName()); st->deleteStmt(); } }