@@ -244,16 +244,14 @@ struct LatestMonitorState {
244244 /// which we haven't yet completed. We're allowed to reload with those as well, at least until
245245 /// they're completed.
246246 persisted_monitor_id : u64 ,
247- /// The latest serialized `ChannelMonitor` that we told LDK we persisted.
248- persisted_monitor : Vec < u8 > ,
249- /// A set of (monitor id, serialized `ChannelMonitor`)s which we're currently "persisting",
247+ /// The latest `ChannelMonitor` that we told LDK we persisted, if any .
248+ persisted_monitor : Option < channelmonitor :: ChannelMonitor < TestChannelSigner > > ,
249+ /// A set of (monitor id, `ChannelMonitor`)s which we're currently "persisting",
250250 /// from LDK's perspective.
251- pending_monitors : Vec < ( u64 , Vec < u8 > ) > ,
251+ pending_monitors : Vec < ( u64 , channelmonitor :: ChannelMonitor < TestChannelSigner > ) > ,
252252}
253253
254254struct TestChainMonitor {
255- pub logger : Arc < dyn Logger > ,
256- pub keys : Arc < KeyProvider > ,
257255 pub persister : Arc < TestPersister > ,
258256 pub chain_monitor : Arc <
259257 chainmonitor:: ChainMonitor <
@@ -277,15 +275,13 @@ impl TestChainMonitor {
277275 chain_monitor : Arc :: new ( chainmonitor:: ChainMonitor :: new (
278276 None ,
279277 broadcaster,
280- logger. clone ( ) ,
278+ logger,
281279 feeest,
282280 Arc :: clone ( & persister) ,
283281 Arc :: clone ( & keys) ,
284282 keys. get_peer_storage_key ( ) ,
285283 false ,
286284 ) ) ,
287- logger,
288- keys,
289285 persister,
290286 latest_monitors : Mutex :: new ( new_hash_map ( ) ) ,
291287 }
@@ -295,20 +291,19 @@ impl chain::Watch<TestChannelSigner> for TestChainMonitor {
295291 fn watch_channel (
296292 & self , channel_id : ChannelId , monitor : channelmonitor:: ChannelMonitor < TestChannelSigner > ,
297293 ) -> Result < chain:: ChannelMonitorUpdateStatus , ( ) > {
298- let mut ser = VecWriter ( Vec :: new ( ) ) ;
299- monitor. write ( & mut ser) . unwrap ( ) ;
300294 let monitor_id = monitor. get_latest_update_id ( ) ;
301295 let res = self . chain_monitor . watch_channel ( channel_id, monitor) ;
296+ let mon = self . persister . take_latest_monitor ( & channel_id) ;
302297 let state = match res {
303298 Ok ( chain:: ChannelMonitorUpdateStatus :: Completed ) => LatestMonitorState {
304299 persisted_monitor_id : monitor_id,
305- persisted_monitor : ser . 0 ,
300+ persisted_monitor : Some ( mon ) ,
306301 pending_monitors : Vec :: new ( ) ,
307302 } ,
308303 Ok ( chain:: ChannelMonitorUpdateStatus :: InProgress ) => LatestMonitorState {
309304 persisted_monitor_id : monitor_id,
310- persisted_monitor : Vec :: new ( ) ,
311- pending_monitors : vec ! [ ( monitor_id, ser . 0 ) ] ,
305+ persisted_monitor : None ,
306+ pending_monitors : vec ! [ ( monitor_id, mon ) ] ,
312307 } ,
313308 Ok ( chain:: ChannelMonitorUpdateStatus :: UnrecoverableError ) => panic ! ( ) ,
314309 Err ( ( ) ) => panic ! ( ) ,
@@ -324,37 +319,15 @@ impl chain::Watch<TestChannelSigner> for TestChainMonitor {
324319 ) -> chain:: ChannelMonitorUpdateStatus {
325320 let mut map_lock = self . latest_monitors . lock ( ) . unwrap ( ) ;
326321 let map_entry = map_lock. get_mut ( & channel_id) . expect ( "Didn't have monitor on update call" ) ;
327- let latest_monitor_data = map_entry
328- . pending_monitors
329- . last ( )
330- . as_ref ( )
331- . map ( |( _, data) | data)
332- . unwrap_or ( & map_entry. persisted_monitor ) ;
333- let deserialized_monitor =
334- <( BlockHash , channelmonitor:: ChannelMonitor < TestChannelSigner > ) >:: read (
335- & mut & latest_monitor_data[ ..] ,
336- ( & * self . keys , & * self . keys ) ,
337- )
338- . unwrap ( )
339- . 1 ;
340- deserialized_monitor
341- . update_monitor (
342- update,
343- & & TestBroadcaster { txn_broadcasted : RefCell :: new ( Vec :: new ( ) ) } ,
344- & & FuzzEstimator { ret_val : atomic:: AtomicU32 :: new ( 253 ) } ,
345- & self . logger ,
346- )
347- . unwrap ( ) ;
348- let mut ser = VecWriter ( Vec :: new ( ) ) ;
349- deserialized_monitor. write ( & mut ser) . unwrap ( ) ;
350322 let res = self . chain_monitor . update_channel ( channel_id, update) ;
323+ let mon = self . persister . take_latest_monitor ( & channel_id) ;
351324 match res {
352325 chain:: ChannelMonitorUpdateStatus :: Completed => {
353326 map_entry. persisted_monitor_id = update. update_id ;
354- map_entry. persisted_monitor = ser . 0 ;
327+ map_entry. persisted_monitor = Some ( mon ) ;
355328 } ,
356329 chain:: ChannelMonitorUpdateStatus :: InProgress => {
357- map_entry. pending_monitors . push ( ( update. update_id , ser . 0 ) ) ;
330+ map_entry. pending_monitors . push ( ( update. update_id , mon ) ) ;
358331 } ,
359332 chain:: ChannelMonitorUpdateStatus :: UnrecoverableError => panic ! ( ) ,
360333 }
@@ -914,9 +887,7 @@ pub fn do_test<Out: Output + MaybeSend + MaybeSync>(
914887 $broadcaster. clone( ) ,
915888 logger. clone( ) ,
916889 $fee_estimator. clone( ) ,
917- Arc :: new( TestPersister {
918- update_ret: Mutex :: new( mon_style[ $node_id as usize ] . borrow( ) . clone( ) ) ,
919- } ) ,
890+ Arc :: new( TestPersister :: new( mon_style[ $node_id as usize ] . borrow( ) . clone( ) ) ) ,
920891 Arc :: clone( & keys_manager) ,
921892 ) ) ;
922893
@@ -966,9 +937,7 @@ pub fn do_test<Out: Output + MaybeSend + MaybeSync>(
966937 broadcaster. clone ( ) ,
967938 logger. clone ( ) ,
968939 Arc :: clone ( fee_estimator) ,
969- Arc :: new ( TestPersister {
970- update_ret : Mutex :: new ( ChannelMonitorUpdateStatus :: Completed ) ,
971- } ) ,
940+ Arc :: new ( TestPersister :: new ( ChannelMonitorUpdateStatus :: Completed ) ) ,
972941 Arc :: clone ( keys) ,
973942 ) ) ;
974943
@@ -983,30 +952,35 @@ pub fn do_test<Out: Output + MaybeSend + MaybeSync>(
983952 let mut monitors = new_hash_map ( ) ;
984953 let mut old_monitors = old_monitors. latest_monitors . lock ( ) . unwrap ( ) ;
985954 for ( channel_id, mut prev_state) in old_monitors. drain ( ) {
986- let ( mon_id, serialized_mon) = if use_old_mons % 3 == 0 {
955+ let old_mon =
956+ prev_state. persisted_monitor . map ( |m| ( prev_state. persisted_monitor_id , m) ) ;
957+ let ( mon_id, mon) = if use_old_mons % 3 == 0 {
987958 // Reload with the oldest `ChannelMonitor` (the one that we already told
988959 // `ChannelManager` we finished persisting).
989- ( prev_state . persisted_monitor_id , prev_state . persisted_monitor )
960+ old_mon . expect ( "no persisted monitor to reload" )
990961 } else if use_old_mons % 3 == 1 {
991962 // Reload with the second-oldest `ChannelMonitor`
992- let old_mon = ( prev_state. persisted_monitor_id , prev_state . persisted_monitor ) ;
993- prev_state . pending_monitors . drain ( .. ) . next ( ) . unwrap_or ( old_mon )
963+ prev_state. pending_monitors . drain ( .. ) . next ( ) . or ( old_mon )
964+ . expect ( "no monitor to reload" )
994965 } else {
995966 // Reload with the newest `ChannelMonitor`
996- let old_mon = ( prev_state. persisted_monitor_id , prev_state . persisted_monitor ) ;
997- prev_state . pending_monitors . pop ( ) . unwrap_or ( old_mon )
967+ prev_state. pending_monitors . pop ( ) . or ( old_mon )
968+ . expect ( "no monitor to reload" )
998969 } ;
999970 // Use a different value of `use_old_mons` if we have another monitor (only for node B)
1000971 // by shifting `use_old_mons` one in base-3.
1001972 use_old_mons /= 3 ;
1002- let mon = <( BlockHash , ChannelMonitor < TestChannelSigner > ) >:: read (
1003- & mut & serialized_mon[ ..] ,
973+ // Serialize and deserialize the monitor to verify round-trip correctness.
974+ let mut ser = VecWriter ( Vec :: new ( ) ) ;
975+ mon. write ( & mut ser) . unwrap ( ) ;
976+ let ( _, deserialized_mon) = <( BlockHash , ChannelMonitor < TestChannelSigner > ) >:: read (
977+ & mut & ser. 0 [ ..] ,
1004978 ( & * * keys, & * * keys) ,
1005979 )
1006980 . expect ( "Failed to read monitor" ) ;
1007- monitors. insert ( channel_id, mon . 1 ) ;
981+ monitors. insert ( channel_id, deserialized_mon ) ;
1008982 // Update the latest `ChannelMonitor` state to match what we just told LDK.
1009- prev_state. persisted_monitor = serialized_mon ;
983+ prev_state. persisted_monitor = Some ( mon ) ;
1010984 prev_state. persisted_monitor_id = mon_id;
1011985 // Wipe any `ChannelMonitor`s which we never told LDK we finished persisting,
1012986 // considering them discarded. LDK should replay these for us as they're stored in
@@ -1053,7 +1027,7 @@ pub fn do_test<Out: Output + MaybeSend + MaybeSync>(
10531027 $monitor. chain_monitor. channel_monitor_updated( * channel_id, id) . unwrap( ) ;
10541028 if id >= state. persisted_monitor_id {
10551029 state. persisted_monitor_id = id;
1056- state. persisted_monitor = data;
1030+ state. persisted_monitor = Some ( data) ;
10571031 }
10581032 }
10591033 }
@@ -1981,10 +1955,11 @@ pub fn do_test<Out: Output + MaybeSend + MaybeSync>(
19811955
19821956 let complete_first = |v : & mut Vec < _ > | if !v. is_empty ( ) { Some ( v. remove ( 0 ) ) } else { None } ;
19831957 let complete_second = |v : & mut Vec < _ > | if v. len ( ) > 1 { Some ( v. remove ( 1 ) ) } else { None } ;
1958+ type PendingMonitors = Vec < ( u64 , channelmonitor:: ChannelMonitor < TestChannelSigner > ) > ;
19841959 let complete_monitor_update =
19851960 |monitor : & Arc < TestChainMonitor > ,
19861961 chan_funding,
1987- compl_selector : & dyn Fn ( & mut Vec < ( u64 , Vec < u8 > ) > ) -> Option < ( u64 , Vec < u8 > ) > | {
1962+ compl_selector : & dyn Fn ( & mut PendingMonitors ) -> Option < ( u64 , channelmonitor :: ChannelMonitor < TestChannelSigner > ) > | {
19881963 if let Some ( state) = monitor. latest_monitors . lock ( ) . unwrap ( ) . get_mut ( chan_funding) {
19891964 assert ! (
19901965 state. pending_monitors. windows( 2 ) . all( |pair| pair[ 0 ] . 0 < pair[ 1 ] . 0 ) ,
@@ -1994,7 +1969,7 @@ pub fn do_test<Out: Output + MaybeSend + MaybeSync>(
19941969 monitor. chain_monitor . channel_monitor_updated ( * chan_funding, id) . unwrap ( ) ;
19951970 if id > state. persisted_monitor_id {
19961971 state. persisted_monitor_id = id;
1997- state. persisted_monitor = data;
1972+ state. persisted_monitor = Some ( data) ;
19981973 }
19991974 }
20001975 }
@@ -2010,7 +1985,7 @@ pub fn do_test<Out: Output + MaybeSend + MaybeSync>(
20101985 monitor. chain_monitor . channel_monitor_updated ( * chan_id, id) . unwrap ( ) ;
20111986 if id > state. persisted_monitor_id {
20121987 state. persisted_monitor_id = id;
2013- state. persisted_monitor = data;
1988+ state. persisted_monitor = Some ( data) ;
20141989 }
20151990 }
20161991 }
0 commit comments