위 문제를 풀고 나서 되돌아보며 부족했던 부분을 채우고자 합니다. 저의 풀이와 다른 사람들의 풀이를 보고 느낀 점을 서술합니다.
나의 풀이
처음에 문제를 대충 읽었다가 주어지는 게 그냥 그래프인 줄 알았는데, 간선 입력이 N-1개 주어진다는 걸 보고 문제를 다시 읽어 보니 트리인 걸 알아차렸다.
문제를 보면 일단 어디가 루트 노드인지 알 수 없기 때문에 기본적으로 1번 노드를 루트 노드로 하여금 트리를 형성하고, 트리를 순회하면서 값을 계산하도록 한다. 이때, 문제를 작은 문제로 줄여서 subtree에 대한 최댓값을 구했을 때, 해당 값을 이용해서 더 큰 문제(parent 노드를 포함하는 문제)를 풀 수 있게 된다. 이를 통해서, 노드 u를 루트 노드로 하는 subtree의 문제의 최댓값을 DP 배열에 저장해서 계산 속도를 줄이면 된다.
최댓값을 구하는 과정에서 2번 조건인 우수 마을끼리는 인접할 수 없는 것과 3번 조건인 일반 마을은 적어도 하나의 마을이 우수 마을과 인접해야 한다는 조건을 생각해 보았을 때, 아래와 같이 경우를 나눠서 생각해 볼 수 있다.
- 현재 방문한 마을이 우수 마을일 때
- 자식 마을은 우수 마을이 되면 안된다.
- 현재 방문한 마을이 일반 마을일 때
- 부모 마을이 우수 마을이라면, 자식 마을의 우수 마을 여부는 상관 없다.
- 부모 마을이 일반 마을이라면, 적어도 하나의 자식 마을은 우수 마을이여야 한다.
이렇게 총 3개의 케이스가 나오고, 각각의 경우는 마을을 순회하면서 해당 마을까지 방문했을 때 해당 마을을 포함한 최신 방문들 중 연속한 일반 마을의 개수로 생각할 수 있다.
1-1은 현재 방문한 마을이 우수 마을이기 때문에 0이고, 2-1은 현재 방문한 마을이 일반 마을이면서 부모 마을이 우수 마을이기 때문에 1, 2-2는 현재 방문한 마을이 일반 마을이고 부모 마을도 우수 마을이기 때문에 2의 값을 갖는다. 즉, DP 배열의 크기는 10000*3이다.
위 아이디어를 통해서 작성한 코드를 살펴 보자.
입력을 트리로 변환하는 함수
int N;
vector<int> popul;
vector<vector<int>> adj, children;
int cache[10001][3];
void generateChildren(int here, vector<bool> &visited)
{
for (auto there : adj[here])
if (!visited[there])
{
visited[there] = true;
children[here].push_back(there);
generateChildren(there, visited);
}
}
해당 함수는 DFS를 통해 입력(adj)을 트리(children)로 변환한다.
최대 인구 수를 구하는 함수
int getMaxPopulation(int here, int passed)
{
int &ref = cache[here][passed];
if (ref != -1) return ref;
ref = passed ? 0 : popul[here];
if (passed == 2)
{
int minDiff = 987654321;
for (auto c : children[here])
{
int s = getMaxPopulation(c, 0), ns = getMaxPopulation(c, 2);
if (s < ns)
{
ref += ns;
minDiff = min(minDiff, ns - s);
}
else
{
ref += s;
minDiff = -1;
}
}
if (minDiff != -1)
ref -= minDiff;
}
else if (passed == 1)
for (auto c : children[here])
ref += max(getMaxPopulation(c, 0), getMaxPopulation(c, 2));
else
for (auto c : children[here])
ref += getMaxPopulation(c, 1);
return ref;
}
passed가 위에서 말한 현재 방문한 마을까지의 가장 최근의 연속한 일반 마을의 개수로, passed == 0(마지막 else 문)일 때에는 자식 마을들은 무조건 일반 마을이어야 하기 때문에 ref에 getMaxPopulation(c, 1)을 더해 준다.
passed == 1일 때는 0일 때와 비슷하게 자식 마을들이 어떤 마을이든 상관 없기 때문에 ref에 getMaxPopulation(c, 0)과 getMaxPopulation(c, 2) 중 큰 값을 더해 준다.
passed == 2일 때에는 자식 마을들이 어떤 마을이든 상관 없으나 최소 하나의 자식 마을은 우수 마을이어야 하기 때문에 모든 자식 마을이 일반 마을이 되지 않도록 해준다. 모든 자식 마을이 일반 마을일 경우에는 하나를 우수 마을로 바꾸어 주는데, 그 값이 최소가 되게 하기 위해서 각각의 경우에 자식 마을이 우수 마을일 때와 일반 마을일 때의 차이의 최소를 따로 저장한다.
모든 자식 마을이 우수 마을일 경우(minDiff != -1), ref 값은 minDiff만큼 감소하도록 한다.
위 함수들을 실행하는 함수
int solve()
{
vector<bool> visited(N, false);
visited[0] = true;
generateChildren(0, visited);
memset(cache, -1, sizeof(cache));
return max(getMaxPopulation(0, 0), getMaxPopulation(0, 1));
}
적절히 입력을 받고(나의 경우 각각의 노드를 입력받을 때 노드 번호를 1씩 빼줘서 0부터 시작하도록 했다), 트리를 생성하고, 최댓값을 찾아 반환하도록 한다. 루트 노드의 경우 passed == 2인 경우는 존재할 수 없기 때문에 0과 1만 구하면 된다.
다른 사람들의 풀이
그런데 이 문제를 풀면서 굳이 경우의 수를 3가지까지 고려해야 하는지에 대한 의문이 생겼고, 다른 사람들의 코드를 열어 본 결과 다들 2가지의 경우만 고려하는 것을 보았다. 즉, 현재 방문한 마을이 우수 마을인지 아닌지의 경우만 따지는 것이다.
이게 어떻게 가능한지에 대해 생각해 본 결과 문제에서 주어진 3번 조건이 필요 없다는 점을 알게 되었다. 1번 조건(최댓값을 구해야 한다)과 2번 조건(우수 마을은 서로 인접하지 않는다)라는 조건으로 3번 조건은 자명해지기 때문이다.
3번 조건이 없을 때 3번 조건이 자명하지 않다고 가정해 보자. 즉, 3번 조건이 없다면 우수-일반-일반-일반-우수를 가지는 최댓값이 존재하는 것이다. 그러나 당연하게도 이것보다 우수-일반-우수-일반-우수로 일반들 사이에 끼어 있는 일반 마을을 우수 마을로 선택하는 경우가 더 큰 값이 되므로, 3번 조건이 없어도 3번 조건이 자명함을 알 수 있다.
즉, 1번 조건인 최댓값을 구하는 함수(getMaxPopulation())를 2번 조건인 우수 마을은 서로 인접하지 않는다는 조건만으로 해결할 수 있게 된다.
또한, 더 개선할 수 있는 점이 존재한다. 하나는 입력 값을 트리로 미리 변경하지 않아도 된다는 점이다. 최댓값 계산 함수를 재귀적으로 호출하는 도중 visited 배열만 가지고 있다면 한 마을을 두 번 이상 방문하는 것을 방지할 수 있으면서도 모든 마을을 순회할 수 있기 때문이다. 또 하나의 개선점은, 한 번의 함수 호출만으로 현재 방문한 마을이 우수 마을인 경우와 일반 마을인 경우를 모두 계산하도록 할 수 있다는 점이다.
그러므로, 우리는 generateChildren() 함수를 제거하고 getMaxPopulation() 함수를 아래와 같이 개선해서 코드를 작성할 수 있다.
int N;
vector<int> popul;
vector<vector<int>> adj;
int cache[10001][2];
void generateMaxPopulation(int here, vector<bool> &visited)
{
int &nor = cache[here][0], &sup = cache[here][1];
nor = 0, sup = popul[here];
for (auto next : adj[here])
{
if (!visited[next])
{
visited[next] = true;
generateMaxPopulation(next, visited);
nor += max(cache[next][0], cache[next][1]);
sup += cache[next][0];
}
}
}
int solve()
{
vector<bool> visited(N, false);
visited[0] = true;
memset(cache, -1, sizeof(cache));
generateMaxPopulation(0, visited);
return max(cache[0][0], cache[0][1]);
}
마치며
일반적으로 백준 문제를 풀면서 헷갈렸던 문제는 AC를 띄워도 다른 사람들의 코드를 보면서 무엇이 다른지 비교해 보는데, 이번 문제의 경우 다른 사람들의 코드를 보면서 지긋이 생각할 수 있는 기회를 가질 수 있어서 좋았습니다.
특히, 이번 문제의 경우 트리와 DP를 결합한 좋은 문제이기 때문에 추천드리고 싶습니다.
'공부 > PS' 카테고리의 다른 글
백준 64일 스트릭 회고 (0) | 2022.11.03 |
---|---|
PS에 편한 C++ 문법 (1) | 2022.09.29 |
최단 거리 알고리즘의 이해 (0) | 2022.09.22 |
KMP 알고리즘 (0) | 2022.02.13 |