Problem Solving/BOJ

[백준 12995번] [동적 계획법] 트리나라

  • -
728x90
반응형

처음에 접근한 방식은 다음과 같다.

# 처음에 접근한 방법

일단 문제를 이해하면 전체 트리의 부분 트리 중 원소의 갯수가 K인 것의 갯수를 구하라는 것과 동치라는 사실을 쉽게 이해할 수 있다.
또한 트리라는 것을 감안할 때 DFS를 이용하여 루트가 밑쪽에 위치한 트리의 원소의 갯수를 활용하여 루트가 위쪽에 위치한 트리의 원소의 갯수를 도출시킬 수 있다는 사실을 파악하였다.

따라서 이러한 속성을 활용하여
dp[i][j]를 i를 Root node로 하고 해당 트리에 속한 원소의 갯수가 j를 만족하는 갯수라고 설정하였다.

이를 바탕으로 DFS를 타고 위로 올라가면서 점화식을 도출하려고 시도하였다.

다만 이 방법의 문제는 점화식을 세우는 것에 무리가 많다는 점이다.

dp[i][j]를 구하기 위해서는 i 노드의 자식을 얼마나 선택할지를 결정해야 한다.

예를 들어 i의 자식 노드의 갯수가 5개면, 자식을 선택하지 않을수도 있고 최대 5개 모두 선택할 수 있다. 

따라서 이러한 측면에서 계산이 상당히 복잡하다. (다른 블로그들을 참고해보니 2개로 짠 사람도 존재하긴 하는데 좀 복잡하다.)

 

따라서 이러한 측면에서 수정된 접근 방법을 사용하였다.

#수정된 접근 방법

첫 번째로 접근한 방법의 문제점은 몇개를 선택해야하는지 순간순간 판단해주어야 한다는 것이다.

그래서 dp에 차원을 하나 더 올려주면 된다. 
(항상 DP를 잡는 기준은 판단해야 하는 기준을 설정하고 그 갯수만큼 차원을 잡아주면 된다.)

dp[i][j][k]를 i를 Root node로 하고 i의 자식 중 j번째 노드까지만 고려한 상태에서 트리의 원소의 갯수가 k를 만족하는 갯수라 설정하였다.

이러한 측면에서 만약 j가 1인 경우에는 해당 노드만 고려하면 된다. 따라서 점화식은 다음과 같다.
(단, x1~xn을 노드 i 아래에 위치한 노드라 하자)
dp[i][j][구하고자 하는 노드 갯수] = dp[x1][x1 밑의 원소갯수][구하고자 하는 노드 갯수 -1]이다.

다른 케이스의 경우는
dp[i][j][구하고자 하는 노드 갯수] = dp[i][j -1][구하고자 하는 노드 갯수 - alpha] + dp[xj][xj 밑의 원소 갯수][alpha]
(단, 0 <= alpha < 구하고자 하는 노드 갯수) 이다. alpha가 구하고자 하는 노드 갯수를 포함하지 않는 이유는 중복되는 케이스를 일부 지우기 위해서이다.

이러한 방식을 활용해서 문제를 풀게되면 쉽게 처리할 수 있게 된다.

이 문제에서 가장 중요하게 보아야할 지점은 총 2가지이다.

1. 차원을 늘리는 느낌을 정확하게 이해해야 한다.

2개로 필요한 차원을 잡은 것까지는 잘했으나, 더 필요한 부분이 있고 그것을 기준으로 분할해서 구할 수 있으면 차원을 하나 늘리면 된다.

 

추가적으로 이 문제는 n이 50까지밖에 안되므로 Time complexity도 차원을 쌓는다고 하여 엄청나게 커지지는 않는다.

 

2. index를 활용해 어디까지 고려했다고 푸는 풀이에 대해 익숙해져야 한다.

최근에 푼 DP문제들이 이러한 느낌으로 푸는 문제들이 많다. DP를 활용하여 특정 index까지 고려했을 떄의 결과값을 저장해두는 양상에 대해서 기억해두도록 하자.

 

코드는 다음과 같다.

#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>

using namespace std;
typedef long long ll;

int city_num, passenger_num;
vector<int> child_num_store[51];
int road[51][51];
// 단, k = 0일때는 1로 초기화 할 것(곱 계산 할 때 필요함)
int dp[51][51][51]; // i : i를 루트로 가지는, j : 자식 중 j번째 자식까지 고려, k : 총 노드의 갯수 (만족하는 경우의 수)
int visited[51];
ll result = 0;

void refresh(){
    result = result % 1000000007;
    return;
}

int refresh(int value){
    return value % 1000000007;
}

void dfs(int index){
    if(visited[index] == 1) return; // Already visit (Early exit)
    visited[index] = 1; // 방문처리

    for(int i = 1; i < child_num_store[index].size();){
        if(visited[child_num_store[index][i]] == 1){
            child_num_store[index].erase(child_num_store[index].begin() + i); // 이미 방문한 경우 지워버림
        }
        else{
            dfs(child_num_store[index][i]);
            i++;
        }
    } // 자식들 업데이트

    for(int i = 1; i < child_num_store[index].size(); i++){
        if(i == 1){
            for(int j = 0; j <= passenger_num; j++){
                if(j == 0) dp[index][i][j] = 1; // Base case 설정
                else{
                    dp[index][i][j] += refresh(dp[child_num_store[index][i]][child_num_store[child_num_store[index][i]].size() - 1][j - 1]);
                }
            }
        }
        else{
            for(int j = 0; j <= passenger_num; j++){
                if(j == 0) dp[index][i][j] = 1; // Base case 설정
                else{
                    for(int k = 0; k < j; k++){
                        dp[index][i][j] += refresh(dp[index][i - 1][j - k] * dp[child_num_store[index][i]][child_num_store[child_num_store[index][i]].size() - 1][k]);
                    }
                }
            }
        }
    }
    dp[index][0][1] = 1;
    dp[index][0][0] = 1;
    return;
      
}

int main(void){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    
    cin >> city_num >> passenger_num;

    // 초기화
    memset(road, 0, sizeof(road));
    memset(dp, 0, sizeof(dp));
    memset(visited, 0, sizeof(visited));
    memset(road, 0, sizeof(road));

    for(int i = 0; i < city_num - 1; i++){
        int temp1, temp2;
        cin >> temp1 >> temp2;
        road[temp1][temp2] = 1;
        road[temp2][temp1] = 1;
    }

    for(int i = 1; i <= city_num; i++){
        child_num_store[i].push_back(0); // Trash data (index 맞추는 용도)
        for(int j = 1; j <= city_num; j++){
            if(road[i][j] == 1){
                child_num_store[i].push_back(j);
            }
        }
    } // 자식 미리 저장해놓음
    dfs(1);

    for(int i = 1; i <= city_num; i++){
        result += dp[i][child_num_store[i].size() - 1][passenger_num];
        refresh();     
    }

    cout << result << "\n";
    return 0;
}

대략 n^3짜리 DP를 채우는 과정이므로

O(N^3)정도라고 볼 수 있다. (아직 증명하면서 정확하게 다룰 수 있는 실력은 아니라서 대략적으로..)

 

N이 최대 50이므로 10^8까지는 여유가 넘친다.

반응형
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.