1919(`OrderOption.QUASI_RANDOM`) in the dataloader constructor's `order` argument.
2020'''
2121
22+ def select_buffer (buffer , batch_slot , count ):
23+ """Util function to select the relevent subpart of a buffer for a given
24+ batch_slot and batch size"""
25+ if buffer is None :
26+ return None
27+ if isinstance (buffer , tuple ):
28+ return tuple (select_buffer (x , batch_slot , count ) for x in buffer )
29+
30+ return buffer [batch_slot ][:count ]
31+
32+
2233class EpochIterator (Thread ):
2334 def __init__ (self , loader : 'Loader' , order : Sequence [int ]):
2435 super ().__init__ (daemon = True )
@@ -33,6 +44,10 @@ def __init__(self, loader: 'Loader', order: Sequence[int]):
3344 self .terminate_event = Event ()
3445 self .memory_context = self .loader .memory_manager .schedule_epoch (
3546 batches )
47+
48+ if IS_CUDA :
49+ self .current_stream = ch .cuda .current_stream ()
50+
3651 try :
3752 self .memory_context .__enter__ ()
3853 except MemoryError as e :
@@ -44,23 +59,13 @@ def __init__(self, loader: 'Loader', order: Sequence[int]):
4459
4560 self .storage_state = self .memory_context .state
4661
47- self .memory_bank_per_stage = defaultdict (list )
48-
4962 self .cuda_streams = [(ch .cuda .Stream () if IS_CUDA else None )
5063 for _ in range (self .loader .batches_ahead + 2 )]
5164
52- # Allocate all the memory
53- memory_allocations = {}
54- for (p_id , p ) in self .loader .pipelines .items ():
55- memory_allocations [p_id ] = p .allocate_memory (self .loader .batch_size ,
56- self .loader .batches_ahead + 2 )
57-
58- # Assign each memory bank to the pipeline stage it belongs to
59- for s_ix , banks in self .loader .memory_bank_keys_per_stage .items ():
60- for (pipeline_name , op_id ) in banks :
61- self .memory_bank_per_stage [s_ix ].append (
62- memory_allocations [pipeline_name ][op_id ]
63- )
65+ self .memory_allocations = self .loader .graph .allocate_memory (
66+ self .loader .batch_size ,
67+ self .loader .batches_ahead + 2
68+ )
6469
6570 self .start ()
6671
@@ -77,6 +82,7 @@ def run(self):
7782 self .current_batch_slot = (
7883 slot + 1 ) % (self .loader .batches_ahead + 2 )
7984 result = self .run_pipeline (b_ix , ixes , slot , events [slot ])
85+ # print("RES", b_ix, "ready")
8086 to_output = (slot , result )
8187 while True :
8288 try :
@@ -88,23 +94,24 @@ def run(self):
8894 if self .terminate_event .is_set ():
8995 return
9096 if IS_CUDA :
97+ # print("SUB", b_ix)
9198 # We were able to submit this batch
9299 # Therefore it means that the user must have entered the for loop for
93100 # (batch_slot - batch_ahead + 1) % (batches ahead + 2)
94101 # Therefore batch_slot - batch_ahead must have all it's work submitted
95102 # We will record an event of all the work submitted on the main stream
96103 # and make sure no one overwrite the data until they are done
97- just_finished_slot = (slot - self .loader .batches_ahead ) % (self .loader .batches_ahead + 2 )
104+ just_finished_slot = (slot - self .loader .batches_ahead - 1 ) % (self .loader .batches_ahead + 2 )
105+ # print("JFS", just_finished_slot)
98106 event = ch .cuda .Event ()
99- event .record (ch . cuda . default_stream () )
107+ event .record (self . current_stream )
100108 events [just_finished_slot ] = event
101109 b_ix += 1
102110
103111 except StopIteration :
104112 self .output_queue .put (None )
105113
106114 def run_pipeline (self , b_ix , batch_indices , batch_slot , cuda_event ):
107- # print(b_ix, batch_indices)
108115 self .memory_context .start_batch (b_ix )
109116 args = []
110117 if IS_CUDA :
@@ -114,28 +121,35 @@ def run_pipeline(self, b_ix, batch_indices, batch_slot, cuda_event):
114121 ctx = nullcontext ()
115122 first_stage = False
116123
124+
125+ code , outputs = self .loader .code
117126 with ctx :
118127 if IS_CUDA :
119128 if cuda_event :
120129 cuda_event .wait ()
121- for stage , banks in self .memory_bank_per_stage .items ():
122- args .insert (0 , batch_indices )
123- for bank in banks :
124- if bank is not None :
125- if isinstance (bank , tuple ):
126- bank = tuple (x [batch_slot ] for x in bank )
127- else :
128- bank = bank [batch_slot ]
129- args .append (bank )
130- args .append (self .metadata )
131- args .append (self .storage_state )
132- code = self .loader .code_per_stage [stage ]
133- result = code (* args )
134- args = list (result )
135- if first_stage :
136- first_stage = False
137- self .memory_context .end_batch (b_ix )
138- return tuple (x [:len (batch_indices )] for x in args )
130+
131+ args = {
132+ 'batch_indices' : batch_indices ,
133+ 'storage_state' : self .storage_state ,
134+ 'metadata' : self .metadata ,
135+ ** {
136+ f'memory_{ k } ' :select_buffer (v , batch_slot , len (batch_indices ))
137+ for (k , v ) in self .memory_allocations ['operation' ].items ()
138+ },
139+ ** {
140+ f'shared_memory_{ k } ' : select_buffer (v , batch_slot , len (batch_indices ))
141+ for (k , v ) in self .memory_allocations ['shared' ].items ()
142+ }
143+ }
144+
145+ for stage_code , define_outputs in code :
146+ results = stage_code (** args )
147+ for node_id , result in zip (define_outputs , results ):
148+ args [f'result_{ node_id } ' ] = result
149+ pass
150+
151+ result = tuple (args [f'result_{ x } ' ] for x in outputs )
152+ return result
139153
140154 def __next__ (self ):
141155 result = self .output_queue .get ()
@@ -146,7 +160,7 @@ def __next__(self):
146160 if IS_CUDA :
147161 stream = self .cuda_streams [slot ]
148162 # We wait for the copy to be done
149- ch . cuda . current_stream () .wait_stream (stream )
163+ self . current_stream .wait_stream (stream )
150164 return result
151165
152166 def __iter__ (self ):
0 commit comments