일단 문제를 이해하면 전체 트리의 부분 트리 중 원소의 갯수가 K인 것의 갯수를 구하라는 것과 동치라는 사실을 쉽게 이해할 수 있다. 또한 트리라는 것을 감안할 때 DFS를 이용하여 루트가 밑쪽에 위치한 트리의 원소의 갯수를 활용하여 루트가 위쪽에 위치한 트리의 원소의 갯수를 도출시킬 수 있다는 사실을 파악하였다.
따라서 이러한 속성을 활용하여 dp[i][j]를 i를 Root node로 하고 해당 트리에 속한 원소의 갯수가 j를 만족하는 갯수라고 설정하였다.
이를 바탕으로 DFS를 타고 위로 올라가면서 점화식을 도출하려고 시도하였다.
다만 이 방법의 문제는 점화식을 세우는 것에 무리가 많다는 점이다.
dp[i][j]를 구하기 위해서는 i 노드의 자식을 얼마나 선택할지를 결정해야 한다.
예를 들어 i의 자식 노드의 갯수가 5개면, 자식을 선택하지 않을수도 있고 최대 5개 모두 선택할 수 있다.
따라서 이러한 측면에서 계산이 상당히 복잡하다. (다른 블로그들을 참고해보니 2개로 짠 사람도 존재하긴 하는데 좀 복잡하다.)
따라서 이러한 측면에서 수정된 접근 방법을 사용하였다.
#수정된 접근 방법
첫 번째로 접근한 방법의 문제점은 몇개를 선택해야하는지 순간순간 판단해주어야 한다는 것이다.
그래서 dp에 차원을 하나 더 올려주면 된다. (항상 DP를 잡는 기준은 판단해야 하는 기준을 설정하고 그 갯수만큼 차원을 잡아주면 된다.)
dp[i][j][k]를 i를 Root node로 하고 i의 자식 중 j번째 노드까지만 고려한 상태에서 트리의 원소의 갯수가 k를 만족하는 갯수라 설정하였다.
이러한 측면에서 만약 j가 1인 경우에는 해당 노드만 고려하면 된다. 따라서 점화식은 다음과 같다. (단, x1~xn을 노드 i 아래에 위치한 노드라 하자) dp[i][j][구하고자 하는 노드 갯수] = dp[x1][x1 밑의 원소갯수][구하고자 하는 노드 갯수 -1]이다.
다른 케이스의 경우는 dp[i][j][구하고자 하는 노드 갯수] = dp[i][j -1][구하고자 하는 노드 갯수 - alpha] + dp[xj][xj 밑의 원소 갯수][alpha] (단, 0 <= alpha < 구하고자 하는 노드 갯수) 이다. alpha가 구하고자 하는 노드 갯수를 포함하지 않는 이유는 중복되는 케이스를 일부 지우기 위해서이다.
이러한 방식을 활용해서 문제를 풀게되면 쉽게 처리할 수 있게 된다.
이 문제에서 가장 중요하게 보아야할 지점은 총 2가지이다.
1. 차원을 늘리는 느낌을 정확하게 이해해야 한다.
2개로 필요한 차원을 잡은 것까지는 잘했으나, 더 필요한 부분이 있고 그것을 기준으로 분할해서 구할 수 있으면 차원을 하나 늘리면 된다.
추가적으로 이 문제는 n이 50까지밖에 안되므로 Time complexity도 차원을 쌓는다고 하여 엄청나게 커지지는 않는다.
2. index를 활용해 어디까지 고려했다고 푸는 풀이에 대해 익숙해져야 한다.
최근에 푼 DP문제들이 이러한 느낌으로 푸는 문제들이 많다. DP를 활용하여 특정 index까지 고려했을 떄의 결과값을 저장해두는 양상에 대해서 기억해두도록 하자.
코드는 다음과 같다.
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
int city_num, passenger_num;
vector<int> child_num_store[51];
int road[51][51];
// 단, k = 0일때는 1로 초기화 할 것(곱 계산 할 때 필요함)
int dp[51][51][51]; // i : i를 루트로 가지는, j : 자식 중 j번째 자식까지 고려, k : 총 노드의 갯수 (만족하는 경우의 수)
int visited[51];
ll result = 0;
void refresh(){
result = result % 1000000007;
return;
}
int refresh(int value){
return value % 1000000007;
}
void dfs(int index){
if(visited[index] == 1) return; // Already visit (Early exit)
visited[index] = 1; // 방문처리
for(int i = 1; i < child_num_store[index].size();){
if(visited[child_num_store[index][i]] == 1){
child_num_store[index].erase(child_num_store[index].begin() + i); // 이미 방문한 경우 지워버림
}
else{
dfs(child_num_store[index][i]);
i++;
}
} // 자식들 업데이트
for(int i = 1; i < child_num_store[index].size(); i++){
if(i == 1){
for(int j = 0; j <= passenger_num; j++){
if(j == 0) dp[index][i][j] = 1; // Base case 설정
else{
dp[index][i][j] += refresh(dp[child_num_store[index][i]][child_num_store[child_num_store[index][i]].size() - 1][j - 1]);
}
}
}
else{
for(int j = 0; j <= passenger_num; j++){
if(j == 0) dp[index][i][j] = 1; // Base case 설정
else{
for(int k = 0; k < j; k++){
dp[index][i][j] += refresh(dp[index][i - 1][j - k] * dp[child_num_store[index][i]][child_num_store[child_num_store[index][i]].size() - 1][k]);
}
}
}
}
}
dp[index][0][1] = 1;
dp[index][0][0] = 1;
return;
}
int main(void){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
cin >> city_num >> passenger_num;
// 초기화
memset(road, 0, sizeof(road));
memset(dp, 0, sizeof(dp));
memset(visited, 0, sizeof(visited));
memset(road, 0, sizeof(road));
for(int i = 0; i < city_num - 1; i++){
int temp1, temp2;
cin >> temp1 >> temp2;
road[temp1][temp2] = 1;
road[temp2][temp1] = 1;
}
for(int i = 1; i <= city_num; i++){
child_num_store[i].push_back(0); // Trash data (index 맞추는 용도)
for(int j = 1; j <= city_num; j++){
if(road[i][j] == 1){
child_num_store[i].push_back(j);
}
}
} // 자식 미리 저장해놓음
dfs(1);
for(int i = 1; i <= city_num; i++){
result += dp[i][child_num_store[i].size() - 1][passenger_num];
refresh();
}
cout << result << "\n";
return 0;
}
대략 n^3짜리 DP를 채우는 과정이므로
O(N^3)정도라고 볼 수 있다. (아직 증명하면서 정확하게 다룰 수 있는 실력은 아니라서 대략적으로..)