#include "../Utils/leak_detector.h" #include #include #include #include #include #include "dvm.h" #include "../ParallelizationRegions/ParRegions_func.h" #include "../Distribution/GraphCSR.h" #include "../Distribution/Arrays.h" #include "../Distribution/Distribution.h" #include "../Distribution/DvmhDirective_func.h" #include "../Utils/errors.h" #include "../LoopAnalyzer/loop_analyzer.h" #include "directive_parser.h" #include "directive_creator.h" #include "../Utils/SgUtils.h" #include "../Sapfor.h" #include "../GraphLoop/graph_loops_func.h" #include "../Transformations/loop_transform.h" #include "../ExpressionTransform/expr_transform.h" #include "../GraphCall/graph_calls_func.h" #include "../Utils/AstWrapper.h" #define PRINT_DIR_RESULT 0 #define FIRST(x) get<0>(x) #define SECOND(x) get<1>(x) #define THIRD(x) get<2>(x) using std::vector; using std::pair; using std::tuple; using std::map; using std::set; using std::make_pair; using std::make_tuple; using std::get; using std::string; using std::wstring; extern int sharedMemoryParallelization; static vector>> groupRealignsDirs(const vector>>& toRealign) { map, vector>> groupedRules; for (auto& rule : toRealign) { auto currRule = rule.second; string tRule = string(currRule[2]->unparse()); string arrRule = string(currRule[1]->unparse()); groupedRules[make_pair(tRule, arrRule)].push_back(currRule); } map, vector> mergedGroupedRules; for (auto& rule : groupedRules) { SgExprListExp* mergedList = new SgExprListExp(); for (int z = 0; z < rule.second.size(); ++z) { if (z == 0) mergedList->setLhs(rule.second[z][0]->GetOriginal()); else mergedList->append(*rule.second[z][0]->GetOriginal()); } vector medged = rule.second[0]; medged[0] = new Expression(mergedList); mergedGroupedRules[rule.first] = medged; } vector>> retVal; for (auto& elem : mergedGroupedRules) retVal.push_back(make_pair("", elem.second)); return retVal; } //create realigns instead of full template redistribution pair, vector> createRealignRules(Statement* st, const uint64_t regId, File *file, const string &templClone, const map> &arrayLinksByFuncCalls, const set& usedArrays, const pair linesBeforeAfter) { vector>>> optimizedRules(2); for (int num = 0; num < 2; ++num) { for (auto &elemPair : sortArraysByName(usedArrays)) { DIST::Array* elem = elemPair.second; if (elem->IsNotDistribute()) continue; auto realRef = getRealArrayRef(elem, regId, arrayLinksByFuncCalls); auto rules = realRef->GetAlignRulesWithTemplate(regId); auto links = realRef->GetLinksWithTemplate(regId); const auto &templ = realRef->GetTemplateArray(regId); if (templ == NULL) printInternalError(convertFileName(__FILE__).c_str(), __LINE__); vector realign = { NULL, NULL, NULL, NULL, NULL }; SgVarRefExp *ref = new SgVarRefExp((SgSymbol*)elem->GetNameInLocationS(st)); realign[0] = new Expression(ref); SgExprListExp *list = new SgExprListExp(); string base = "iEX"; for (int z = 0; z < elem->GetDimSize(); ++z) { if (z == 0) list->setLhs(*new SgVarRefExp(findSymbolOrCreate(file, base + std::to_string(z)))); else list->append(*new SgVarRefExp(findSymbolOrCreate(file, base + std::to_string(z)))); } realign[1] = new Expression(list); if (num == 0) realign[2] = new Expression(new SgArrayRefExp(*findSymbolOrCreate(file, templClone, new SgArrayType(*SgTypeInt())))); else realign[2] = new Expression(new SgArrayRefExp(*findSymbolOrCreate(file, templ->GetShortName(), new SgArrayType(*SgTypeInt())))); vector templateRuleEx(templ->GetDimSize()); std::fill(templateRuleEx.begin(), templateRuleEx.end(), (SgExpression*)NULL); for (int z = 0; z < elem->GetDimSize(); ++z) { if (links[z] != -1) { SgExpression *toSet = NULL; auto symb = new SgVarRefExp(*findSymbolOrCreate(file, base + std::to_string(z))); if (rules[z] == make_pair(1, 0)) toSet = symb; else if (rules[z].second == 0) toSet = &(*new SgValueExp(rules[z].first) * *symb); else if (rules[z].first == 1) toSet = &(*symb + *new SgValueExp(rules[z].second)); else toSet = &(*new SgValueExp(rules[z].first) * *symb + *new SgValueExp(rules[z].second)); templateRuleEx[links[z]] = toSet; } } for (int z = 0; z < templateRuleEx.size(); ++z) { SgExpression *toSet = NULL; if (templateRuleEx[z] == NULL) toSet = new SgVarRefExp(*findSymbolOrCreate(file, "*")); else toSet = templateRuleEx[z]; ((SgArrayRefExp*)realign[2]->GetOriginal())->addSubscript(*toSet); } optimizedRules[num].push_back(make_pair("", realign)); } } vector>>> groupedOptRules(2); groupedOptRules[0] = groupRealignsDirs(optimizedRules[0]); groupedOptRules[1] = groupRealignsDirs(optimizedRules[1]); pair, vector> retVal; for (auto& elem : groupedOptRules[0]) retVal.first.push_back(new CreatedDirective(elem.first, elem.second, linesBeforeAfter.first)); for (auto& elem : groupedOptRules[1]) retVal.second.push_back(new CreatedDirective(elem.first, elem.second, linesBeforeAfter.second)); return retVal; } static bool hasFunctionCall(SgExpression* ex) { bool ret = false; if (ex) { if (ex->variant() == FUNC_CALL) return true; ret |= hasFunctionCall(ex->lhs()); ret |= hasFunctionCall(ex->rhs()); } return ret; } static bool splitToBase(SgExpression *ex, pair &splited) { bool res = true; if (hasFunctionCall(ex)) return false; if (ex->variant() == VAR_REF || ex->variant() == ARRAY_REF || ex->variant() == MULT_OP) splited = make_pair(ex, 0); else { if (ex->variant() == SUBT_OP || ex->variant() == ADD_OP) { int minus = (ex->variant() == ADD_OP) ? 1 : -1; if (ex->rhs()) { int err, val; err = CalculateInteger(ex->rhs(), val); if (err == 0) { const int nextEx = ex->lhs()->variant(); if (nextEx == VAR_REF || nextEx == ARRAY_REF || nextEx == MULT_OP) splited = make_pair(ex->lhs(), minus * val); else if (nextEx == SUBT_OP || nextEx == ADD_OP) { pair splitedNext; bool res = splitToBase(ex->lhs(), splitedNext); if (res == false) return false; else splited = make_pair(splitedNext.first, minus * val + splitedNext.second); } else return false; } else return false; } else return false; } else return false; } return res; } static void analyzeRightPart(SgExpression *ex, map>>>> &rightValues, const map> &dimsNotMatch) { if (ex) { if (ex->variant() == ARRAY_REF) { const std::string name = ex->symbol()->identifier(); for (auto &elem : dimsNotMatch) { if (elem.first->GetShortName() == name) { int idx = 0; for (auto expr = ex->lhs(); expr; expr = expr->rhs(), ++idx) { if (elem.second[idx]) { int err, val; err = CalculateInteger(expr->lhs(), val); if (err == 0) { if (rightValues[elem.first][idx].first) { auto it = rightValues[elem.first][idx].second.find(""); if (it == rightValues[elem.first][idx].second.end()) rightValues[elem.first][idx].second[""] = make_pair(val, val); else { it->second.first = std::min(it->second.first, val); it->second.second = std::max(it->second.second, val); } } else { rightValues[elem.first][idx].first = true; rightValues[elem.first][idx].second[""] = make_pair(val, val); } } else { pair splited; bool result = splitToBase(expr->lhs(), splited); if (result) { if (rightValues[elem.first][idx].first) { auto key = string(splited.first->unparse()); auto itS = rightValues[elem.first][idx].second.find(key); if (itS == rightValues[elem.first][idx].second.end()) itS = rightValues[elem.first][idx].second.insert(itS, make_pair(key, make_pair(splited.second, splited.second))); else { itS->second.first = std::min(itS->second.first, splited.second); itS->second.second = std::max(itS->second.second, splited.second); } } else { rightValues[elem.first][idx].first = true; rightValues[elem.first][idx].second[string(splited.first->unparse())] = make_pair(splited.second, splited.second); } } } } } break; } } } analyzeRightPart(ex->lhs(), rightValues, dimsNotMatch); analyzeRightPart(ex->rhs(), rightValues, dimsNotMatch); } } static bool analyzeLeftPart(SgExpression *left, const map>& dimsNotMatch, map>>> &leftValues, string &base) { const std::string name = left->symbol()->identifier(); for (auto& elem : dimsNotMatch) { if (elem.first->GetShortName() == name) { int idx = 0; for (auto ex = left->lhs(); ex; ex = ex->rhs(), ++idx) { if (elem.second[idx]) { int err, val; err = CalculateInteger(ex->lhs(), val); if (err == 0) { if (leftValues[elem.first][idx].first) { if (leftValues[elem.first][idx].second.first != "") // has non zero base expression return false; if (leftValues[elem.first][idx].second.second != val) // has conflict writes return false; } else leftValues[elem.first][idx] = make_pair(true, make_pair("", val)); } else // WRITE OP can not recognized { pair splited; bool result = splitToBase(ex->lhs(), splited); if (result == false) return false; if (leftValues[elem.first][idx].first) { // has conflict writes if (leftValues[elem.first][idx].second.first != string(splited.first->unparse()) || leftValues[elem.first][idx].second.second != splited.second) return false; } else { base = string(splited.first->unparse()); leftValues[elem.first][idx] = make_pair(true, make_pair(base, splited.second)); } } } } break; } } return true; } bool analyzeLoopBody(LoopGraph* loopV, map>>>& leftValues, map>>>>& rightValues, string& base, const map> &dimsNotMatch, const map& mapFuncInfo) { SgStatement* loop = loopV->loop->GetOriginal(); for (auto st = loop; st != loop->lastNodeOfStmt(); st = st->lexNext()) { if (st->variant() == ASSIGN_STAT) { auto left = st->expr(0); if (left->variant() == ARRAY_REF) { bool ok = analyzeLeftPart(left, dimsNotMatch, leftValues, base); if (ok == false) return false; } analyzeRightPart(st->expr(1), rightValues, dimsNotMatch); } else if (st->variant() == PROC_STAT) { string name = st->symbol()->identifier(); if (isIntrinsicFunctionName(name.c_str()) == 0) { //TODO: contains and modules auto it = mapFuncInfo.find(name); int z = 0; for (SgExpression* ex = st->expr(0); ex; ex = ex->rhs(), ++z) { if (ex->lhs()->variant() == ARRAY_REF) { bool ok = true; if (it == mapFuncInfo.end()) ok = analyzeLeftPart(ex->lhs(), dimsNotMatch, leftValues, base); else { if (it->second->funcParams.isArgIn(z) && !it->second->funcParams.isArgOut(z)) analyzeRightPart(ex->lhs(), rightValues, dimsNotMatch); else ok = analyzeLeftPart(ex->lhs(), dimsNotMatch, leftValues, base); } if (ok == false) return false; } else analyzeRightPart(ex->lhs(), rightValues, dimsNotMatch); } } } else { for (int i = 0; i < 3; ++i) analyzeRightPart(st->expr(1), rightValues, dimsNotMatch); } } //is OK ? return true; } void createParallelDirs(File *file, map>& createdDirectives, vector& messages, const vector& loopsInFile, const map>& allFuncInfo, const vector& parallelRegions, const map& depInfoForLoopGraph, const map>& arrayLinksByFuncCalls) { const string file_name = file->filename(); map mapLoopsInFile; createMapLoopGraph(loopsInFile, mapLoopsInFile); map mapFuncInfo; createMapOfFunc(allFuncInfo, mapFuncInfo); for (int z = 0; z < parallelRegions.size(); ++z) { vector toInsert; const DataDirective& dataDirectives = parallelRegions[z]->GetDataDir(); const vector& currentVariant = parallelRegions[z]->GetCurrentVariant(); DIST::GraphCSR& reducedG = parallelRegions[z]->GetReducedGraphToModify(); DIST::Arrays& allArrays = parallelRegions[z]->GetAllArraysToModify(); auto& tmp = dataDirectives.distrRules; vector> currentVar; if (sharedMemoryParallelization == 0) { for (int z1 = 0; z1 < currentVariant.size(); ++z1) currentVar.push_back(make_pair(tmp[z1].first, &tmp[z1].second[currentVariant[z1]])); } else { for (auto& loop : mapLoopsInFile) { auto& rules = loop.second->getDataDir().distrRules; for (auto& rule : rules) currentVar.push_back(make_pair(rule.first, &rule.second[0])); } } selectParallelDirectiveForVariant(file, parallelRegions[z], reducedG, allArrays, loopsInFile, mapLoopsInFile, mapFuncInfo, currentVar, toInsert, parallelRegions[z]->GetId(), arrayLinksByFuncCalls, depInfoForLoopGraph, messages); if (toInsert.size() > 0) { auto it = createdDirectives.find(file_name); if (it == createdDirectives.end()) createdDirectives.insert(it, make_pair(file_name, toInsert)); else for (int m = 0; m < toInsert.size(); ++m) it->second.push_back(toInsert[m]); } } } #undef PRINT_DIR_RESULT #undef FIRST #undef SECOND #undef THIRD