improved module analysis

This commit is contained in:
ALEXks
2025-02-18 13:45:20 +03:00
committed by Dudarenko
parent a0c8f78868
commit 09401376c7
10 changed files with 113 additions and 153 deletions

View File

@@ -90,7 +90,6 @@ pair<vector<Directive*>, vector<Directive*>>
const pair<int, int> linesBeforeAfter)
{
vector<vector<pair<string, vector<Expression*>>>> optimizedRules(2);
auto byUse = moduleRefsByUseInFunction(st->GetOriginal());
for (int num = 0; num < 2; ++num)
{
@@ -108,7 +107,7 @@ pair<vector<Directive*>, vector<Directive*>>
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
vector<Expression*> realign = { NULL, NULL, NULL, NULL, NULL };
SgVarRefExp *ref = new SgVarRefExp(getFromModule(byUse, findSymbolOrCreate(file, elem->GetShortName())));
SgVarRefExp *ref = new SgVarRefExp((SgSymbol*)elem->GetNameInLocationS(st));
realign[0] = new Expression(ref);
SgExprListExp *list = new SgExprListExp();

View File

@@ -839,7 +839,7 @@ static pair<string, string> getModuleRename(const set<SgStatement*>& allocatable
set<string> arrayNames;
for (auto& alloc : allocatableStmts)
if (alloc->variant() == ALLOCATE_STMT)
arrayNames.insert(getNameByUse(alloc, array->GetShortName(), array->GetLocation().second));
arrayNames.insert(array->GetNameInLocation(alloc));
if (arrayNames.size() > 1 || arrayNames.size() == 0)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
@@ -1397,14 +1397,23 @@ static set<SgStatement*> filterAllocateStats(SgFile* file, const vector<SgStatem
if (!stat->switchToFile())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
auto byUse = moduleRefsByUseInFunction(stat);
for (auto &elem : byUse)
if (elem.first == array)
for (auto &newElem : elem.second)
arraySyns.insert(newElem->identifier());
SgExpression* list = stat->expr(0);
bool find = false;
while (list)
{
if (list->lhs() && list->lhs()->symbol())
{
if (OriginalSymbol(list->lhs()->symbol())->identifier() == array)
{
find = true;
break;
}
}
for (auto &syns : arraySyns)
if (recSymbolFind(stat->expr(0), syns, ARRAY_REF))
list = list->rhs();
}
if (find)
filtered.insert(stat);
SgFile::switchToFile(fileName);

View File

@@ -780,6 +780,8 @@ void addRemoteLink(const LoopGraph* loop, const map<string, FuncInfo*>& funcMap,
while (withDir && withDir->loop->GetOriginal()->lexPrev()->variant() != DVM_PARALLEL_ON_DIR)
withDir = withDir->parent;
checkNull(withDir, convertFileName(__FILE__).c_str(), __LINE__);
set<string> loopVars;
for (auto& elem : withDir->directive->parallel)
if (elem != "*")
@@ -871,8 +873,6 @@ ArrayRefExp* createRemoteLink(const LoopGraph* currLoop, const DIST::Array* forA
const set<string> allFiles = getAllFilesInProject();
SgStatement* realStat = (SgStatement*)currLoop->getRealStat(file->filename());
const map<string, set<SgSymbol*>> byUseInFunc = moduleRefsByUseInFunction(realStat);
SgStatement* parentFunc = getFuncStat(realStat);
const pair<int, int> lineRange = make_pair(parentFunc->lineNumber(), parentFunc->lastNodeOfStmt()->lineNumber());
SgExpression* ex = new SgExpression(EXPR_LIST);
SgExpression* p = ex;
@@ -885,21 +885,7 @@ ArrayRefExp* createRemoteLink(const LoopGraph* currLoop, const DIST::Array* forA
p = p->rhs();
}
}
SgArrayRefExp* newRem = NULL;
auto decls = forArray->GetDeclInfoWithSymb();
const string fName = current_file->filename();
/*for (auto& decl : decls)
{
if (decl.first.first == fName)
{
newRem = new SgArrayRefExp(*decl.second->GetOriginal(), *ex);
break;
}
}*/
if (!newRem)
newRem = new SgArrayRefExp(*getFromModule(byUseInFunc, forArray->GetDeclSymbol(fName, lineRange, allFiles)->GetOriginal()), *ex);
SgArrayRefExp* newRem = new SgArrayRefExp(*((SgSymbol*)forArray->GetNameInLocationS(realStat)), *ex);
return new ArrayRefExp(newRem);
}

View File

@@ -782,12 +782,7 @@ static void replacingShadowNodes(FuncInfo* currF)
const ShadowElement& currElement = currSh.second[0];
SgSymbol* s = currArray->GetDeclSymbol()->GetOriginal();
if (currArray->IsModuleSymbol())
{
const map<string, set<SgSymbol*>> byUseInFunc = moduleRefsByUseInFunction(currF->funcPointer->GetOriginal());
s = getFromModule(byUseInFunc, s);
}
SgSymbol* s = (SgSymbol*)currArray->GetNameInLocationS(currF->funcPointer);
//TODO: if moved from other file
/*auto itTmp = currElement.origNameByProc.find(currF);

View File

@@ -265,6 +265,9 @@ namespace Distribution
int GetDimSize() const { return dimSize; }
const STRING GetName() const { return name; }
const STRING GetShortName() const { return shortName; }
const STRING GetNameInLocation(void* location) const;
void* GetNameInLocationS(void* location) const;
unsigned GetId() const { return id; }
void SetSizes(VECTOR<PAIR<int, int>> &_sizes, bool notCopyToExpr = false)
{

View File

@@ -227,7 +227,7 @@ static vector<SgExpression*>
compliteTieList(const LoopGraph* currLoop, const vector<LoopGraph*>& loops,
const map<DIST::Array*, set<DIST::Array*>>& arrayLinksByFuncCalls,
const map<string, set<SgSymbol*>>& byUseInFunc,
File* file, const pair<int, int>& lineRange,
File* file, SgStatement *location,
const set<DIST::Array*>& onlyFor,
const set<string>& privates)
{
@@ -258,8 +258,7 @@ static vector<SgExpression*>
if (privates.find(pairs.second->GetShortName()) != privates.end())
continue;
auto type = pairs.second->GetDeclSymbol(currLoop->fileName, lineRange, getAllFilesInProject())->GetOriginal()->type();
SgSymbol* arrayS = getFromModule(byUseInFunc, findSymbolOrCreate(file, pairs.second->GetShortName(), type));
SgSymbol* arrayS = (SgSymbol*)pairs.second->GetNameInLocationS(location);
SgArrayRefExp* array = new SgArrayRefExp(*arrayS);
bool needToAdd = false;
@@ -509,8 +508,6 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
SgStatement* parentFunc = getFuncStat(realStat);
const map<string, set<SgSymbol*>> byUseInFunc = moduleRefsByUseInFunction(realStat);
const int nested = countPerfectLoopNest(loopG);
const pair<int, int> lineRange = make_pair(parentFunc->lineNumber(), parentFunc->lastNodeOfStmt()->lineNumber());
const string& filename = currLoop->fileName;
vector<SgSymbol*> loopSymbs;
vector<LoopGraph*> loops;
@@ -595,12 +592,14 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
if (arrayRef->IsTemplate())
{
if (mapTo->IsLoopArray())
symbForPar = getFromModule(byUseInFunc, findSymbolOrCreate(file, mapTo->GetShortName(), new SgArrayType(*SgTypeInt()), file->GetOriginal()->firstStatement()));
symbForPar = findSymbolOrCreate(file, mapTo->GetShortName(), new SgArrayType(*SgTypeInt()), file->GetOriginal()->firstStatement());
else
symbForPar = getFromModule(byUseInFunc, mapTo->GetDeclSymbol(filename, lineRange, allFiles)->GetOriginal());
{
symbForPar = (SgSymbol*)mapTo->GetNameInLocationS(parentFunc);
}
}
else
symbForPar = getFromModule(byUseInFunc, arrayRef->GetDeclSymbol(filename, lineRange, allFiles)->GetOriginal());
symbForPar = (SgSymbol*)arrayRef->GetNameInLocationS(parentFunc);
arrayExpr = new SgArrayRefExp(*symbForPar);
arrayExprS = "";
@@ -695,9 +694,9 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
}
vector<SgExpression*> tieList;
if (sharedMemoryParallelization)
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, byUseInFunc, file, lineRange, onlyFor, uniqNamesOfPrivates);
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, byUseInFunc, file, parentFunc, onlyFor, uniqNamesOfPrivates);
else if (onlyFor.size()) // not MPI regime
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, byUseInFunc, file, lineRange, onlyFor, uniqNamesOfPrivates);
tieList = compliteTieList(currLoop, loopsTie, arrayLinksByFuncCalls, byUseInFunc, file, parentFunc, onlyFor, uniqNamesOfPrivates);
if (tieList.size())
{
@@ -829,7 +828,7 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
acrossAdd += across[i1].first.first + "(" + bounds + ")";
SgArrayRefExp* newArrayRef = new SgArrayRefExp(*getFromModule(byUseInFunc, acrossArray->GetDeclSymbol(filename, lineRange, allFiles)->GetOriginal()));
SgArrayRefExp* newArrayRef = new SgArrayRefExp(*((SgSymbol*)acrossArray->GetNameInLocationS(parentFunc)));
newArrayRef->addAttribute(ARRAY_REF, acrossArray, sizeof(DIST::Array));
for (auto& elem : genSubscripts(across[i1].second, acrossShifts[i1]))
@@ -905,7 +904,7 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
}
shadowAdd += shadowRenew[i1].first.first + "(" + bounds + ")";
SgArrayRefExp* newArrayRef = new SgArrayRefExp(*getFromModule(byUseInFunc, shadowArray->GetDeclSymbol(filename, lineRange, allFiles)));
SgArrayRefExp* newArrayRef = new SgArrayRefExp(*((SgSymbol*)shadowArray->GetNameInLocationS(parentFunc)));
newArrayRef->addAttribute(ARRAY_REF, shadowArray, sizeof(DIST::Array));
for (auto& elem : genSubscripts(shadowRenew[i1].second, shadowRenewShifts[i1]))
@@ -1053,7 +1052,7 @@ ParallelDirective::genDirective(File* file, const vector<pair<DIST::Array*, cons
directive += it->first.second + ")";
DIST::Array* currArray = allArrays.GetArrayByName(it->first.first.second);
SgArrayRefExp* tmp = new SgArrayRefExp(*getFromModule(byUseInFunc, currArray->GetDeclSymbol(filename, lineRange, allFiles)->GetOriginal()), *it->second);
SgArrayRefExp* tmp = new SgArrayRefExp(*((SgSymbol*)currArray->GetNameInLocationS(parentFunc)), *it->second);
tmp->addAttribute(ARRAY_REF, currArray, sizeof(DIST::Array));
p->setLhs(tmp);

View File

@@ -654,9 +654,7 @@ void DvmhRegionInserter::insertActualDirective(SgStatement *st, const ArraySet &
vector<SgExpression*> list;
for (auto &arr : arraySet)
{
string arrayName = arr->GetShortName();
if (arr->GetLocation().first == DIST::l_MODULE)
arrayName = getNameByUse(st, arrayName, arr->GetLocation().second);
string arrayName = arr->GetNameInLocation(st);
if (exceptSymbs)
if (exceptSymbs->find(arrayName) != exceptSymbs->end())

View File

@@ -314,118 +314,89 @@ static SgStatement* findModWithName(const vector<SgStatement*>& modules, const s
return NULL;
}
string getNameByUse(SgStatement* place, const string& varName, const string& locName)
static map<SgStatement*, set<SgSymbol*>> symbolsForFunc;
static set<string> allFiles;
static const set<SgSymbol*>& getModeulSymbols(SgStatement *func)
{
if (symbolsForFunc.find(func) != symbolsForFunc.end())
return symbolsForFunc[func];
set<SgSymbol*> symbs;
SgSymbol* s = func->symbol()->next();
while (s)
{
if (s->scope() == func && IS_BY_USE(s))
symbs.insert(s);
s = s->next();
}
symbolsForFunc[func] = symbs;
return symbolsForFunc[func];
}
namespace Distribution
{
const string Array::GetNameInLocation(void* location_p) const
{
return ((SgSymbol*)GetNameInLocationS(location_p))->identifier();
}
void* Array::GetNameInLocationS(void* location_p) const
{
SgStatement* location = (SgStatement*)location_p;
int old_id = -1;
string oldFileName = "";
if (place->getFileId() != current_file_id)
if (location->getFileId() != current_file_id)
{
old_id = current_file_id;
oldFileName = current_file->filename();
if (!place->switchToFile())
if (!location->switchToFile())
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
SgStatement* func = getFuncStat(place, { MODULE_STMT });
string returnVal = varName;
if (func != NULL)
SgStatement* func = getFuncStat(location, { MODULE_STMT });
if (func == NULL)
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
if (allFiles.size() == 0)
allFiles = getAllFilesInProject();
const pair<int, int> lineRange = make_pair(func->lineNumber(), func->lastNodeOfStmt()->lineNumber());
const string& filename = func->fileName();
SgSymbol* returnVal = NULL;
if (locationPos.first == l_MODULE)
{
map<string, set<string>> graphUse;
const string& varName = shortName;
const string& locName = locationPos.second;
set<string> useMod;
map<string, vector<pair<SgSymbol*, SgSymbol*>>> modByUse;
map<string, vector<pair<SgSymbol*, SgSymbol*>>> modByUseOnly;
fillInfo(func, useMod, modByUse, modByUseOnly);
SgStatement* cp = func->controlParent();
if (isSgProgHedrStmt(cp) || cp->variant() == MODULE_STMT) // if function in contains region
fillInfo(cp, useMod, modByUse, modByUseOnly);
set<string> useModDone;
bool needRepeat = true;
vector<SgStatement*> modules;
findModulesInFile(func->getFile(), modules);
while (needRepeat)
map<string, SgSymbol*> altNames;
for (const auto& s : getModeulSymbols(func))
{
needRepeat = false;
set<string> newUseMod;
for (auto& useM : useMod)
SgSymbol* orig = OriginalSymbol(s);
if (orig->identifier() == varName && orig->scope()->symbol()->identifier() == locName)
{
if (useModDone.find(useM) == useModDone.end())
{
auto modSt = findModWithName(modules, useM);
if (modSt == NULL || useM == "dvmh_template_mod")
continue;
if (altNames.count(s->identifier()))
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
checkNull(modSt, convertFileName(__FILE__).c_str(), __LINE__);
set<string> tmpUse;
fillInfo(modSt, tmpUse, modByUse, modByUseOnly);
useModDone.insert(useM);
for (auto& use : tmpUse)
{
newUseMod.insert(use);
if (use != "dvmh_template_mod")
graphUse[use].insert(useM);
}
altNames[s->identifier()] = s;
}
}
for (auto& newU : newUseMod)
{
if (useModDone.find(newU) == useModDone.end())
{
useModDone.insert(newU);
needRepeat = true;
}
}
}
vector<string> altNames;
findByUse(modByUse, varName, { locName }, altNames);
findByUse(modByUseOnly, varName, { locName }, altNames);
if (altNames.size() == 0)
{
set<string> locations = { locName };
bool changed = true;
while (changed)
{
changed = false;
for (auto& loc : locations)
{
if (graphUse.find(loc) != graphUse.end())
{
for (auto& use : graphUse[loc])
{
if (locations.find(use) == locations.end())
{
locations.insert(use);
changed = true;
}
}
}
}
}
findByUse(modByUse, varName, locations, altNames);
findByUse(modByUseOnly, varName, locations, altNames);
}
if (altNames.size() == 0)
returnVal = varName;
else if (altNames.size() >= 1)
{
set<string> setAlt(altNames.begin(), altNames.end());
returnVal = *setAlt.begin();
}
if (altNames.size() > 0)
returnVal = altNames.begin()->second;
else
printInternalError(convertFileName(__FILE__).c_str(), __LINE__);
}
else
returnVal = GetDeclSymbol(filename, lineRange, allFiles);
checkNull(returnVal, convertFileName(__FILE__).c_str(), __LINE__);
if (old_id != -1)
{
@@ -435,6 +406,7 @@ string getNameByUse(SgStatement* place, const string& varName, const string& loc
return returnVal;
}
}
void fixUseOnlyStmt(SgFile *file, const vector<ParallelRegion*> &regs)
{

View File

@@ -7,7 +7,6 @@ std::map<std::string, std::set<std::string>> createMapOfModuleUses(SgFile* file)
void fillModuleUse(SgFile* file, std::map<std::string, std::set<std::string>>& moduleUses, std::map<std::string, std::string>& moduleDecls);
void filterModuleUse(std::map<std::string, std::set<std::string>>& moduleUses, std::map<std::string, std::string>& moduleDecls);
void fillUsedModulesInFunction(SgStatement* st, std::vector<SgStatement*>& useStats);
std::string getNameByUse(SgStatement* place, const std::string& varName, const std::string& locName);
void fillUseStatement(SgStatement* st, std::set<std::string>& useMod, std::map<std::string, std::vector<std::pair<SgSymbol*, SgSymbol*>>>& modByUse, std::map<std::string, std::vector<std::pair<SgSymbol*, SgSymbol*>>>& modByUseOnly);
void fixUseOnlyStmt(SgFile* file, const std::vector<ParallelRegion*>& regs);
std::map<std::string, std::set<SgSymbol*>> moduleRefsByUseInFunction(SgStatement* stIn);

View File

@@ -1,3 +1,3 @@
#pragma once
#define VERSION_SPF "2390"
#define VERSION_SPF "2391"