[中山OI2011]小W的问题 题解
题目地址:BZOJ:Problem 2441. — [中山市选2011]小W的问题
题目描述
有一天,小W找了一个笛卡尔坐标系,并在上面选取了N个整点。他发现通过这些整点能够画出很多个“W”出来。具体来说,对于五个不同的点(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5),如果满足:
- x1 < x2 < x3 < x4 < x5
- y1 > y3 > y2
- y5 > y3 > y4
则称它们构成一个“W”形。
现在,小W想统计“W”形的个数,也就是满足上面条件的五元点组个数。你能帮助他吗?
输入输出格式
输入格式:
第一行包含一个整数N,表示点的个数。
下面N行每行两个整数,第i+1行为(xi, yi),表示第i个点的坐标。
输出格式:
仅包含一行,为“W”形个数模1 000 000 007的值。
输入输出样例
输入样例#1:
6 1 10 2 1 3 5 4 6 5 1 6 10
输出样例#1:
3
说明
对于100%的数据满足N ≤ 200 000,0 ≤ xi ≤ 10^9,0 ≤ yi ≤ 10^9
题解
首先我们考虑将一整个“W”拆成左右两个“V”形分别计算。我们以左高右低的“V”形为例。考虑一个点右上区域中的点数,这个点数可以认为是V形最左侧的可选点数。而对于一个V形右侧的点,能构成的V形数量显然就是左侧所有比它低的点的右上点之和。如果看到这里你还是感觉很懵,请看下面的图片。
到这里,我们知道,我们要求的就是对于每个点,它左边任意一点的左上方的点的数量的和。我们先将点按y坐标排序,从小到大,每一次计算一层相同y坐标的点。我们用线段树维护x坐标(预先离散化)在这个区间内的点的左上方的点的数量的和。处理一个y坐标的时候,对于每一个点,把大于它x坐标的点的线段树区间-1,因为这个点已经低于之后处理的那些点了,无法成为V形左边的那个点。然后计算低于这个点的和,作为这个点能构成V形的数量。最后把这个点左边的点的个数加入线段树这个点x坐标的位置,和在这之前算过的-1合并作为这个点的左上点数量。需要注意的是,一层y坐标处理时,先对全部点处理-1,再对全部点计算答案,最后再对全部点计算左上点数量。这是因为同y坐标不能互为左上点,如果边计算答案边计算左上点数量,可能会算进去同y坐标的点的答案。
另一边左低右高的V形可以反着推一下,或者看代码的计算方法。
最后的答案为把同一个点两个V形的种类乘起来的和,原理是乘法原理。
这里我们需要对线段树进行魔改,因为有一些x坐标处没有点被计算过,这个位置的值还是负值,不应该加进去,所以要统计一下当前区间内有多少x坐标是算过的,维护一个len。细节看代码吧。
实话说,这个题属于“我知道我要求啥,但是我不知道怎么求要求的这个东西”的一类题目,看别人题解看的一脸懵逼,思路还是在写这篇文章的时候逐渐清晰的。考场上估计就是一脸不可做了。如果对过程还是很懵逼,在这篇文章下留言,我尽可能回答你的问题,并且改进题解叙述。
代码
// Code by KSkun, 2018/3
#include <cstdio>
#include <cstring>
#include <algorithm>
typedef long long LL;
inline char fgc() {
static char buf[100000], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1++;
}
inline int readint() {
register int res = 0, neg = 1;
char c = fgc();
while(c < '0' || c > '9') {
if(c == '-') neg = -1;
c = fgc();
}
while(c >= '0' && c <= '9') {
res = res * 10 + c - '0';
c = fgc();
}
return res * neg;
}
// variables
const int MO = 1e9 + 7, MAXN = 200005, INF = 2e9;
int n;
LL f[MAXN][2], ans = 0;
struct Point {
int x, y, id, nx;
} pts[MAXN];
LL tmp[MAXN], ttot = 0;
inline bool cmpx(Point a, Point b) {
return a.x < b.x;
}
inline bool cmpy(Point a, Point b) {
return a.y < b.y;
}
// seg tree
struct Node {
LL val, add, len;
} tr[MAXN << 2];
inline void apply(int o, LL v) {
tr[o].add = (tr[o].add + v) % MO;
tr[o].val = (tr[o].val + v * tr[o].len) % MO;
}
inline void pushdown(int o, int l, int r) {
if(l == r) return;
int lch = o << 1, rch = o << 1 | 1;
if(tr[o].add != 0) {
apply(lch, tr[o].add);
apply(rch, tr[o].add);
tr[o].add = 0;
}
}
inline void merge(int o, int l, int r) {
if(l == r) return;
int lch = o << 1, rch = o << 1 | 1;
tr[o].val = 0;
tr[o].len = tr[lch].len + tr[rch].len;
if(tr[lch].len) tr[o].val = (tr[o].val + tr[lch].val) % MO;
if(tr[rch].len) tr[o].val = (tr[o].val + tr[rch].val) % MO;
}
inline void add(int o, int l, int r, int ll, int rr, LL v) {
if(ll > rr) return;
pushdown(o, l, r);
if(l >= ll && r <= rr) {
apply(o, v);
return;
}
int lch = o << 1, rch = o << 1 | 1, mid = (l + r) >> 1;
if(ll <= mid) add(lch, l, mid, ll, rr, v);
if(rr > mid) add(rch, mid + 1, r, ll, rr, v);
merge(o, l, r);
}
inline void add(int o, int l, int r, int x, LL v) {
pushdown(o, l, r);
if(l == r) {
tr[o].len = 1;
tr[o].val = (tr[o].add + v) % MO;
return;
}
int lch = o << 1, rch = o << 1 | 1, mid = (l + r) >> 1;
if(x <= mid) add(lch, l, mid, x, v);
if(x > mid) add(rch, mid + 1, r, x, v);
merge(o, l, r);
}
inline LL query(int o, int l, int r, int ll, int rr) {
if(ll > rr || !tr[o].len) return 0;
pushdown(o, l, r);
if(l >= ll && r <= rr) {
return tr[o].val;
}
int lch = o << 1, rch = o << 1 | 1, mid = (l + r) >> 1;
LL res = 0;
if(ll <= mid) res = (res + query(lch, l, mid, ll, rr)) % MO;
if(rr > mid) res = (res + query(rch, mid + 1, r, ll, rr)) % MO;
return res;
}
inline void add(int l, int r, LL v) {
add(1, 1, n, l, r, v);
}
inline void add(int x, LL v) {
add(1, 1, n, x, v);
}
inline LL query(int l, int r) {
return query(1, 1, n, l, r);
}
int main() {
n = readint();
for(int i = 1; i <= n; i++) {
pts[i].x = readint();
pts[i].y = readint();
tmp[++ttot] = pts[i].x;
}
tmp[++ttot] = INF;
std::sort(tmp + 1, tmp + ttot + 1);
std::sort(pts + 1, pts + n + 1, cmpx);
for(int i = 1; i <= n; i++) {
pts[i].nx = std::upper_bound(tmp + 1, tmp + ttot + 1, pts[i].x) - tmp;
pts[i].x = std::lower_bound(tmp + 1, tmp + ttot + 1, pts[i].x) - tmp;
pts[i].id = i;
}
std::sort(pts + 1, pts + n + 1, cmpy);
for(int i = 1; i <= n; i++) {
int j = i;
while(j < n && pts[i].y == pts[j + 1].y) j++;
for(int k = i; k <= j; k++) {
add(pts[k].nx, n, -1);
}
for(int k = i; k <= j; k++) {
f[pts[k].id][0] = query(1, pts[k].x - 1);
}
for(int k = i; k <= j; k++) {
add(pts[k].id, pts[k].x - 1);
}
i = j;
}
memset(tr, 0, sizeof(tr));
for(int i = 1; i <= n; i++) {
int j = i;
while(j < n && pts[i].y == pts[j + 1].y) j++;
for(int k = i; k <= j; k++) {
add(1, pts[k].x - 1, -1);
}
for(int k = i; k <= j; k++) {
f[pts[k].id][1] = query(pts[k].nx, n);
}
for(int k = i; k <= j; k++) {
add(pts[k].id, n - pts[k].nx + 1);
}
i = j;
}
for(int i = 1; i <= n; i++) {
ans = (ans + f[i][0] * f[i][1] % MO) % MO;
}
printf("%lld", ans);
return 0;
}