@@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
7
7
use std:: { fs, io, mem, str, thread} ;
8
8
9
9
use rustc_ast:: attr;
10
+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
10
11
use rustc_data_structures:: fx:: { FxHashMap , FxIndexMap } ;
11
12
use rustc_data_structures:: jobserver:: { self , Acquired } ;
12
13
use rustc_data_structures:: memmap:: Mmap ;
@@ -40,7 +41,7 @@ use tracing::debug;
40
41
use super :: link:: { self , ensure_removed} ;
41
42
use super :: lto:: { self , SerializedModule } ;
42
43
use super :: symbol_export:: symbol_name_for_instance_in_crate;
43
- use crate :: errors:: ErrorCreatingRemarkDir ;
44
+ use crate :: errors:: { AutodiffWithoutLto , ErrorCreatingRemarkDir } ;
44
45
use crate :: traits:: * ;
45
46
use crate :: {
46
47
CachedModuleCodegen , CodegenResults , CompiledModule , CrateInfo , ModuleCodegen , ModuleKind ,
@@ -118,6 +119,7 @@ pub struct ModuleConfig {
118
119
pub merge_functions : bool ,
119
120
pub emit_lifetime_markers : bool ,
120
121
pub llvm_plugins : Vec < String > ,
122
+ pub autodiff : Vec < config:: AutoDiff > ,
121
123
}
122
124
123
125
impl ModuleConfig {
@@ -266,6 +268,7 @@ impl ModuleConfig {
266
268
267
269
emit_lifetime_markers : sess. emit_lifetime_markers ( ) ,
268
270
llvm_plugins : if_regular ! ( sess. opts. unstable_opts. llvm_plugins. clone( ) , vec![ ] ) ,
271
+ autodiff : if_regular ! ( sess. opts. unstable_opts. autodiff. clone( ) , vec![ ] ) ,
269
272
}
270
273
}
271
274
@@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
389
392
390
393
fn generate_lto_work < B : ExtraBackendMethods > (
391
394
cgcx : & CodegenContext < B > ,
395
+ autodiff : Vec < AutoDiffItem > ,
392
396
needs_fat_lto : Vec < FatLtoInput < B > > ,
393
397
needs_thin_lto : Vec < ( String , B :: ThinBuffer ) > ,
394
398
import_only_modules : Vec < ( SerializedModule < B :: ModuleBuffer > , WorkProduct ) > ,
@@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
397
401
398
402
if !needs_fat_lto. is_empty ( ) {
399
403
assert ! ( needs_thin_lto. is_empty( ) ) ;
400
- let module =
404
+ let mut module =
401
405
B :: run_fat_lto ( cgcx, needs_fat_lto, import_only_modules) . unwrap_or_else ( |e| e. raise ( ) ) ;
406
+ if cgcx. lto == Lto :: Fat {
407
+ let config = cgcx. config ( ModuleKind :: Regular ) ;
408
+ module = unsafe { module. autodiff ( cgcx, autodiff, config) . unwrap ( ) } ;
409
+ }
402
410
// We are adding a single work item, so the cost doesn't matter.
403
411
vec ! [ ( WorkItem :: LTO ( module) , 0 ) ]
404
412
} else {
413
+ if !autodiff. is_empty ( ) {
414
+ let dcx = cgcx. create_dcx ( ) ;
415
+ dcx. handle ( ) . emit_fatal ( AutodiffWithoutLto { } ) ;
416
+ }
405
417
assert ! ( needs_fat_lto. is_empty( ) ) ;
406
418
let ( lto_modules, copy_jobs) = B :: run_thin_lto ( cgcx, needs_thin_lto, import_only_modules)
407
419
. unwrap_or_else ( |e| e. raise ( ) ) ;
@@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
1021
1033
/// Sent from a backend worker thread.
1022
1034
WorkItem { result : Result < WorkItemResult < B > , Option < WorkerFatalError > > , worker_id : usize } ,
1023
1035
1036
+ /// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
1037
+ AddAutoDiffItems ( Vec < AutoDiffItem > ) ,
1038
+
1024
1039
/// The frontend has finished generating something (backend IR or a
1025
1040
/// post-LTO artifact) for a codegen unit, and it should be passed to the
1026
1041
/// backend. Sent from the main thread.
@@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
1348
1363
1349
1364
// This is where we collect codegen units that have gone all the way
1350
1365
// through codegen and LLVM.
1366
+ let mut autodiff_items = Vec :: new ( ) ;
1351
1367
let mut compiled_modules = vec ! [ ] ;
1352
1368
let mut compiled_allocator_module = None ;
1353
1369
let mut needs_link = Vec :: new ( ) ;
@@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
1459
1475
let needs_thin_lto = mem:: take ( & mut needs_thin_lto) ;
1460
1476
let import_only_modules = mem:: take ( & mut lto_import_only_modules) ;
1461
1477
1462
- for ( work, cost) in
1463
- generate_lto_work ( & cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
1464
- {
1478
+ for ( work, cost) in generate_lto_work (
1479
+ & cgcx,
1480
+ autodiff_items. clone ( ) ,
1481
+ needs_fat_lto,
1482
+ needs_thin_lto,
1483
+ import_only_modules,
1484
+ ) {
1465
1485
let insertion_index = work_items
1466
1486
. binary_search_by_key ( & cost, |& ( _, cost) | cost)
1467
1487
. unwrap_or_else ( |e| e) ;
@@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
1596
1616
main_thread_state = MainThreadState :: Idle ;
1597
1617
}
1598
1618
1619
+ Message :: AddAutoDiffItems ( mut items) => {
1620
+ autodiff_items. append ( & mut items) ;
1621
+ }
1622
+
1599
1623
Message :: CodegenComplete => {
1600
1624
if codegen_state != Aborted {
1601
1625
codegen_state = Completed ;
@@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
2070
2094
drop ( self . coordinator . sender . send ( Box :: new ( Message :: CodegenComplete :: < B > ) ) ) ;
2071
2095
}
2072
2096
2097
+ pub ( crate ) fn submit_autodiff_items ( & self , items : Vec < AutoDiffItem > ) {
2098
+ drop ( self . coordinator . sender . send ( Box :: new ( Message :: < B > :: AddAutoDiffItems ( items) ) ) ) ;
2099
+ }
2100
+
2073
2101
pub ( crate ) fn check_for_errors ( & self , sess : & Session ) {
2074
2102
self . shared_emitter_main . check ( sess, false ) ;
2075
2103
}
0 commit comments