将军令
可以树形\(dp\),可以贪心。
不得不承认确实连贪心想都没想,好像我没学过这个算法一样。
先说贪心做法
一个点被守护的状态对于这个点是等效的,也就是不管你驻扎在哪,只要把它守住,那这个点根本不管你驻扎在哪(对它来说没有区别)。
所以我们想让守住他的点能守住更多的点。
所以我们可以先按深度排序,然后对于每个点都在他的\(k\)级父亲上驻扎,然后让他的\(k\)级父亲把它能覆盖的点都覆盖了。
特别地,如果不存在\(k\)级父亲,那就驻扎在根节点。
至于证明,本人不太会,但是这个做法好像乍一看看不出来啥毛病。
引用大佬一个证明
考虑被控制节点x,如果x可以向上移动一,并且仍然能控制x移动前能控制的点,就把x向上移动,显然,这样做是不会使答案变差的。
我们不知道x向上移动后能控制多少点,但显然,x有可能控制更多的点,这时把x移动到不能移动为止,x就成为了最优答案。
引用完了发现证明也挺简单的
考场上确实可以多想一点,想到贪心之后可以搞个排来验证一下正确性。
code
#include <cstring>
#include <algorithm>
#include <cstdio>
#define mp make_pair
#define R register int
#define int long
#define printf Ruusupuu = printf
int Ruusupuu ;
using namespace std ;
typedef long long L ;
typedef long double D ;
typedef unsigned long long G ;
typedef pair< int , int > PI ;
const int N = 1e5 + 10 ;
const int M = 4e1 + 10 ;
const int Inf = 0x3f3f3f3f ;
inline int mn( int a , int b ) { return a > b ? b : a ; }
inline int read(){
int w = 0 ; bool fg = 0 ; char ch = getchar() ;
while( ch < '0' || ch > '9' ) fg |= ( ch == '-' ) , ch = getchar() ;
while( ch >= '0' && ch <= '9' ) w = ( w << 1 ) + ( w << 3 ) + ( ch ^ '0' ) , ch = getchar() ;
return fg ? -w : w ;
}
int n , k , x , y , f [N][M] , fa [N] ;
int head [N] , to [N << 1] , net [N << 1] , fr [N << 1] , cnt = 1 ;
bool fg [N << 1] , lit [N] ;
#define add( f , t ) fr [++ cnt] = f , to [cnt] = t , net [cnt] = head [f] , head [f] = cnt
struct node{ int index , dep ; } a [N] ;
inline bool cmp( node a , node b ){ return a.dep > b.dep ; }
void sc(){
n = read() , k = read() , read() ; memset( head , -1 , sizeof( head ) ) ;
for( R i = 1 ; i < n ; i ++ ) x = read() , y = read() , add( x , y ) , add( y , x ) ;
}
void dfs( int x ){
for( R i = head [x] ; ~i ; i = net [i] ){
if( fg [i] ) continue ;
fg [i] = fg [i ^ 1] = 1 ;
int y = to [i] ; fa [y] = x ;
a [y].index = y , a [y].dep = a [x].dep + 1 , dfs( y ) ;
}
}
inline void dfsp( int x , int fa , int ks ){
if( ks < 0 ) return ;
lit [x] = 1 ;
for( R i = head [x] ; ~i ; i = net [i] ){
int y = to [i] ;
if( y == fa ) continue ;
// printf( "%ld %ld\n" , x , y ) ;
dfsp( y , x , ks - 1 ) ;
}
}
void work(){
a [1].dep = 1 , a [1].index = 1 , dfs( 1 ) ;
sort( a + 1 , a + 1 + n , cmp ) ;
// for( R i = 1 ; i <= n ; i ++ ) printf( "%ld %ld\n" , a [i].index , a [i].dep ) ;
int ans = 0 ;
for( R i = 1 ; i <= n ; i ++ ){
if( lit [a [i].index] ) continue ;
int kk = k , x = a [i].index ;
while( x != 1 && kk ) kk -- , x = fa [x] ;
//printf( "SSS%ld %ld\n" , a [i].index , x ) ;
dfsp ( x , x , k ) , ans ++ ;
} printf( "%ld\n" , ans ) ;
}
signed main(){
sc() ;
work() ;
return 0 ;
}
重头戏:树形\(dp\)
开眼界了,学到了如何从孙子,重孙子等非直接儿子节点转移。
先推一下\(k=1\)的时候的式子。
状态定义:\(f[i][0]\)表示当前节点可以覆盖当前节点的子树和当前节点的父亲时最少驻扎多少。
\(f[i][1]\)表示当前节点可以覆盖当前节点的子树时最少驻扎多少。
\(f[i][2]\)表示当前节点可以覆盖当前节点的孙树时最少驻扎多少。(孙树就是子节点的子树)
类似于小胖收皇宫,不过那个的状态定义不如这个可扩展性高。
转移方程画一画不难推出来,我就不赘述了,注意就是转移\(f[i][1]\)的时候至少有一个子节点要选择覆盖父亲节点。
\(k=2\)的时候同理设计五个状态,粘一个让我懂的\(blog\)。
不过我写的和他的有点不同,因为保证是少选一个的操作本弱懒得搞花里胡哨一个式子解决,所以就加了点辅助的东西。
推出来式子之后发现和\(k=1\)有很多相似之处,一般就可以找到规律打正解了。
如果你再推一个\(k=3\),那么一切会变得更加明了。
粘一个本弱打的\(k=1~3\)的代码
code
#include <cstring>
#include <algorithm>
#include <cstdio>
#define mp make_pair
#define R register int
#define int long
#define printf Ruusupuu = printf
int Ruusupuu ;
using namespace std ;
typedef long long L ;
typedef long double D ;
typedef unsigned long long G ;
typedef pair< int , int > PI ;
const int N = 1e5 + 10 ;
const int Inf = 0x3f3f3f3f ;
inline int mn( int a , int b ) { return a > b ? b : a ; }
inline int read(){
int w = 0 ; bool fg = 0 ; char ch = getchar() ;
while( ch < '0' || ch > '9' ) fg |= ( ch == '-' ) , ch = getchar() ;
while( ch >= '0' && ch <= '9' ) w = ( w << 1 ) + ( w << 3 ) + ( ch ^ '0' ) , ch = getchar() ;
return fg ? -w : w ;
}
int n , k , x , y , f [N][9] ;
int head [N] , to [N << 1] , net [N << 1] , fr [N << 1] , cnt = 1 ;
bool fg [N << 1] ;
#define add( f , t ) fr [++ cnt] = f , to [cnt] = t , net [cnt] = head [f] , head [f] = cnt
void sc(){
n = read() , k = read() , read() ; memset( head , -1 , sizeof( head ) ) ;
for( R i = 1 ; i < n ; i ++ ) x = read() , y = read() , add( x , y ) , add( y , x ) ;
}
void dfs1( int x ){
int f0 = 0 , f1 = 0 , f2 = 0 , dlt = Inf ;
bool cg = 0 ;
for( R i = head [x] ; ~i ; i = net [i] ){
if( fg [i] ) continue ;
fg [i] = fg [i ^ 1] = 1 ;
int y = to [i] ; dfs1( y ) ;
f0 += mn( f [y][0] , mn( f [y][1] , f [y][2] ) ) ;
if( f [y][0] <= f [y][1] ) f1 += f [y][0] , cg = 1 ;
else f1 += f [y][1] ;
dlt = mn( dlt , f [y][0] - f [y][1] ) ;
f2 += mn( f [y][0] , f [y][1] ) ;
}
f [x][0] = f0 + 1 , f [x][1] = f1 + ( cg ? 0 : dlt ) , f [x][2] = f2 ;
}
void dfs2( int x ){
int f0 = 0 , f1 = 0 , f2 = 0 , f3 = 0 , f4 = 0 , dlt1 = Inf , dlt2 = Inf ;
bool cg1 = 0 , cg2 = 0 ;
for( R i = head [x] ; ~i ; i = net [i] ){
if( fg [i] ) continue ;
fg [i] = fg [i ^ 1] = 1 ;
int y = to [i] ; dfs2( y ) ;
f0 += mn( mn( mn( f [y][0] , f [y][1] ) , mn( f [y][2] , f [y][3] ) ) , f [y][4] ) ;
int ff1 = mn( mn( f [y][0] , f [y][1] ) , mn( f [y][2] , f [y][3] ) ) ;
if( ff1 == f [y][0] ) f1 += f [y][0] , cg1 = 1 ;
else f1 += ff1 ;
dlt1 = mn( dlt1 , f [y][0] - ff1 ) ;
int ff2 = mn( mn( f [y][0] , f [y][1] ) , f [y][2] ) ;
if( ff2 == f [y][1] ) f2 += f [y][1] , cg2 = 1 ;
else f2 += ff2 ;
dlt2 = mn( dlt2 , f [y][1] - ff2 ) ;
f3 += ff2 , f4 += ff1 ;
}
f [x][0] = f0 + 1 , f [x][1] = mn( f [x][0] , f1 + ( cg1 ? 0 : dlt1 ) ) ;
f [x][2] = mn( f [x][1] , f2 + ( cg2 ? 0 : dlt2 ) ) ;
f [x][3] = mn( f [x][2] , f3 ) , f [x][4] = mn( f [x][3] , f4 ) ;
}
void dfs3( int x ){
int f0 = 0 , f1 = 0 , f2 = 0 , f3 = 0 , f4 = 0 , f5 = 0 , f6 = 0 , dlt1 = Inf , dlt2 = Inf , dlt3 = Inf ;
bool cg1 = 0 , cg2 = 0 , cg3 = 0 ;
for( R i = head [x] ; ~i ; i = net [i] ){
if( fg [i] ) continue ;
fg [i] = fg [i ^ 1] = 1 ;
int y = to [i] ; dfs3( y ) ;
f0 += mn( mn( mn( f [y][0] , f [y][1] ) , mn( f [y][2] , f [y][3] ) ) , mn( mn( f [y][4] , f [y][5] ) , f [y][6] ) ) ;
int ff1 = mn( mn( mn( f [y][0] , f [y][1] ) , mn( f [y][2] , f [y][3] ) ) , mn( f [y][4] , f [y][5] ) ) ;
if( ff1 == f [y][0] ) f1 += f [y][0] , cg1 = 1 ;
else f1 += ff1 ;
dlt1 = mn( dlt1 , f [y][0] - ff1 ) ;
int ff2 = mn( mn( mn( f [y][0] , f [y][1] ) , mn( f [y][2] , f [y][3] ) ) , f [y][4] ) ;
if( ff2 == f [y][1] ) f2 += f [y][1] , cg2 = 1 ;
else f2 += ff2 ;
dlt2 = mn( dlt2 , f [y][1] - ff2 ) ;
int ff3 = mn( mn( f [y][0] , f [y][1] ) , mn( f [y][2] , f [y][3] ) ) ;
if( ff3 == f [y][2] ) f3 += f [y][2] , cg3 = 1 ;
else f3 += ff3 ;
dlt3 = mn( dlt3 , f [y][2] - ff3 ) ;
f4 += ff3 , f5 += ff2 , f6 += ff1 ;
}
f [x][0] = f0 + 1 , f [x][1] = mn( f [x][0] , f1 + ( cg1 ? 0 : dlt1 ) ) ;
f [x][2] = mn( f [x][1] , f2 + ( cg2 ? 0 : dlt2 ) ) ;
f [x][3] = mn( f [x][2] , f3 + ( cg3 ? 0 : dlt3 ) ) ;
f [x][4] = mn( f [x][3] , f4 ) , f [x][5] = mn( f [x][4] , f5 ) , f [x][6] = mn( f [x][5] , f6 ) ;
}
void work(){
if( k == 0 ) printf( "%ld\n" , n ) ;
else if( k == 1 ) dfs1( 1 ) , printf( "%ld\n" , mn( f [1][0] , f [1][1] ) ) ;
else if( k == 2 ) dfs2( 1 ) , printf( "%ld\n" , mn( f [1][0] , mn( f [1][1] , f [1][2] ) ) ) ;
else if( k == 3 ) dfs3( 1 ) , printf( "%ld\n" , mn( mn( f [1][0] , f [1][1] ) , mn( f [1][2] , f [1][3] ) ) ) ;
}
signed main(){
sc() ;
work() ;
return 0 ;
}
可以获得\(90pts\)的好成绩,不过都推到这了,规律不就很显然了吗。
然后打出正解,甚至根本没有题解说的那么麻烦,很短就解决了。
code
#include <cstring>
#include <algorithm>
#include <cstdio>
#define mp make_pair
#define R register int
#define int long
#define printf Ruusupuu = printf
int Ruusupuu ;
using namespace std ;
typedef long long L ;
typedef long double D ;
typedef unsigned long long G ;
typedef pair< int , int > PI ;
const int N = 1e5 + 10 ;
const int M = 4e1 + 10 ;
const int Inf = 0x3f3f3f3f ;
inline int mn( int a , int b ) { return a > b ? b : a ; }
inline int read(){
int w = 0 ; bool fg = 0 ; char ch = getchar() ;
while( ch < '0' || ch > '9' ) fg |= ( ch == '-' ) , ch = getchar() ;
while( ch >= '0' && ch <= '9' ) w = ( w << 1 ) + ( w << 3 ) + ( ch ^ '0' ) , ch = getchar() ;
return fg ? -w : w ;
}
int n , k , x , y , f [N][M] ;
int head [N] , to [N << 1] , net [N << 1] , fr [N << 1] , cnt = 1 ;
bool fg [N << 1] ;
#define add( f , t ) fr [++ cnt] = f , to [cnt] = t , net [cnt] = head [f] , head [f] = cnt
void sc(){
n = read() , k = read() , read() ; memset( head , -1 , sizeof( head ) ) ;
for( R i = 1 ; i < n ; i ++ ) x = read() , y = read() , add( x , y ) , add( y , x ) ;
}
inline int catchmin( int x , int fr , int to ){
int ans = Inf ;
for( R i = fr ; i <= to ; i ++ ) ans = mn( ans , f [x][i] ) ;
return ans ;
}
void Gan( int x ){
int ff [M] = { 0 } , dlt [M] ;
bool cg [M] = { 0 } ; memset( dlt , 0x3f , sizeof( dlt ) ) ;
for( R i = head [x] ; ~i ; i = net [i] ){
if( fg [i] ) continue ;
fg [i] = fg [i ^ 1] = 1 ;
int y = to [i] ; Gan( y ) ;
ff [0] += catchmin( y , 0 , 2 * k ) ;
for( R j = 1 ; j <= k ; j ++ ){
int fff = catchmin( y , 0 , ( 2 * k ) - j ) ;
if( fff == f [y][j - 1] ) ff [j] += f [y][j - 1] , cg [j] = 1 ;
else ff [j] += fff ;
dlt [j] = mn( dlt [j] , f [y][j - 1] - fff ) ;
ff [2 * k - j + 1] += fff ;
}
} f [x][0] = ff [0] + 1 ;
for( R i = 1 ; i <= k ; i ++ ) f [x][i] = mn( f [x][i - 1] , ff [i] + ( cg [i] ? 0 : dlt [i] ) ) ;
for( R i = k + 1 ; i <= 2 * k ; i ++ ) f [x][i] = mn( f [x][i - 1] , ff [i] ) ;
}
void work(){
if( k == 0 ) printf( "%ld\n" , n ) ;
else Gan( 1 ) , printf( "%ld\n" , catchmin( 1 , 0 , k ) ) ;
}
signed main(){
sc() ;
work() ;
return 0 ;
}