Mannattan_minimum_spanning_tree 为添边操作
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll mod = 1e9+7;
const int maxd = 2e5+10;
const int inf = 0x3f3f3f3f;
struct Point{
int x,y,id;
bool operator < (const Point &b) const
{ return b.x == x? y < b.y: x<b.x; }
}p[maxd],tmp[maxd];
struct Edge
{
int u,v,d;
bool operator < (const Edge &b) const
{ return d < b.d;}
} e[maxd<<2];
int tot;
int fa[maxd],t[maxd],pos[maxd];
int gf(int x){
if(fa[x] == x) return x;
else return fa[x] = gf(fa[x]);
}
void update(int x,int v,int pp)
{
for(;x; x -=(x&-x))
if(v < t[x]) t[x] = v,pos[x] = pp;
}
int query(int x,int m)
{
int ans = inf,pp = -1;
for(;x<=m; x+=(x&-x))
if(t[x] < ans) ans = t[x], pp = pos[x];
return pp;
}
int dis(Point a,Point b)
{
return abs(a.x - b.x) + abs(a.y - b.y);
}
void Mannattan_minimum_spanning_tree(int n, Point p[])
{
int a[maxd],b[maxd];
for(int dir = 0; dir < 4;dir++)
{
if(dir == 1 || dir == 3) for(int i = 0;i < n;i++) swap(p[i].x,p[i].y);
else if(dir == 2) for(int i = 0;i < n;i++) p[i].x = -p[i].x;
sort(p,p+n);
for(int i = 0;i < n;i++) a[i] = b[i] = p[i].y - p[i].x;
sort(b,b+n);
int m = unique(b,b+n) - b;
for(int i = 1;i <= m;i++) t[i] = inf,pos[i] = -1;
for(int i = n-1 ;i >= 0;i--)
{
int pos = lower_bound(b,b+m,a[i]) - b + 1;
int ans = query(pos,m);
if(ans != -1) e[++tot] = (Edge) {p[i].id,p[ans].id,dis(p[i],p[ans])};
//addedge(p[i].id,p[ans].id,dis(p[i],p[ans]));
update(pos,p[i].x+p[i].y,i);
}
}
}
int n,m;
int main()
{
// freopen("a.in","r",stdin);
// freopen("k.out","w",stdout);
while(scanf("%d %d",&n,&m)!=EOF)
{
for(int i = 1;i<=n;i++)
for(int j = 1;j<=m;j++)
{
p[(i-1)*m+j].id = (i-1)*m+j;
scanf("%d",&p[(i-1)*m+j].x);
}
for(int i = 1;i<=n;i++)
for(int j = 1;j<=m;j++)
scanf("%d",&p[(i-1)*m+j].y);
tot = 0;
for(int i = 1;i<=n;i++)
{
int C = 0;
for(int j = 1;j<=m;j++)
tmp[C++] = p[(i-1)*m+j];
Mannattan_minimum_spanning_tree(C,tmp);
}
for(int j = 1;j<=m;j++)
{
int C = 0;
for(int i = 1;i<=n;i++)
tmp[C++] = p[(i-1)*m+j];
Mannattan_minimum_spanning_tree(C,tmp);
}
for(int i=1;i<=n*m;i++) fa[i] = i;
sort(e+1,e+tot+1);
ll ans = 0ll;
for(int i = 1;i <= tot;i++)
{
int x = gf(e[i].u) , y = gf(e[i].v);
if(x!=y)
{
fa[x] = y;
ans += e[i].d;
}
}
printf("%lld\n",ans);
}
}