Approach
특정 구간에서 k보다 작은 개수/큰 개수는 일반적으로 머지 소트 트리를 활용해서 처리하는 경우가 많다.
머지소트가 진행되는 양상을 잘 살펴보면, 세그먼트가 구성되는 느낌이랑 거의 유사하다.
따라서 어떠한 쿼리에서 k보다 큰 개수를 구하는 과정은 다음과 같다.
1) 원하는 쿼리와 현재 탐색중인 쿼리가 전혀 교차하지 않으면 0
2) 원하는 쿼리가 현재 탐색중인 쿼리를 포함하고 있을 경우, 이진탐색하여 k보다 큰 개수를 구한다. (upper bound)
3) 겹치는 경우에는 재귀적으로 처리하면 1)과 2)의 연산의 합으로 구할 수 있다.
k보다 크거나 같은/ 작거나 같은/ 작은의 경우도 비슷한 방식으로 처리해주면 된다.
Code
#include <bits/stdc++.h>
#define fastio ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
vector<int> seg[400000];
void update(int node_index, int node_left, int node_right, int index, int value){
if(index < node_left || node_right < index) return;
seg[node_index].push_back(value);
if(node_left == node_right) return;
int mid = (node_left + node_right) / 2;
update(node_index * 2, node_left, mid, index, value);
update(node_index * 2 + 1, mid + 1, node_right, index, value);
return;
}
int query(int node_index, int node_left, int node_right, int query_left, int query_right, int value){
if(query_right < node_left || node_right < query_left) return 0;
if(query_left <= node_left && node_right <= query_right){
return seg[node_index].end() - upper_bound(seg[node_index].begin(), seg[node_index].end(), value);
}
int mid = (node_left + node_right) / 2;
return query(node_index * 2, node_left, mid, query_left, query_right, value) +
query(node_index * 2 + 1, mid + 1, node_right, query_left, query_right, value);
}
int main(void){
fastio;
int n;
cin >> n;
for(int i = 1; i <= n; i++){
int t;
cin >> t;
update(1, 1, n, i, t);
}
// Sorting Merge sort tree
for(int i = 0; i <= 400000; i++){
sort(seg[i].begin(), seg[i].end());
}
int m;
cin >> m;
for(int i = 1; i <= m; i++){
int a, b, c;
cin >> a >> b >> c;
cout << query(1, 1, n, a, b, c) << "\n";
}
return 0;
}
Another Approach
오프라인 쿼리와 스위핑 기법을 활용한 접근 방법으로 이 문제를 풀 수 있다.
접근 방법 자체는 쿼리에서 k값을 미리 다 받아놓고, k가 큰 것부터 쿼리를 처리해서 구간합 문제로 돌려풀 수 있다.
수열의 입력값 각각에 대해서, 해당 값보다 작은 k중 가장 큰 k를 구하고 이를 포함한 k가 쿼리로 주어졌을 때 1씩 update시켜주면 된다.
Code
#include <bits/stdc++.h>
#define fastio ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef pair<piii, int> piiii;
int seg[400000];
bool compare(piiii &v1, piiii & v2){
return v1.first.second > v2.first.second;
}
void update(int node_index, int node_left, int node_right, int index){
if(index < node_left || node_right < index) return;
seg[node_index]++;
if(node_left == node_right) return;
int mid = (node_left + node_right) / 2;
update(node_index * 2, node_left, mid, index);
update(node_index * 2 + 1, mid + 1, node_right, index);
return;
}
int query(int node_index, int node_left, int node_right, int query_left, int query_right){
if(query_right < node_left || node_right < query_left) return 0;
if(query_left <= node_left && node_right <= query_right) return seg[node_index];
int mid = (node_left + node_right) / 2;
return query(node_index * 2, node_left, mid, query_left, query_right) +
query(node_index * 2 + 1, mid + 1, node_right, query_left, query_right);
}
int main(void){
fastio;
memset(seg, 0, sizeof(seg));
int N;
cin >> N;
vector<int> data(N); // 수열 저장
for(int i = 0; i < N; i++){
cin >> data[i];
}
int M;
cin >> M;
vector<piiii> query_value;
vector<int> value_store; // k값 묶어둔 것
for(int i = 0; i < M; i++){
int a, b, c;
cin >> a >> b >> c;
query_value.push_back(make_pair(make_pair(make_pair(a, b), c),i));
value_store.push_back(c);
}
sort(query_value.begin(), query_value.end(), compare); // 주어진 쿼리를 k가 큰 순서대로 정렬
// k값 압축 및 정렬
sort(value_store.begin(), value_store.end());
value_store.erase(unique(value_store.begin(), value_store.end()), value_store.end());
map<int, int> index;
for(int i = 0; i < value_store.size(); i++){
index.insert(make_pair(value_store[i], i));
} // k값 자체를 index로 매핑해주는 map 자료구조
// 자기보다 작은 k중 가장 큰 것 찾기(Lower bound 사용)
vector<int> update_store[value_store.size()];
for(int i = 0; i < N; i++){
int start = 0;
int end = value_store.size();
while(start < end){
int mid = (start + end) / 2;
if(data[i] > value_store[mid]) start = mid + 1;
else end = mid;
}
if(start >= 1) update_store[start - 1].push_back(i);
}
vector<int> result(M);
for(int i = 0; i < M; i++){
int k = query_value[i].first.second;
// 업데이트
for(int j = 0; j < update_store[index[k]].size(); j++){
update(1, 0, N - 1, update_store[index[k]][j]);
}
update_store[index[k]].clear();
result[query_value[i].second] = query(1, 0, N - 1, query_value[i].first.first.first - 1, query_value[i].first.first.second - 1);
}
for(int i = 0; i < M; i++){
cout << result[i] << "\n";
}
return 0;
}