@@ -404,20 +404,20 @@ pub trait BaseStateWithExtensions<S: BaseState> {
404
404
///
405
405
/// Provides the correct answer regardless if the extension is already present
406
406
/// in the TLV data.
407
- fn try_get_new_account_len < V : Extension + VariableLenPack > (
407
+ fn try_get_new_account_len_for_variable_len_extension_from_new_extension_len < V : Extension > (
408
408
& self ,
409
- new_extension : & V ,
409
+ new_extension_len : usize ,
410
410
) -> Result < usize , ProgramError > {
411
411
// get the new length used by the extension
412
- let new_extension_len = add_type_and_length_to_len ( new_extension . get_packed_len ( ) ? ) ;
412
+ let new_extension_tlv_len = add_type_and_length_to_len ( new_extension_len ) ;
413
413
let tlv_info = get_tlv_data_info ( self . get_tlv_data ( ) ) ?;
414
414
// If we're adding an extension, then we must have at least BASE_ACCOUNT_LENGTH
415
415
// and account type
416
416
let current_len = tlv_info
417
417
. used_len
418
418
. saturating_add ( BASE_ACCOUNT_AND_TYPE_LENGTH ) ;
419
419
let new_len = if tlv_info. extension_types . is_empty ( ) {
420
- current_len. saturating_add ( new_extension_len )
420
+ current_len. saturating_add ( new_extension_tlv_len )
421
421
} else {
422
422
// get the current length used by the extension
423
423
let current_extension_len = self
@@ -426,10 +426,29 @@ pub trait BaseStateWithExtensions<S: BaseState> {
426
426
. unwrap_or ( 0 ) ;
427
427
current_len
428
428
. saturating_sub ( current_extension_len)
429
- . saturating_add ( new_extension_len )
429
+ . saturating_add ( new_extension_tlv_len )
430
430
} ;
431
431
Ok ( adjust_len_for_multisig ( new_len) )
432
432
}
433
+
434
+ /// Calculate the new expected size if the state allocates the required
435
+ /// number of bytes for the given extension type.
436
+ fn try_get_new_account_len < V : Extension + Pod > ( & self ) -> Result < usize , ProgramError > {
437
+ self . try_get_new_account_len_for_variable_len_extension_from_new_extension_len :: < V > (
438
+ pod_get_packed_len :: < V > ( ) ,
439
+ )
440
+ }
441
+
442
+ /// Calculate the new expected size if the state allocates the required
443
+ /// number of bytes for the given variable-length extension type.
444
+ fn try_get_new_account_len_for_variable_len_extension < V : Extension + VariableLenPack > (
445
+ & self ,
446
+ new_extension : & V ,
447
+ ) -> Result < usize , ProgramError > {
448
+ self . try_get_new_account_len_for_variable_len_extension_from_new_extension_len :: < V > (
449
+ new_extension. get_packed_len ( ) ?,
450
+ )
451
+ }
433
452
}
434
453
435
454
/// Encapsulates owned immutable base state data (mint or account) with possible extensions
@@ -1178,6 +1197,58 @@ impl Extension for AccountPaddingTest {
1178
1197
const TYPE : ExtensionType = ExtensionType :: AccountPaddingTest ;
1179
1198
}
1180
1199
1200
+ /// Packs a fixed-length extension into a TLV space
1201
+ ///
1202
+ /// This function reallocates the account as needed to accommodate for the
1203
+ /// change in space.
1204
+ ///
1205
+ /// If the extension already exists, it will overwrite the existing extension
1206
+ /// if `overwrite` is `true`, otherwise it will return an error.
1207
+ ///
1208
+ /// If the extension does not exist, it will reallocate the account and write
1209
+ /// the extension into the TLV buffer.
1210
+ ///
1211
+ /// NOTE: Since this function deals with fixed-size extensions, it does not
1212
+ /// handle _decreasing_ the size of an account's data buffer, like the function
1213
+ /// `alloc_and_serialize_variable_len_extension` does.
1214
+ pub fn alloc_and_serialize < S : BaseState , V : Default + Extension + Pod > (
1215
+ account_info : & AccountInfo ,
1216
+ new_extension : & V ,
1217
+ overwrite : bool ,
1218
+ ) -> Result < ( ) , ProgramError > {
1219
+ let previous_account_len = account_info. try_data_len ( ) ?;
1220
+ let ( new_account_len, extension_already_exists) = {
1221
+ let data = account_info. try_borrow_data ( ) ?;
1222
+ let state = StateWithExtensions :: < S > :: unpack ( & data) ?;
1223
+ let new_account_len = state. try_get_new_account_len :: < V > ( ) ?;
1224
+ let extension_already_exists = state. get_extension_bytes :: < V > ( ) . is_ok ( ) ;
1225
+ ( new_account_len, extension_already_exists)
1226
+ } ;
1227
+
1228
+ if extension_already_exists {
1229
+ if !overwrite {
1230
+ return Err ( TokenError :: ExtensionAlreadyInitialized . into ( ) ) ;
1231
+ } else {
1232
+ // Overwrite the extension
1233
+ let mut buffer = account_info. try_borrow_mut_data ( ) ?;
1234
+ let mut state = StateWithExtensionsMut :: < S > :: unpack ( & mut buffer) ?;
1235
+ let extension = state. get_extension_mut :: < V > ( ) ?;
1236
+ * extension = * new_extension;
1237
+ }
1238
+ } else {
1239
+ // Realloc the account, then write the new extension
1240
+ account_info. realloc ( new_account_len, false ) ?;
1241
+ let mut buffer = account_info. try_borrow_mut_data ( ) ?;
1242
+ if previous_account_len <= BASE_ACCOUNT_LENGTH {
1243
+ set_account_type :: < S > ( * buffer) ?;
1244
+ }
1245
+ let mut state = StateWithExtensionsMut :: < S > :: unpack ( & mut buffer) ?;
1246
+ let extension = state. init_extension :: < V > ( false ) ?;
1247
+ * extension = * new_extension;
1248
+ }
1249
+ Ok ( ( ) )
1250
+ }
1251
+
1181
1252
/// Packs a variable-length extension into a TLV space
1182
1253
///
1183
1254
/// This function reallocates the account as needed to accommodate for the
@@ -1186,7 +1257,7 @@ impl Extension for AccountPaddingTest {
1186
1257
///
1187
1258
/// NOTE: Unlike the `reallocate` instruction, this function will reduce the
1188
1259
/// size of an account if it has too many bytes allocated for the given value.
1189
- pub fn alloc_and_serialize < S : BaseState , V : Extension + VariableLenPack > (
1260
+ pub fn alloc_and_serialize_variable_len_extension < S : BaseState , V : Extension + VariableLenPack > (
1190
1261
account_info : & AccountInfo ,
1191
1262
new_extension : & V ,
1192
1263
overwrite : bool ,
@@ -1195,7 +1266,8 @@ pub fn alloc_and_serialize<S: BaseState, V: Extension + VariableLenPack>(
1195
1266
let ( new_account_len, extension_already_exists) = {
1196
1267
let data = account_info. try_borrow_data ( ) ?;
1197
1268
let state = StateWithExtensions :: < S > :: unpack ( & data) ?;
1198
- let new_account_len = state. try_get_new_account_len ( new_extension) ?;
1269
+ let new_account_len =
1270
+ state. try_get_new_account_len_for_variable_len_extension ( new_extension) ?;
1199
1271
let extension_already_exists = state. get_extension_bytes :: < V > ( ) . is_ok ( ) ;
1200
1272
( new_account_len, extension_already_exists)
1201
1273
} ;
@@ -2282,7 +2354,9 @@ mod test {
2282
2354
let current_len = state. try_get_account_len ( ) . unwrap ( ) ;
2283
2355
assert_eq ! ( current_len, Mint :: LEN ) ;
2284
2356
let new_len = state
2285
- . try_get_new_account_len :: < VariableLenMintTest > ( & variable_len)
2357
+ . try_get_new_account_len_for_variable_len_extension :: < VariableLenMintTest > (
2358
+ & variable_len,
2359
+ )
2286
2360
. unwrap ( ) ;
2287
2361
assert_eq ! (
2288
2362
new_len,
@@ -2297,19 +2371,25 @@ mod test {
2297
2371
2298
2372
// Reduce the extension size
2299
2373
let new_len = state
2300
- . try_get_new_account_len :: < VariableLenMintTest > ( & small_variable_len)
2374
+ . try_get_new_account_len_for_variable_len_extension :: < VariableLenMintTest > (
2375
+ & small_variable_len,
2376
+ )
2301
2377
. unwrap ( ) ;
2302
2378
assert_eq ! ( current_len. checked_sub( new_len) . unwrap( ) , 1 ) ;
2303
2379
2304
2380
// Increase the extension size
2305
2381
let new_len = state
2306
- . try_get_new_account_len :: < VariableLenMintTest > ( & big_variable_len)
2382
+ . try_get_new_account_len_for_variable_len_extension :: < VariableLenMintTest > (
2383
+ & big_variable_len,
2384
+ )
2307
2385
. unwrap ( ) ;
2308
2386
assert_eq ! ( new_len. checked_sub( current_len) . unwrap( ) , 1 ) ;
2309
2387
2310
2388
// Maintain the extension size
2311
2389
let new_len = state
2312
- . try_get_new_account_len :: < VariableLenMintTest > ( & variable_len)
2390
+ . try_get_new_account_len_for_variable_len_extension :: < VariableLenMintTest > (
2391
+ & variable_len,
2392
+ )
2313
2393
. unwrap ( ) ;
2314
2394
assert_eq ! ( new_len, current_len) ;
2315
2395
}
@@ -2382,7 +2462,8 @@ mod test {
2382
2462
let key = Pubkey :: new_unique ( ) ;
2383
2463
let account_info = ( & key, & mut data) . into_account_info ( ) ;
2384
2464
2385
- alloc_and_serialize :: < Mint , _ > ( & account_info, & variable_len, false ) . unwrap ( ) ;
2465
+ alloc_and_serialize_variable_len_extension :: < Mint , _ > ( & account_info, & variable_len, false )
2466
+ . unwrap ( ) ;
2386
2467
let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len ( value_len) ;
2387
2468
assert_eq ! ( data. len( ) , new_account_len) ;
2388
2469
let state = StateWithExtensions :: < Mint > :: unpack ( data. data ( ) ) . unwrap ( ) ;
@@ -2395,12 +2476,18 @@ mod test {
2395
2476
2396
2477
// alloc again succeeds with "overwrite"
2397
2478
let account_info = ( & key, & mut data) . into_account_info ( ) ;
2398
- alloc_and_serialize :: < Mint , _ > ( & account_info, & variable_len, true ) . unwrap ( ) ;
2479
+ alloc_and_serialize_variable_len_extension :: < Mint , _ > ( & account_info, & variable_len, true )
2480
+ . unwrap ( ) ;
2399
2481
2400
2482
// alloc again fails without "overwrite"
2401
2483
let account_info = ( & key, & mut data) . into_account_info ( ) ;
2402
2484
assert_eq ! (
2403
- alloc_and_serialize:: <Mint , _>( & account_info, & variable_len, false ) . unwrap_err( ) ,
2485
+ alloc_and_serialize_variable_len_extension:: <Mint , _>(
2486
+ & account_info,
2487
+ & variable_len,
2488
+ false ,
2489
+ )
2490
+ . unwrap_err( ) ,
2404
2491
TokenError :: ExtensionAlreadyInitialized . into( )
2405
2492
) ;
2406
2493
}
@@ -2429,7 +2516,8 @@ mod test {
2429
2516
let key = Pubkey :: new_unique ( ) ;
2430
2517
let account_info = ( & key, & mut data) . into_account_info ( ) ;
2431
2518
2432
- alloc_and_serialize :: < Mint , _ > ( & account_info, & variable_len, false ) . unwrap ( ) ;
2519
+ alloc_and_serialize_variable_len_extension :: < Mint , _ > ( & account_info, & variable_len, false )
2520
+ . unwrap ( ) ;
2433
2521
let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
2434
2522
+ add_type_and_length_to_len ( value_len)
2435
2523
+ add_type_and_length_to_len ( size_of :: < MetadataPointer > ( ) ) ;
@@ -2447,12 +2535,18 @@ mod test {
2447
2535
2448
2536
// alloc again succeeds with "overwrite"
2449
2537
let account_info = ( & key, & mut data) . into_account_info ( ) ;
2450
- alloc_and_serialize :: < Mint , _ > ( & account_info, & variable_len, true ) . unwrap ( ) ;
2538
+ alloc_and_serialize_variable_len_extension :: < Mint , _ > ( & account_info, & variable_len, true )
2539
+ . unwrap ( ) ;
2451
2540
2452
2541
// alloc again fails without "overwrite"
2453
2542
let account_info = ( & key, & mut data) . into_account_info ( ) ;
2454
2543
assert_eq ! (
2455
- alloc_and_serialize:: <Mint , _>( & account_info, & variable_len, false ) . unwrap_err( ) ,
2544
+ alloc_and_serialize_variable_len_extension:: <Mint , _>(
2545
+ & account_info,
2546
+ & variable_len,
2547
+ false ,
2548
+ )
2549
+ . unwrap_err( ) ,
2456
2550
TokenError :: ExtensionAlreadyInitialized . into( )
2457
2551
) ;
2458
2552
}
@@ -2488,7 +2582,8 @@ mod test {
2488
2582
let key = Pubkey :: new_unique ( ) ;
2489
2583
let account_info = ( & key, & mut data) . into_account_info ( ) ;
2490
2584
let variable_len = VariableLenMintTest { data : vec ! [ 1 , 2 ] } ;
2491
- alloc_and_serialize :: < Mint , _ > ( & account_info, & variable_len, true ) . unwrap ( ) ;
2585
+ alloc_and_serialize_variable_len_extension :: < Mint , _ > ( & account_info, & variable_len, true )
2586
+ . unwrap ( ) ;
2492
2587
2493
2588
let state = StateWithExtensions :: < Mint > :: unpack ( data. data ( ) ) . unwrap ( ) ;
2494
2589
let extension = state. get_extension :: < MetadataPointer > ( ) . unwrap ( ) ;
@@ -2505,7 +2600,8 @@ mod test {
2505
2600
let variable_len = VariableLenMintTest {
2506
2601
data : vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 ] ,
2507
2602
} ;
2508
- alloc_and_serialize :: < Mint , _ > ( & account_info, & variable_len, true ) . unwrap ( ) ;
2603
+ alloc_and_serialize_variable_len_extension :: < Mint , _ > ( & account_info, & variable_len, true )
2604
+ . unwrap ( ) ;
2509
2605
2510
2606
let state = StateWithExtensions :: < Mint > :: unpack ( data. data ( ) ) . unwrap ( ) ;
2511
2607
let extension = state. get_extension :: < MetadataPointer > ( ) . unwrap ( ) ;
@@ -2522,7 +2618,8 @@ mod test {
2522
2618
let variable_len = VariableLenMintTest {
2523
2619
data : vec ! [ 7 , 6 , 5 , 4 , 3 , 2 , 1 ] ,
2524
2620
} ;
2525
- alloc_and_serialize :: < Mint , _ > ( & account_info, & variable_len, true ) . unwrap ( ) ;
2621
+ alloc_and_serialize_variable_len_extension :: < Mint , _ > ( & account_info, & variable_len, true )
2622
+ . unwrap ( ) ;
2526
2623
2527
2624
let state = StateWithExtensions :: < Mint > :: unpack ( data. data ( ) ) . unwrap ( ) ;
2528
2625
let extension = state. get_extension :: < MetadataPointer > ( ) . unwrap ( ) ;
0 commit comments