Files
SAPFOR/Sapfor/_src/Transformations/loops_combiner.cpp

1837 lines
60 KiB
C++
Raw Normal View History

2023-09-14 19:43:13 +03:00
#include "loops_combiner.h"
#include "../LoopAnalyzer/loop_analyzer.h"
#include "../ExpressionTransform/expr_transform.h"
#include "../Utils/errors.h"
#include "../Utils/SgUtils.h"
#include <string>
#include <vector>
#include <queue>
using std::string;
using std::vector;
using std::map;
using std::set;
using std::pair;
using std::make_pair;
using std::queue;
using std::wstring;
static int gcd(int a, int b)
{
while (a != b)
{
if (a > b)
a = a - b;
else
b = b - a;
}
return a;
}
static SgSymbol* getLoopSymbol(const LoopGraph* loop)
{
if (!loop || !loop->isFor)
return NULL;
SgForStmt* stmt = (SgForStmt*)loop->loop->GetOriginal();
return stmt->doName();
}
static void fillIterationVariables(const LoopGraph* loop, set<SgSymbol*>& vars, int dimensions = -1)
{
if (dimensions == -1)
{
auto s = getLoopSymbol(loop);
if (s)
vars.insert(s);
for (LoopGraph* child : loop->children)
fillIterationVariables(child, vars);
}
else
{
for (int i = 0; i < dimensions; ++i)
{
auto s = getLoopSymbol(loop);
if (s)
vars.insert(s);
if (i != dimensions - 1)
loop = loop->children[0];
}
}
}
static void eraseSymbolFromSet(set<SgSymbol*>& symbols, SgSymbol* symbol)
{
SgSymbol* toDelete = NULL;
for (SgSymbol* elem : symbols)
{
if (isEqSymbols(elem, symbol))
{
toDelete = elem;
break;
}
}
if (toDelete)
symbols.erase(toDelete);
}
static bool isSymbolInSet(const set<SgSymbol*>& symbols, SgSymbol* symbol)
{
for (SgSymbol* elem : symbols)
if (isEqSymbols(elem, symbol))
return true;
return false;
}
static void getIntersection(const set<SgSymbol*>& firstSet, const set<SgSymbol*>& secondSet, set<SgSymbol*>& intersection)
{
for (SgSymbol* var1 : firstSet)
{
for (SgSymbol* var2 : secondSet)
{
if (isEqSymbols(var1, var2))
{
intersection.insert(var1);
break;
}
}
}
}
static bool hasGotoToStatement(SgStatement* stmt)
{
if (!stmt->hasLabel())
return false;
SgStatement* parent;
parent = getFuncStat(stmt);
for (SgStatement* current = parent->lexNext(); current != parent->lastNodeOfStmt(); current = current->lexNext())
{
if (current->variant() == GOTO_NODE)
{
SgLabel* label = ((SgGotoStmt*)current)->branchLabel();
if (label->id() == stmt->label()->id())
return true;
}
}
return false;
}
// Проверить на равенство expr1 и expr2
static bool isEqExpressions(SgExpression* exp1, SgExpression* exp2)
{
string str1, str2;
if (exp1 != NULL)
{
2024-04-09 11:51:21 +03:00
exp1 = CalculateInteger(exp1->copyPtr());
2023-09-14 19:43:13 +03:00
str1 = exp1->unparse();
}
if (exp2 != NULL)
{
2024-04-09 11:51:21 +03:00
exp2 = CalculateInteger(exp2->copyPtr());
2023-09-14 19:43:13 +03:00
str2 = exp2->unparse();
}
return str1 == str2;
}
// Проверить, что expr1 и expr2 противоположны по значению
static bool isOppositeExpressions(SgExpression* exp1, SgExpression* exp2)
{
if (exp1 == NULL || exp2 == NULL)
return false;
2024-04-09 11:51:21 +03:00
exp1 = CalculateInteger(exp1->copyPtr());
exp2 = CalculateInteger(exp2->copyPtr());
2023-09-14 19:43:13 +03:00
if (exp1->variant() == MINUS_OP)
return isEqExpressions(exp1->lhs(), exp2);
if (exp2->variant() == MINUS_OP)
return isEqExpressions(exp1, exp2->lhs());
if (exp1->variant() == INT_VAL && exp2->variant() == INT_VAL)
return exp1->valueInteger() == -1 * exp2->valueInteger();
return false;
}
static bool ifLoopCanBeReversed(LoopGraph* loop, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph)
{
const set<string> privVars;
auto dependency = depInfoForLoopGraph.find(loop);
if (dependency == depInfoForLoopGraph.end())
return false;
vector<depNode*> nodes = (dependency->second)->getNodes();
for (depNode* node : nodes)
{
int type = node->typedep;
const ddnature nature = (ddnature)node->kinddep;
if (type == ARRAYDEP && (nature == ddoutput || nature == ddreduce))
continue;
if (type == PRIVATEDEP)
continue;
return false;
}
return true;
}
static void reverseLoop(LoopGraph* loop, int dimensions)
{
if (loop == NULL)
return;
for (int i = 0; i < dimensions; ++i)
{
if (loop->calculatedCountOfIters != 0) {
std::swap(loop->startVal, loop->endVal);
loop->stepVal *= -1;
}
SgForStmt* loopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgExpression& start = loopStmt->start()->copy();
loopStmt->setStart(loopStmt->end()->copy());
loopStmt->setEnd(start);
SgExpression* tmpEx = loopStmt->step();
if (tmpEx == NULL)
tmpEx = new SgValueExp(1);
if (tmpEx->variant() == MINUS_OP)
{
SgExpression* lhs = tmpEx->lhs();
loopStmt->setStep(*lhs);
}
else
loopStmt->setStep(*(new SgExpression(MINUS_OP, tmpEx, NULL)));
Expression* startExpr = loopStmt->start() ? new Expression(loopStmt->start()) : NULL;
Expression* endExpr = loopStmt->end() ? new Expression(loopStmt->end()) : NULL;
Expression* stepExpr = loopStmt->step() ? new Expression(loopStmt->step()) : NULL;
loop->startEndStepVals = std::make_tuple(startExpr, endExpr, stepExpr);
if (i != dimensions - 1)
loop = loop->children[0];
}
}
static bool isSimpleExpression(SgExpression* expr)
{
// simple expression:
// CONST_REF / INT_VAL / VAR_REF / MINUS_OP VAR_REF /
// (VAR_REF (ADD_OP / SUBT_OP) INT_VAL) / (INT_VAL (ADD_OP / SUBT_OP) VAR_REF)
SgExpression* lhs = expr->lhs(), *rhs = expr->rhs();
SgConstantSymb* constExpr = NULL;
switch (expr->variant()) {
case CONST_REF:
constExpr = isSgConstantSymb(expr->symbol());
if (constExpr && constExpr->constantValue()->isInteger())
return true;
else
return false;
case VAR_REF:
case INT_VAL:
return true;
case MINUS_OP:
if (expr->lhs()->variant() != VAR_REF && !expr->lhs()->isInteger())
return false;
return true;
case ADD_OP:
case SUBT_OP:
if (!lhs || !rhs)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
if (lhs->variant() != VAR_REF && !lhs->isInteger())
return false;
if (rhs->variant() != VAR_REF && !rhs->isInteger())
return false;
if (!((lhs->variant() == VAR_REF) ^ (rhs->variant() == VAR_REF))) // only one variable
return false;
return true;
default:
return false;
}
}
static bool hasEqualVars(SgExpression* firstExpr, SgExpression* secondExpr)
{
set<string> firstVars;
getVariables(firstExpr, firstVars, set<int> { VAR_REF });
set<string> secondVars;
getVariables(secondExpr, secondVars, set<int> { VAR_REF });
if (firstVars.size() != secondVars.size())
return false;
for (const string var : firstVars)
if (secondVars.find(var) == secondVars.end())
return false;
return true;
}
// simple expression: var + varAdd
// varMinus -- is a minus before var
static void getSimpleExprVarParams(SgExpression* expr, bool* varMinus, int* varAdd)
{
if (expr->variant() == VAR_REF)
{
*varMinus = false;
*varAdd = 0;
}
else if (expr->variant() == MINUS_OP && expr->lhs()->variant() == VAR_REF)
{
*varMinus = true;
*varAdd = 0;
}
else if (expr->variant() == ADD_OP)
{
*varMinus = false;
if (expr->lhs()->isInteger())
*varAdd = expr->lhs()->valueInteger();
else
*varAdd = expr->rhs()->valueInteger();
}
else if (expr->variant() == SUBT_OP)
{
if (expr->lhs()->isInteger())
{
*varMinus = true;
*varAdd = expr->lhs()->valueInteger();
}
else
{
*varMinus = false;
*varAdd = (-1) * expr->rhs()->valueInteger();
}
}
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
// returns 1 if equal, 0 if first < second, 2 if first > second
// -1 if impossible to compare
static int compareSimpleExpressions(SgExpression* firstExpr, SgExpression* secondExpr)
{
if (!hasEqualVars(firstExpr, secondExpr))
return -1;
SgConstantSymb* constExpr1 = isSgConstantSymb(firstExpr->symbol());
if (constExpr1)
firstExpr = constExpr1->constantValue();
SgConstantSymb* constExpr2 = isSgConstantSymb(secondExpr->symbol());
if (constExpr2)
secondExpr = constExpr2->constantValue();
if (firstExpr->isInteger() && secondExpr->isInteger())
{
int firstVal = firstExpr->valueInteger();
int secondVal = secondExpr->valueInteger();
if (firstVal == secondVal)
return 1;
else if (firstVal < secondVal)
return 0;
else
return 2;
}
bool minusVarFirst = false, minusVarSecond = false;
int addVarFirst = 0, addVarSecond = 0;
getSimpleExprVarParams(firstExpr, &minusVarFirst, &addVarFirst);
getSimpleExprVarParams(secondExpr, &minusVarSecond, &addVarSecond);
if (minusVarFirst != minusVarSecond) // vars have different sign
return -1;
if (addVarFirst == addVarSecond)
return 1;
else if (addVarFirst < addVarSecond)
return 0;
else
return 2;
}
static bool canBeCombinedWithDiffBounds(const LoopGraph* firstLoop, const LoopGraph* loop)
{
// TODO: удалить после добавления анализа зависимостей по массивам:
return false;
//
if (firstLoop->hasLimitsToCombine() || loop->hasLimitsToCombine())
return false;
SgForStmt* firstLoopStmt = isSgForStmt(firstLoop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (hasGotoToStatement(firstLoopStmt) || hasGotoToStatement(loopStmt))
return false;
if (!isSimpleExpression(firstLoopStmt->start()) || !isSimpleExpression(firstLoopStmt->end()))
return false;
if (!isSimpleExpression(loopStmt->start()) || !isSimpleExpression(loopStmt->end()))
return false;
if (firstLoop->calculatedCountOfIters != 0 && loop->calculatedCountOfIters != 0)
{
// intersection of ranges is not empty:
if (firstLoop->stepVal > 0)
{
if (firstLoop->startVal > loop->endVal || firstLoop->endVal < loop->startVal)
return false;
}
else
{
if (firstLoop->startVal < loop->endVal || firstLoop->endVal > loop->startVal)
return false;
}
if (firstLoop->stepVal * loop->stepVal < 0) // steps have different sign
return false;
return true;
}
SgExpression* step1 = firstLoopStmt->step();
SgExpression* step2 = loopStmt->step();
int step1Val = 1, step2Val = 1;
if (step1 != NULL)
{
if (step1->isInteger())
step1Val = step1->valueInteger();
else
return false;
}
if (step2 != NULL)
{
if (step2->isInteger())
step2Val = step2->valueInteger();
else
return false;
}
if (step1Val * step2Val < 0) // steps have different sign
return false;
int compStart = compareSimpleExpressions(firstLoopStmt->start(), loopStmt->start());
int compEnd = compareSimpleExpressions(firstLoopStmt->end(), loopStmt->end());
if (compStart == -1 || compEnd == -1) // impossible to compare
return false;
return true;
}
static int getDeepestDimToReverse(LoopGraph* firstLoop, LoopGraph* loop, int perfectLoop,
LoopGraph** toReverse, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph)
{
LoopGraph* curFirstLoop = firstLoop;
LoopGraph* curLoop = loop;
int i = 0;
bool canBeReversed1 = ifLoopCanBeReversed(firstLoop, depInfoForLoopGraph);
bool canBeReversed2 = ifLoopCanBeReversed(loop, depInfoForLoopGraph);
if (!canBeReversed1 && !canBeReversed2)
return 0;
for (i = 0; i < perfectLoop; ++i)
{
SgForStmt* firstLoopStmt = isSgForStmt(curFirstLoop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(curLoop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (curLoop->hasLimitsToCombine() || hasGotoToStatement(loopStmt))
break;
if (curFirstLoop->calculatedCountOfIters != 0 && curLoop->calculatedCountOfIters != 0) {
if (curFirstLoop->startVal != curLoop->endVal)
break;
if (curFirstLoop->endVal != curLoop->startVal)
break;
if (curFirstLoop->stepVal != -1 * curLoop->stepVal)
break;
}
else {
if (!isEqExpressions(std::get<0>(curFirstLoop->startEndStepVals), std::get<1>(curLoop->startEndStepVals)))
break;
if (!isEqExpressions(std::get<1>(curFirstLoop->startEndStepVals), std::get<0>(curLoop->startEndStepVals)))
break;
SgExpression* step1 = std::get<2>(curFirstLoop->startEndStepVals);
SgExpression* step2 = std::get<2>(curLoop->startEndStepVals);
SgValueExp defaultStep(1);
if (step1 == NULL)
step1 = &defaultStep;
if (step2 == NULL)
step2 = &defaultStep;
if (!isOppositeExpressions(step1, step2))
break;
}
if (i != perfectLoop - 1)
{
if (curLoop->children.size() != 1)
break;
curFirstLoop = curFirstLoop->children[0];
curLoop = curLoop->children[0];
}
}
if (i > 0)
{
if (canBeReversed1)
*toReverse = firstLoop;
else
*toReverse = loop;
}
return i;
}
/**
* Найти количество измерений, объединение по которым возможно.
*/
static int getDeepestDimForCombine(const LoopGraph* firstLoop, const LoopGraph* loop, int perfectLoop)
{
const LoopGraph* curFirstLoop = firstLoop;
const LoopGraph* curLoop = loop;
int i = 0;
for (i = 0; i < perfectLoop; ++i)
{
SgForStmt* firstLoopStmt = isSgForStmt(curFirstLoop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(curLoop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (curLoop->hasLimitsToCombine() || hasGotoToStatement(loopStmt))
return i;
if (curFirstLoop->calculatedCountOfIters != 0 && curLoop->calculatedCountOfIters != 0) {
if (curFirstLoop->startVal != curLoop->startVal)
return i;
if (curFirstLoop->endVal != curLoop->endVal)
return i;
if (curFirstLoop->stepVal != curLoop->stepVal)
return i;
}
else {
// startVal:
if (!isEqExpressions(std::get<0>(curFirstLoop->startEndStepVals), std::get<0>(curLoop->startEndStepVals)))
return i;
// endVal:
if (!isEqExpressions(std::get<1>(curFirstLoop->startEndStepVals), std::get<1>(curLoop->startEndStepVals)))
return i;
SgExpression* step1 = std::get<2>(curFirstLoop->startEndStepVals);
SgExpression* step2 = std::get<2>(curLoop->startEndStepVals);
if (!isEqExpressions(step1, step2))
{
if ((step1 == NULL) ^ (step2 == NULL))
{
SgValueExp defaultStep(1);
if (step1 == NULL)
step1 = &defaultStep;
else
step2 = &defaultStep;
if (!isEqExpressions(step1, step2))
return i;
}
else
return i;
}
}
if (i != perfectLoop - 1)
{
curFirstLoop = curFirstLoop->children[0];
curLoop = curLoop->children[0];
}
}
return i;
}
static void compareIterationVars(const LoopGraph* firstLoop, const LoopGraph* loop, int dimensions, map<SgSymbol*, SgSymbol*>& symbols)
{
for (int i = 0; i < dimensions; ++i)
{
SgForStmt* firstLoopStmt = isSgForStmt(firstLoop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (!isEqSymbols(firstLoopStmt->doName(), loopStmt->doName()))
symbols.insert(make_pair(loopStmt->doName(), firstLoopStmt->doName()));
if (i != dimensions - 1)
{
firstLoop = firstLoop->children[0];
loop = loop->children[0];
}
}
}
static SgSymbol* copySymbolAndRename(SgSymbol* symbol)
{
string baseName = symbol->identifier();
size_t pos = baseName.rfind('_');
string strNumber;
int number;
if (pos != string::npos)
{
for (size_t i = pos + 1; i < baseName.length(); ++i)
{
if (baseName[i] >= '0' && baseName[i] <= '9')
strNumber.push_back(baseName[i]);
else
{
strNumber.clear();
break;
}
}
}
if (!strNumber.empty())
{
baseName.resize(baseName.length() - (strNumber.length() + 1));
number = atoi(strNumber.c_str()) + 1;
}
else
number = 1;
int new_name_num = checkSymbNameAndCorrect(baseName + '_', number);
string new_name = baseName + '_' + std::to_string(new_name_num);
SgSymbol* new_sym = &symbol->copy();
new_sym->changeName(new_name.c_str());
return new_sym;
}
static void renameVariables(const map<SgSymbol*, SgSymbol*>& symbols, SgExpression* ex)
{
if (ex)
{
if ((ex->variant() == VAR_REF || isArrayRef(ex)) && ex->symbol())
{
for (auto& pair : symbols)
{
if (isEqSymbols(pair.first, ex->symbol()))
{
ex->setSymbol(pair.second);
break;
}
}
}
renameVariables(symbols, ex->lhs());
renameVariables(symbols, ex->rhs());
}
}
static void renameIterationVariables(LoopGraph* loop, const map<SgSymbol*, SgSymbol*>& symbols)
{
if (loop)
{
string& loopName = loop->loopSymbol;
for (auto& pair : symbols)
{
if (pair.first->identifier() == loopName)
{
loop->loopSymbol = (string)pair.second->identifier();
break;
}
}
for (LoopGraph* child : loop->children)
renameIterationVariables(child, symbols);
}
}
static void renameVariablesInLoop(LoopGraph* loop, const map<SgSymbol*, SgSymbol*>& symbols)
{
renameIterationVariables(loop, symbols);
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == FOR_NODE)
{
SgForStmt* for_st = (SgForStmt*)st;
for (auto& pair : symbols)
if (isEqSymbols(pair.first, for_st->symbol()))
for_st->setDoName(*pair.second);
}
for (int i = 0; i < 3; ++i)
renameVariables(symbols, st->expr(i));
}
}
static void renamePrivatesInMap(LoopGraph* loop, const map<SgSymbol*, SgSymbol*>& symbols, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
auto privates = mapPrivates.find(loop);
if (loop && privates != mapPrivates.end())
{
set<SgSymbol*> newList;
for (auto& priv : privates->second)
{
bool found = false;
for (auto& pair : symbols)
{
if (isEqSymbols(priv, pair.first))
{
found = true;
newList.insert(pair.second);
break;
}
}
if (!found)
newList.insert(priv);
}
privates->second = newList;
for (LoopGraph* child : loop->children)
renamePrivatesInMap(child, symbols, mapPrivates);
}
}
static void addIterationVarsToMap(LoopGraph* loop, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
auto privates = mapPrivates.find(loop);
if (loop && privates != mapPrivates.end())
{
set<SgSymbol*> symbols;
fillIterationVariables(loop, symbols);
for (SgSymbol* var : symbols)
if (!isSymbolInSet(privates->second, var))
privates->second.insert(var);
for (LoopGraph* child : loop->children)
addIterationVarsToMap(child, mapPrivates);
}
}
static void fillMapPrivateVars(const vector<LoopGraph*>& loopGraphs, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
if (loopGraphs.size() == 0)
return;
for (int i = 0; i < loopGraphs.size(); ++i)
{
LoopGraph* loop = loopGraphs[i];
set<Symbol*> symbols;
for (auto& data : getAttributes<SgStatement*, SgStatement*>(loop->loop, set<int>{ SPF_ANALYSIS_DIR }))
fillPrivatesFromComment(new Statement(data), symbols);
set<SgSymbol*> loopPrivates;
for (Symbol* symbol : symbols)
loopPrivates.insert(OriginalSymbol((SgSymbol*)symbol));
mapPrivates.insert(make_pair(loop, loopPrivates));
if (!loop->children.empty())
fillMapPrivateVars(loop->children, mapPrivates);
}
}
static SgForStmt* getInnerLoop(const LoopGraph* loop, int deep)
{
int perfectLoop = loop->perfectLoop;
const LoopGraph* curLoop = loop;
SgForStmt* result = NULL;
for (int i = 0; i < deep; ++i)
{
result = isSgForStmt(curLoop->loop->GetOriginal());
checkNull(result, convertFileName(__FILE__).c_str(), __LINE__);
if (i != perfectLoop - 1)
curLoop = curLoop->children[0];
}
return result;
}
static void moveBody(SgStatement* from, SgStatement* to, const map<SgSymbol*, SgSymbol*>& symbols)
{
for (auto st = from->lexNext(); st != from->lastNodeOfStmt(); st = st->lexNext())
for (int i = 0; i < 3; ++i)
renameVariables(symbols, st->expr(i));
auto loopBody = from->extractStmtBody();
to->lastExecutable()->insertStmtAfter(*loopBody, *to);
}
static SgExpression* createIterationCountExpr(const LoopGraph* loop)
{
// loop: do i = a, b, c
// iteration count expression after loop: (a + ((b - a + c) / c - 1) * c) => [a + 'cIters']
SgForStmt* firstLoopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(firstLoopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgExpression* a = firstLoopStmt->start();
SgExpression* b = firstLoopStmt->end();
SgExpression* c = firstLoopStmt->step();
if (c == NULL)
c = new SgValueExp(1);
SgExpression* ex = &(*a + ((*b - *a + *c) / *c - *new SgValueExp(1)) * *c);
return ex;
}
static void changeVarToExpr(SgExpression* expression, SgSymbol* var, SgExpression* changeExpr)
{
if (expression == NULL || var == NULL)
return;
SgExpression* lhs = expression->lhs();
SgExpression* rhs = expression->rhs();
if (lhs && lhs->symbol() && isEqSymbols(lhs->symbol(), var))
expression->setLhs(changeExpr);
if (rhs && rhs->symbol() && isEqSymbols(rhs->symbol(), var))
expression->setRhs(changeExpr);
changeVarToExpr(expression->lhs(), var, changeExpr);
changeVarToExpr(expression->rhs(), var, changeExpr);
}
static void changeVarToExpr(SgStatement* statement, SgSymbol* var, SgExpression* expr, int startExpr = 0)
{
if (statement == NULL || var == NULL)
return;
for (int i = startExpr; i < 3; ++i)
{
SgExpression* ex = statement->expr(i);
if (ex && ex->symbol() && isEqSymbols(ex->symbol(), var))
{
statement->setExpression(i, expr);
continue;
}
changeVarToExpr(ex, var, expr);
}
}
static void changeIterationVarToCountExpr(LoopGraph* firstLoop, LoopGraph* loop, int dimensions, SgSymbol* var)
{
SgSymbol* sym = NULL;
for (int dim = 0; dim < dimensions; ++dim)
{
sym = getLoopSymbol(firstLoop);
if (isEqSymbols(sym, var))
break;
firstLoop = firstLoop->children[0];
}
SgExpression* countExpr = createIterationCountExpr(firstLoop);
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
changeVarToExpr(st, var, countExpr);
}
static bool isVarInExpression(SgSymbol* var, SgExpression* ex)
{
bool res = false;
if (ex)
{
if (ex->variant() == VAR_REF || isArrayRef(ex))
if (ex->symbol() && isEqSymbols(ex->symbol(), var))
return true;
res |= isVarInExpression(var, ex->lhs());
res |= isVarInExpression(var, ex->rhs());
}
return res;
}
static bool varIsChanged(SgSymbol* var, LoopGraph* loop)
{
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == ASSIGN_STAT && isEqSymbols(st->expr(0)->symbol(), var))
return true;
if (st->variant() == FOR_NODE && isEqSymbols(((SgForStmt*)st)->doName(), var))
return true;
}
return false;
}
static bool varIsRead(SgSymbol* var, LoopGraph* loop)
{
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
{
int i = 0;
if (st->variant() == ASSIGN_STAT && isEqSymbols(st->expr(0)->symbol(), var))
i = 1;
for (; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
return true;
}
return false;
}
static bool varIsChangedBetween(SgSymbol* var, SgStatement* begin, SgStatement* end)
{
for (SgStatement* st = begin; st != end; st = st->lexNext())
if (st->variant() == ASSIGN_STAT && isEqSymbols(st->expr(0)->symbol(), var))
return true;
return false;
}
static bool varIsUsedBetween(SgSymbol* var, SgStatement* begin, SgStatement* end)
{
if (begin == NULL || end == NULL)
return false;
for (SgStatement* st = begin; st != end; st = st->lexNext())
for (int i = 0; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
return true;
return false;
}
static bool isAntiVarDependency(SgSymbol* var, SgForStmt* loop)
{
bool is_used = false;
for (SgStatement* st = loop; st != loop->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == ASSIGN_STAT && isEqSymbols(st->expr(0)->symbol(), var))
{
for (int i = 1; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
return true;
return is_used;
}
for (int i = 0; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
is_used = true;
}
return false;
}
static void replaceIterationVar(LoopGraph* firstLoop, LoopGraph* loop, int dimensions, SgSymbol* var, SgSymbol* newSymbol)
{
LoopGraph* first = firstLoop;
for (int i = 0; i < dimensions; ++i)
{
SgSymbol* loopSymbol = getLoopSymbol(first);
if (isEqSymbols(loopSymbol, var))
break;
first = first->children[0];
}
SgExpression* countExpr = createIterationCountExpr(first);
SgStatement* st = new SgAssignStmt(*new SgVarRefExp(newSymbol), *countExpr);
firstLoop->loop->insertStmtBefore(*st, *firstLoop->loop->controlParent());
map<SgSymbol*, SgSymbol*> toRename;
toRename.insert(make_pair(var, newSymbol));
renameVariablesInLoop(loop, toRename);
}
static bool varIsReallyNotPrivate(SgSymbol* var, const LoopGraph* loop, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
bool res = false;
for (SgStatement* st = loop->loop; st != loop->loop->lastNodeOfStmt(); st = st->lexNext())
{
if (st->variant() == FOR_NODE)
{
for (LoopGraph* child : loop->children)
{
if (child->loop->id() == st->id())
{
if (isSymbolInSet(mapPrivates[child], var))
res = false;
else
res = varIsReallyNotPrivate(var, child, mapPrivates);
st = st->lastNodeOfStmt();
}
}
}
else
{
for (int i = 0; i < 3; ++i)
if (st->expr(i) && isVarInExpression(var, st->expr(i)))
return true;
}
}
return res;
}
static void insertStmtBeforeOuterLoop(SgStatement* st, SgStatement* loop)
{
while (loop->controlParent()->variant() == FOR_NODE)
loop = loop->controlParent();
loop->insertStmtBefore(*st, *loop->controlParent());
}
static void correctInheritedUsage(LoopGraph* firstLoop, LoopGraph* loop, int dimensions, set<SgSymbol*>& firstLoopVars, set<SgSymbol*>& loopVars)
{
set<SgSymbol*> firstLoopIterationVars;
fillIterationVariables(firstLoop, firstLoopIterationVars, dimensions);
set<SgSymbol*> loopIterationVars;
fillIterationVariables(loop, loopIterationVars, dimensions);
LoopGraph* first = firstLoop;
for (int i = 0; i < dimensions; ++i)
{
SgSymbol* var = getLoopSymbol(first);
if (isSymbolInSet(loopVars, var) && !isSymbolInSet(loopIterationVars, var))
{
if (varIsChanged(var, loop))
{
checkNull(isSgForStmt(loop->loop), convertFileName(__FILE__).c_str(), __LINE__);
if (isAntiVarDependency(var, (SgForStmt*)loop->loop))
{
SgSymbol* newSymbol = copySymbolAndRename(var);
eraseSymbolFromSet(loopVars, var);
loopVars.insert(newSymbol);
makeDeclaration(loop->loop, vector<SgSymbol*> { newSymbol });
replaceIterationVar(first, loop, dimensions, var, newSymbol);
}
}
else
{
eraseSymbolFromSet(loopVars, var);
changeIterationVarToCountExpr(first, loop, dimensions, var);
}
}
if (i != dimensions - 1)
first = first->children[0];
}
// TODO:
// установка значений итерационным переменным, которые в результате объединения заменяются на другие переменные
// временно убрано из прохода
/*first = firstLoop;
for (int i = 0; i < dimensions; ++i)
{
SgSymbol* loopVar = getLoopSymbol(loop);
SgSymbol* firstLoopVar = getLoopSymbol(first);
if (!isEqSymbols(loopVar, firstLoopVar))
{
SgExpression* countExpr = createIterationCountExpr(loop);
if (isSymbolInSet(firstLoopVars, loopVar))
{
SgStatement* st = new SgAssignStmt(*new SgVarRefExp(loopVar), *countExpr);
firstLoop->loop->insertStmtAfter(*st, *firstLoop->loop->controlParent());
}
else
{
SgStatement* st = new SgAssignStmt(*new SgVarRefExp(loopVar), *countExpr);
firstLoop->loop->insertStmtBefore(*st, *firstLoop->loop->controlParent());
}
}
if (i != dimensions - 1)
{
first = first->children[0];
loop = loop->children[0];
}
}*/
}
// TODO: улучшить анализ зависимостей по массивам
static bool hasDependenciesBetweenArrays(LoopGraph* firstLoop, LoopGraph* loop, int dimensions)
{
set<DIST::Array*> readWriteFrist, readWriteSecond;
vector<pair<LoopGraph*, set<DIST::Array*>*>> loops = { make_pair(firstLoop, &readWriteFrist), make_pair(loop, &readWriteSecond) };
for (auto& loop : loops)
{
const LoopGraph* currLoop = loop.first;
for (int d = 0; d < dimensions; ++d)
{
checkNull(currLoop, convertFileName(__FILE__).c_str(), __LINE__);
*(loop.second) = loop.first->usedArraysAll;
if (currLoop->children.size())
currLoop = currLoop->children[0];
else
currLoop = NULL;
}
}
//есть ли вообще одинаковые массивы, которые читаются и пишутся в объединяемых циклах и отображены на них
set<DIST::Array*> intersect;
std::set_intersection(readWriteFrist.begin(), readWriteFrist.end(), readWriteSecond.begin(), readWriteSecond.end(), inserter(intersect, intersect.begin()));
if (intersect.size() == 0)
return false;
for (auto& array : intersect)
{
const LoopGraph* currLoop[2] = { firstLoop, loop };
for (int d = 0; d < dimensions; ++d)
{
//по измерениям массива отображение на цикл вложенности d
vector<set<pair<int, int>>> coefsRead[2], coefsWrite[2];
checkNull(currLoop[0], convertFileName(__FILE__).c_str(), __LINE__);
checkNull(currLoop[1], convertFileName(__FILE__).c_str(), __LINE__);
for (int k = 0; k < 2; ++k)
{
auto it = currLoop[k]->readOpsForLoop.find(array);
if (it != currLoop[k]->readOpsForLoop.end())
{
if (coefsRead[k].size() == 0)
coefsRead[k].resize(it->second.size());
for (int z = 0; z < it->second.size(); ++z)
if (it->second[z].coefficients.size())
for (auto& coef : it->second[z].coefficients)
coefsRead[k][z].insert(coef.first);
}
auto itW = currLoop[k]->writeOpsForLoop.find(array);
if (itW != currLoop[k]->writeOpsForLoop.end())
{
if (coefsWrite[k].size() == 0)
coefsWrite[k].resize(itW->second.size());
for (int z = 0; z < itW->second.size(); ++z)
if (itW->second[z].coefficients.size())
for (auto& coef : itW->second[z].coefficients)
coefsWrite[k][z].insert(coef.first);
}
}
//нет записей, значит нет зависимости
bool nulWrite = true;
for (auto& wr : coefsWrite)
for (auto& elem : wr)
if (elem.size() != 0)
nulWrite = false;
if (nulWrite)
continue;
// если чтение в одном цикле и запись (и наоборот) в другом идут по разным правилам, то пока что это зависимость.
// здесь можно уточнить.
const int len = std::max(coefsWrite[0].size(), coefsRead[0].size());
int countW[2] = { 0, 0 };
int countR[2] = { 0, 0 };
for (int L = 0; L < 2; ++L)
for (int z = 0; z < coefsWrite[L].size(); ++z)
countW[L] += (coefsWrite[L][z].size() ? 1 : 0);
for (int L = 0; L < 2; ++L)
for (int z = 0; z < coefsRead[L].size(); ++z)
countR[L] += (coefsRead[L][z].size() ? 1 : 0);
for (int p = 0; p < len; ++p)
{
if (coefsWrite[1].size() && coefsWrite[0].size())
if (coefsWrite[0][p].size() != 0 && coefsWrite[1][p].size() != 0)
if (coefsWrite[0][p] != coefsWrite[1][p])
return true;
if (coefsRead[1].size() && coefsWrite[0].size())
if (coefsWrite[0][p].size() != 0 && coefsRead[1][p].size() != 0)
if (coefsWrite[0][p] != coefsRead[1][p])
return true;
if (coefsWrite[1].size() && coefsRead[0].size())
if (coefsWrite[1][p].size() != 0 && coefsRead[0][p].size() != 0)
if (coefsWrite[1][p] != coefsRead[0][p])
return true;
//отображение на разные измерения
if (coefsWrite[1].size() && coefsWrite[0].size())
{
if (coefsWrite[0][p].size() != 0 && coefsWrite[1][p].size() == 0 && countW[1] ||
coefsWrite[0][p].size() == 0 && coefsWrite[1][p].size() != 0 && countW[0])
return true;
}
if (coefsRead[1].size() && coefsWrite[0].size())
{
if (coefsWrite[0][p].size() != 0 && coefsRead[1][p].size() == 0 && countR[1] ||
coefsWrite[0][p].size() == 0 && coefsRead[1][p].size() != 0 && countW[0])
return true;
}
if (coefsWrite[1].size() && coefsRead[1].size())
{
if (coefsWrite[1][p].size() != 0 && coefsRead[0][p].size() == 0 && countR[0] ||
coefsWrite[1][p].size() == 0 && coefsRead[0][p].size() != 0 && countW[1])
return true;
}
//где то нет правил отображения вообще, но есть факт его наличия.
if ( ((coefsWrite[0].size() == 0 && coefsRead[0].size() == 0) && (countW[0] == 0 && countR[0] == 0))
||
((coefsWrite[1].size() == 0 && coefsRead[1].size() == 0) && (countW[1] == 0 && countR[1] == 0)) )
return true;
}
currLoop[0] = (currLoop[0]->children.size()) ? currLoop[0]->children[0] : NULL;
currLoop[1] = (currLoop[1]->children.size()) ? currLoop[1]->children[0] : NULL;
}
}
return false;
}
static int solveVarsCollisions(LoopGraph* firstLoop, LoopGraph* loop, int dimensions, map<LoopGraph*, set<SgSymbol*>>& mapPrivates)
{
set<SgSymbol*> firstLoopAllVars = getAllVariables<SgSymbol*>(firstLoop->loop, firstLoop->loop->lastNodeOfStmt(), set<int> { VAR_REF, ARRAY_REF });
set<SgSymbol*> loopAllVars = getAllVariables<SgSymbol*>(loop->loop, loop->loop->lastNodeOfStmt(), set<int> { VAR_REF, ARRAY_REF });
if (mapPrivates.find(loop) == mapPrivates.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
set<SgSymbol*> loopPrivates = mapPrivates[loop];
if (mapPrivates.find(firstLoop) == mapPrivates.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
set<SgSymbol*> firstLoopPrivates = mapPrivates[firstLoop];
set<SgSymbol*> loopIterationVars;
fillIterationVariables(loop, loopIterationVars, dimensions);
for (SgSymbol* var : firstLoopAllVars)
{
if (var->type()->variant() == T_ARRAY || !varIsReallyNotPrivate(var, firstLoop, mapPrivates))
continue;
if (isSymbolInSet(loopPrivates, var) || isSymbolInSet(firstLoopPrivates, var))
continue;
bool isChangedInFirst = false, isChangedInSecond = false;
bool isReadInFirst = false, isReadInSecond = false;
if (isSymbolInSet(loopAllVars, var) && varIsReallyNotPrivate(var, loop, mapPrivates))
{
isChangedInFirst = varIsChanged(var, firstLoop);
isChangedInSecond = varIsChanged(var, loop);
isReadInFirst = varIsRead(var, firstLoop);
isReadInSecond = varIsRead(var, loop);
if (isChangedInFirst && isReadInSecond || isChangedInSecond && isReadInFirst)
return -1;
}
}
if (hasDependenciesBetweenArrays(firstLoop, loop, dimensions))
return -1;
correctInheritedUsage(firstLoop, loop, dimensions, firstLoopAllVars, loopAllVars);
for (SgSymbol* var : loopPrivates)
eraseSymbolFromSet(firstLoopPrivates, var);
for (SgSymbol* var : loopIterationVars)
{
eraseSymbolFromSet(firstLoopPrivates, var);
eraseSymbolFromSet(loopPrivates, var);
LoopGraph* parentLoop = loop;
while (parentLoop)
{
auto pair = mapPrivates.find(parentLoop);
if (pair == mapPrivates.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
eraseSymbolFromSet(pair->second, var);
parentLoop = parentLoop->parent;
}
}
vector<SgSymbol*> symbolsToDeclare;
set<SgSymbol*> varsFromLoopToRename;
getIntersection(firstLoopAllVars, loopPrivates, varsFromLoopToRename);
set<SgSymbol*> varsFromFirstLoopToRename;
getIntersection(loopAllVars, firstLoopPrivates, varsFromFirstLoopToRename);
map<SgSymbol*, SgSymbol*> symbolsToRename;
for (SgSymbol* symbol : varsFromLoopToRename)
{
if (varIsReallyNotPrivate(symbol, firstLoop, mapPrivates))
{
SgSymbol* newSymbol = copySymbolAndRename(symbol);
symbolsToDeclare.push_back(newSymbol);
symbolsToRename.insert(make_pair(symbol, newSymbol));
}
}
renamePrivatesInMap(loop, symbolsToRename, mapPrivates);
renameVariablesInLoop(loop, symbolsToRename);
symbolsToRename.clear();
for (SgSymbol* symbol : varsFromFirstLoopToRename)
{
if (varIsReallyNotPrivate(symbol, loop, mapPrivates))
{
SgSymbol* newSymbol = copySymbolAndRename(symbol);
symbolsToDeclare.push_back(newSymbol);
symbolsToRename.insert(make_pair(symbol, newSymbol));
}
}
renamePrivatesInMap(firstLoop, symbolsToRename, mapPrivates);
renameVariablesInLoop(firstLoop, symbolsToRename);
makeDeclaration(symbolsToDeclare, loop->loop->GetOriginal());
LoopGraph* loopToInsert = firstLoop;
for (int i = 0; i < dimensions; ++i)
{
auto pair = mapPrivates.find(loopToInsert);
if (pair == mapPrivates.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
for (SgSymbol* privateVar : mapPrivates[loop])
pair->second.insert(privateVar);
if (i != dimensions - 1)
loopToInsert = loopToInsert->children[0];
}
LoopGraph* toDelete = loop;
for (int i = 0; i < dimensions; ++i)
{
mapPrivates.erase(toDelete);
if (i != dimensions - 1)
toDelete = toDelete->children[0];
}
return 0;
}
static int getNewStep(SgForStmt* firstLoopStmt, SgForStmt* loopStmt)
{
SgExpression* step1 = firstLoopStmt->step();
SgExpression* step2 = loopStmt->step();
int step1Val = 1, step2Val = 1;
if (step1 != NULL)
step1Val = step1->valueInteger();
if (step2 != NULL)
step2Val = step2->valueInteger();
int stepGcd = gcd(std::abs(step1Val), std::abs(step2Val));
int newStep = stepGcd;
int startDifference = 0;
if (firstLoopStmt->start()->isInteger() && loopStmt->start()->isInteger())
{
int start1Val = firstLoopStmt->start()->valueInteger();
int start2Val = loopStmt->start()->valueInteger();
startDifference = std::abs(start1Val - start2Val);
}
else
{
bool var1Minus = false, var2Minus = false;
int var1Add = 0, var2Add = 0;
getSimpleExprVarParams(firstLoopStmt->start(), &var1Minus, &var1Add);
getSimpleExprVarParams(loopStmt->start(), &var2Minus, &var2Add);
startDifference = std::abs(var1Add - var2Add);
}
if (startDifference != 0)
newStep = gcd(startDifference, stepGcd);
if (step1Val < 0)
newStep *= -1;
return newStep;
}
// returns -1 if impossible to get global bounds
static void getGlobalBounds(SgForStmt* firstLoopStmt, SgForStmt* loopStmt, pair<SgExpression*, SgExpression*>& globalBounds)
{
int compStart = compareSimpleExpressions(firstLoopStmt->start(), loopStmt->start());
int compEnd = compareSimpleExpressions(firstLoopStmt->end(), loopStmt->end());
int step = 1;
if (firstLoopStmt->step())
step = firstLoopStmt->step()->valueInteger();
SgExpression* start = NULL, * end = NULL;
if (step > 0)
{
if (compStart == 0) // firstLoopStmt->start() < loopStmt->start()
start = &firstLoopStmt->start()->copy();
else
start = &loopStmt->start()->copy();
if (compEnd == 2) // firstLoopStmt->end() > loopStmt->end()
end = &firstLoopStmt->end()->copy();
else
end = &loopStmt->end()->copy();
}
else
{
if (compStart == 2) // firstLoopStmt->start() > loopStmt->start()
start = &firstLoopStmt->start()->copy();
else
start = &loopStmt->start()->copy();
if (compEnd == 0) // firstLoopStmt->end() < loopStmt->end()
end = &firstLoopStmt->end()->copy();
else
end = &loopStmt->end()->copy();
}
globalBounds = make_pair(start, end);
}
static SgStatement* makeIfStatementForBounds(SgForStmt* loopStmt, const pair<SgExpression*, SgExpression*>& globalBounds,
SgSymbol* loopSymbol, int newStep)
{
SgExpression* step = NULL;
int stepVal = 1;
if (loopStmt->step() != NULL)
{
step = &loopStmt->step()->copy();
stepVal = loopStmt->step()->valueInteger();
}
else
step = new SgValueExp(1);
SgExpression* stepCond = NULL;
if (stepVal != newStep)
{
// MOD(var - start, step) .eq. 0
SgExpression* varRef = new SgExpression(VAR_REF, NULL, NULL, &loopSymbol->copy());
SgExpression* subt = new SgExpression(SUBT_OP, varRef, &loopStmt->start()->copy());
vector<SgExpression*> vec = { step, subt };
SgExpression* list = makeExprList(vec, false);
SgSymbol* symbol = new SgSymbol(FUNCTION_NAME, "mod");
SgExpression* mod = new SgExpression(FUNC_CALL, list, NULL, symbol);
stepCond = new SgExpression(EQ_OP, mod, new SgValueExp(0));
}
SgExpression* startCond = NULL;
if (!isEqExpressions(loopStmt->start(), globalBounds.first))
{
SgExpression* varRef = new SgExpression(VAR_REF, NULL, NULL, &loopSymbol->copy());
if (stepVal > 0)
startCond = new SgExpression(GTEQL_OP, varRef, &loopStmt->start()->copy());
else
startCond = new SgExpression(LTEQL_OP, varRef, &loopStmt->start()->copy());
}
SgExpression* endCond = NULL;
if (!isEqExpressions(loopStmt->end(), globalBounds.second))
{
SgExpression* varRef = new SgExpression(VAR_REF, NULL, NULL, &loopSymbol->copy());
if (stepVal > 0)
endCond = new SgExpression(LTEQL_OP, varRef, &loopStmt->end()->copy());
else
endCond = new SgExpression(GTEQL_OP, varRef, &loopStmt->end()->copy());
}
SgExpression* loopCond = NULL;
if (startCond)
loopCond = startCond;
if (endCond)
{
if (loopCond)
loopCond = new SgExpression(AND_OP, loopCond, endCond);
else
loopCond = endCond;
}
if (stepCond)
{
if (loopCond)
loopCond = new SgExpression(AND_OP, loopCond, stepCond);
else
loopCond = stepCond;
}
SgIfStmt* ifStmt = NULL;
if (loopCond)
ifStmt = new SgIfStmt(*loopCond);
return ifStmt;
}
static void moveBodyWithDiffBounds(SgForStmt* from, SgForStmt* to, const pair<SgExpression*, SgExpression*>& globalBounds, int newStep)
{
map<SgSymbol*, SgSymbol*> symbols;
symbols.insert(make_pair(from->doName(), to->doName()));
SgStatement* ifStmt = makeIfStatementForBounds(from, globalBounds, to->doName(), newStep);
if (ifStmt)
{
to->lastExecutable()->insertStmtAfter(*ifStmt, *to);
moveBody(from, ifStmt, symbols);
}
else
moveBody(from, to, symbols);
}
static void moveCommentsAndAttributes(SgStatement* loopFrom, SgStatement* loopTo)
{
if (loopFrom->comments())
loopTo->addComment(string(loopFrom->comments()).c_str());
if (loopFrom->numberOfAttributes())
{
auto data = getAttributes<SgStatement*, SgStatement*>(loopFrom, set<int>{ SPF_ANALYSIS_DIR });
for (auto& elem : data)
loopTo->addAttribute(SPF_ANALYSIS_DIR, elem, sizeof(SgStatement*));
}
}
static void combineWithDifferentBounds(const LoopGraph* firstLoop, const LoopGraph* loop)
{
SgForStmt* firstLoopStmt = isSgForStmt(firstLoop->loop->GetOriginal());
checkNull(firstLoop, convertFileName(__FILE__).c_str(), __LINE__);
SgForStmt* loopStmt = isSgForStmt(loop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
SgExpression* step1 = firstLoopStmt->step();
int step1Val = 1;
if (step1 != NULL)
step1Val = step1->valueInteger();
pair<SgExpression*, SgExpression*> globalBounds;
getGlobalBounds(firstLoopStmt, loopStmt, globalBounds);
int newStep = getNewStep(firstLoopStmt, loopStmt);
SgStatement* firstLoopIfStmt = makeIfStatementForBounds(firstLoopStmt, globalBounds, firstLoopStmt->doName(), newStep);
firstLoopStmt->setStart(*globalBounds.first);
firstLoopStmt->setEnd(*globalBounds.second);
if (newStep != step1Val)
firstLoopStmt->setStep(*new SgValueExp(newStep));
if (firstLoopIfStmt)
{
map<SgSymbol*, SgSymbol*> symbols;
moveBody(firstLoopStmt, firstLoopIfStmt, symbols);
firstLoopStmt->lastExecutable()->insertStmtAfter(*firstLoopIfStmt, *firstLoopStmt);
}
moveBodyWithDiffBounds(loopStmt, firstLoopStmt, globalBounds, newStep);
moveCommentsAndAttributes(loopStmt, firstLoopStmt);
loopStmt->extractStmt();
}
/**
* Собственно объединение
*/
static bool combine(LoopGraph* firstLoop, const vector<LoopGraph*>& nextLoops, set<LoopGraph*>& combinedLoops,
map<LoopGraph*, set<SgSymbol*>>& mapPrivates, vector<Messages>& messages, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph,
int& countOfTransform)
{
bool wasCombine = false;
for (LoopGraph* loop : nextLoops)
{
if (!loop->isFor)
return wasCombine;
int perfectLoop = std::min(firstLoop->perfectLoop, loop->perfectLoop);
const LoopGraph* curLoop = firstLoop;
for (int i = 0; i < perfectLoop; ++i)
{
SgForStmt* loopStmt = isSgForStmt(curLoop->loop->GetOriginal());
checkNull(loopStmt, convertFileName(__FILE__).c_str(), __LINE__);
if (curLoop->hasLimitsToCombine() || hasGotoToStatement(loopStmt))
return false;
if (i != perfectLoop - 1)
curLoop = curLoop->children[0];
}
map<SgSymbol*, SgSymbol*> symbolsFromLoopToRename;
int dimensionsForCombine = getDeepestDimForCombine(firstLoop, loop, perfectLoop);
LoopGraph* loopToReverse = NULL;
if (dimensionsForCombine == 0)
dimensionsForCombine = getDeepestDimToReverse(firstLoop, loop, perfectLoop, &loopToReverse, depInfoForLoopGraph);
if (dimensionsForCombine || canBeCombinedWithDiffBounds(firstLoop, loop))
{
if (solveVarsCollisions(firstLoop, loop, dimensionsForCombine, mapPrivates) == -1)
break;
if (dimensionsForCombine)
{
reverseLoop(loopToReverse, dimensionsForCombine);
compareIterationVars(firstLoop, loop, dimensionsForCombine, symbolsFromLoopToRename);
SgForStmt* innerMainLoop = getInnerLoop(firstLoop, dimensionsForCombine);
moveBody(getInnerLoop(loop, dimensionsForCombine), innerMainLoop, symbolsFromLoopToRename);
moveCommentsAndAttributes(loop->loop, firstLoop->loop);
loop->loop->extractStmt();
}
else
{
dimensionsForCombine = 1;
combineWithDifferentBounds(firstLoop, loop);
}
combinedLoops.insert(loop);
wasCombine = true;
//move in structure
LoopGraph* deep = loop, *parent = firstLoop;
for (int p = 0; p < dimensionsForCombine - 1; ++p)
{
deep = deep->children[0];
parent = parent->children[0];
}
for (auto& toMove : deep->children)
{
parent->children.push_back(toMove);
toMove->parent = parent;
}
deep->children.clear();
firstLoop->recalculatePerfect();
wstring strR, strE;
__spf_printToLongBuf(strE, L"Loops on line %d and on line %d were combined", firstLoop->lineNum, loop->lineNum);
__spf_printToLongBuf(strR, R100, firstLoop->lineNum, loop->lineNum);
messages.push_back(Messages(NOTE, firstLoop->lineNum, strR, strE, 2005));
__spf_print(1, "Loops on lines %d and %d were combined\n", firstLoop->lineNum, loop->lineNum);
countOfTransform++;
}
else
break;
}
return wasCombine;
}
/**
* Возвращает следующие loopsAmount циклов после nextAfterThis.
* Если loopsAmount < 0, вернёт все последующие циклы, до первого оператора-не-цикла.
*/
static vector<LoopGraph*> getNextLoops(LoopGraph* nextAfterThis, vector<LoopGraph*>& loops, int loopsAmount)
{
vector<LoopGraph*> result;
SgStatement* lastSt = nextAfterThis->loop->lastNodeOfStmt();
int z = 0;
for (; z < loops.size(); ++z)
if (loops[z] == nextAfterThis)
break;
if (z == loops.size())
return result;
else
z++;
for (; z < loops.size(); ++z)
{
if (loopsAmount == 0)
break;
SgStatement* loopSt = loops[z]->loop->GetOriginal();
if (lastSt->lexNext() != loopSt)
break;
else
{
lastSt = loopSt->lastNodeOfStmt();
result.push_back(loops[z]);
--loopsAmount;
}
}
return result;
}
static bool tryToCombine(vector<LoopGraph*>& loopGraphs, map<LoopGraph*, set<SgSymbol*>>& mapPrivates,
vector<Messages>& messages, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph,
int& countOfTransform)
{
if (loopGraphs.size() == 0)
return false;
bool change = false;
set<LoopGraph*> loopsToDelete;
vector<LoopGraph*> newloopGraphs;
vector<LoopGraph*> loops = loopGraphs;
for (size_t z = 0; z < loops.size(); ++z)
{
LoopGraph* loop = loops[z];
newloopGraphs.push_back(loop);
if (!loop->isFor)
continue;
vector<LoopGraph*> nextLoops = getNextLoops(loop, loopGraphs, -1);
set<LoopGraph*> combinedLoops;
change = false;
if (nextLoops.size())
change = combine(loop, nextLoops, combinedLoops, mapPrivates, messages, depInfoForLoopGraph, countOfTransform);
for (LoopGraph* combined : combinedLoops)
{
loopsToDelete.insert(combined);
loopGraphs.erase(find(loopGraphs.begin(), loopGraphs.end(), combined));
}
if (change)
{
LoopGraph* loopParent = loop;
while (loopParent->parent)
loopParent = loopParent->parent;
addIterationVarsToMap(loopParent, mapPrivates);
LoopGraph* outerParent = loop;
while (outerParent->parent)
outerParent = outerParent->parent;
outerParent->recalculatePerfect();
}
z += combinedLoops.size();
}
loopGraphs = newloopGraphs;
for (LoopGraph* elem : loopsToDelete)
delete elem;
if (change == false)
{
for (LoopGraph* ch : loopGraphs)
{
bool res = tryToCombine(ch->children, mapPrivates, messages, depInfoForLoopGraph, countOfTransform);
change |= res;
}
}
return change;
}
int combineLoops(SgFile* file, vector<LoopGraph*>& loopGraphs, vector<Messages>& messages,
const pair<string, int>& onPlace, const map<LoopGraph*, depGraph*>& depInfoForLoopGraph,
int& countOfTransform)
{
map<int, LoopGraph*> mapGraph;
createMapLoopGraph(loopGraphs, mapGraph);
map<LoopGraph*, set<SgSymbol*>> mapPrivates;
fillMapPrivateVars(loopGraphs, mapPrivates);
if (onPlace.second > 0)
{
if (onPlace.first != file->filename())
return 0;
else
{
const int onLine = onPlace.second;
auto it = mapGraph.find(onLine);
if (it == mapGraph.end())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
vector<LoopGraph*> nextLoops = getNextLoops(it->second, it->second->parent ? it->second->parent->children : loopGraphs, 1);
set<LoopGraph*> combinedLoops;
bool wasCombine = false;
if (nextLoops.size())
wasCombine = combine(it->second, nextLoops, combinedLoops, mapPrivates, messages, depInfoForLoopGraph, countOfTransform);
return 0;
}
}
bool change = true;
int count = 0;
while (change)
{
change = tryToCombine(loopGraphs, mapPrivates, messages, depInfoForLoopGraph, countOfTransform);
if (change)
count++;
}
/*printf(" === \n");
for (auto& elem : mapPrivates)
{
printf("for loop %d\n", elem.first->lineNum);
for (auto& priv : elem.second)
printf(" %s\n", priv->identifier());
}*/
return 0;
}