잘 생각해보면 bitwise or 계산 결과는 단조증가할 수 밖에 없다. 구간의 시작점을 1 ~ N까지 이동시키면서, lower_bound를 해서 k가 나오면 해당 구간을 출력해주면 된다. 구간 별 xor 결과는 세그먼트 트리를 통해서 관리를 하고, 이를 이분탐색으로 만족하는 구간이 있는지를 파악해주면 된다.
대회가 끝나고 공식 해설을 보고, 다음과 같은 방법으로 풀면 좀 더 쉽게 처리할 수 있음을 파악하였다.
잘 생각해보면, k와 xor한 결과가 k이면 구간 안에 들어갈 수 있는 수이고, 그렇지 않으면 무조건 들어갈 수 없는 수이다.
따라서, 들어갈 수 있는 숫자들을 전부 다 더했을 때 k가 나오면 문제조건을 만족할 수 있는 구간으로 생각할 수 있는 것이다.
Code 1 (세그먼트 트리 + 이분탐색)
#include <bits/stdc++.h>
#define fastio ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
typedef long long ll;
ll seg[800004];
int n, k;
ll update(int node_index, int node_left, int node_right, int index, int value)
{
if (index < node_left || node_right < index) return seg[node_index]; // Out of bound
if(node_left == node_right){
return seg[node_index] = value;
}
else{
int mid = (node_left + node_right) / 2;
return seg[node_index] = update(node_index * 2, node_left, mid, index, value) |
update(node_index * 2 + 1, mid + 1, node_right, index, value);
}
}
ll query(int node_index, int node_left, int node_right, int query_left, int query_right)
{
if (query_right < node_left || node_right < query_left) return 0LL;
else if (query_left <= node_left && node_right <= query_right) return seg[node_index];
int mid = (node_left + node_right) / 2;
return query(node_index * 2, node_left, mid, query_left, query_right) | query(node_index * 2 + 1, mid + 1, node_right, query_left, query_right);
}
int main(void){
fastio;
memset(seg, 0, sizeof(seg));
cin >> n >> k;
for(int i = 1; i <= n; i++){
int t;
cin >> t;
update(1, 1, n, i, t);
}
for(int i = 1; i <= n; i++){
int start = i - 1;
int end = n;
while(start + 1 < end){
int mid = (start + end) / 2;
if (query(1, 1, n, i, mid) >= k) end = mid;
else start = mid;
}
if (query(1, 1, n, i, end) == k)
{
cout << i << " " << end;
return 0;
}
}
cout << -1 << "\n";
return 0;
}
Code 2 (비트마스킹 + 그리디)
#include <bits/stdc++.h>
#define fastio ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
int main(void){
fastio;
int n, k;
cin >> n >> k;
int cal = 0;
int s = 1;
for (int i = 1; i <= n; i++)
{
int t;
cin >> t;
if ((t | k) != k)
{
cal = 0;
s = i + 1;
}
else
{
cal |= t;
if (cal == k)
{
cout << s << " " << i << "\n";
return 0;
}
}
}
cout << -1 << "\n";
return 0;
}