@@ -206,7 +206,7 @@ module type CircuitInterface = sig
206206
207207 (* Mapreduce/Dependecy analysis related functions *)
208208 val is_decomposable : int -> int -> cbitstring cfun -> bool
209- val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun ) list
209+ val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun ) list * ( int * int )
210210 val permute : int -> (int -> int ) -> cbitstring cfun -> cbitstring cfun
211211
212212 (* Wraps the backend call to deal with args/inputs *)
@@ -320,6 +320,10 @@ module type CBackend = sig
320320 val is_splittable : int -> int -> deps -> bool
321321
322322 val are_independent : block_deps -> bool
323+
324+ val single_dep : deps -> bool
325+ (* Assumes single_dep *)
326+ val dep_range : deps -> int * int
323327 end
324328end
325329
@@ -425,11 +429,14 @@ module TestBack : CBackend = struct
425429 let get (r : reg ) (idx : int ) = r.(idx)
426430
427431 let permute (w : int ) (perm : int -> int ) (r : reg ) : reg =
432+ Format. eprintf " Applying permutation to reg of size %d with block size of %d@." (size_of_reg r) w;
428433 Array. init (size_of_reg r) (fun i ->
429- let block_idx, bit_idx = (i / w), (i mod w) in
430- let idx = (perm block_idx)* w + bit_idx in
431- r.(idx)
432- )
434+ let block_idx, bit_idx = perm (i / w), (i mod w) in
435+ if block_idx < 0 then None
436+ else
437+ let idx = block_idx* w + bit_idx in
438+ Some r.(idx)
439+ ) |> Array. filter_map (fun x -> x)
433440
434441
435442 (* Node operations *)
@@ -536,17 +543,17 @@ module TestBack : CBackend = struct
536543 | 0 -> true
537544 | 1 ->
538545 let blocks = block_deps_of_deps w_out d in
539- (* Format.eprintf "Checking block width...@."; *)
546+ Format. eprintf " Checking block width...@." ;
540547 Array. for_all (fun (_ , d ) ->
541548 if Map. is_empty d then true
542549 else
543550 let _, bits = Map. any d in
544551 Set. is_empty bits ||
545552 let base = Set. at_rank_exn 0 bits in
546- (* Format.eprintf "Base for current block: %d@." base; *)
553+ Format. eprintf " Base for current block: %d@." base;
547554 Set. for_all (fun bit ->
548555 let dist = bit - base in
549- (* Format.eprintf "Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in; *)
556+ Format. eprintf " Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in;
550557 0 < = dist && dist < w_in
551558 ) bits
552559 ) blocks
@@ -576,6 +583,28 @@ module TestBack : CBackend = struct
576583 true
577584 with BreakOut ->
578585 false
586+
587+
588+ let single_dep (d : deps ) : bool =
589+ match Set. cardinal
590+ (Array. fold_left (Set. union) Set. empty
591+ (Array. map (fun dep -> Map. keys dep |> Set. of_enum) d))
592+ with
593+ | 0 | 1 -> true
594+ | _ -> false
595+
596+ (* Assumes single_dep, returns range (bot, top) such that valid idxs are bot <= i < top *)
597+ let dep_range (d : deps ) : int * int =
598+ assert (single_dep d);
599+ let idxs =
600+ Array. fold_left (fun acc d ->
601+ Set. union (Map. fold Set. union d Set. empty) acc) Set. empty d
602+ in
603+ Format. eprintf " %a@." pp_deps d;
604+ Format. eprintf " Dep range for dependencies:@." ;
605+ Set. iter (fun i -> Format. eprintf " %d " i) idxs;
606+ Format. eprintf " @.Min: %d | Max: %d@." (Set. min_elt idxs) (Set. max_elt idxs);
607+ (Set. min_elt idxs, Set. max_elt idxs + 1 )
579608 end
580609
581610end
@@ -1272,7 +1301,7 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
12721301 let array_oflist (circs : circuit list ) (dfl : circuit ) (len : int ) : circuit =
12731302 let circs, inps = List. split circs in
12741303 let dif = len - List. length circs in
1275- Format. eprintf " Len, Dif in array_oflist: %d, %d@." len dif;
1304+ (* Format.eprintf "Len, Dif in array_oflist: %d, %d@." len dif; *)
12761305 let circs = circs @ (List. init dif (fun _ -> fst dfl)) in
12771306 let inps = if dif > 0 then inps @ [snd dfl] else inps in
12781307 let circs = List. map
@@ -1518,39 +1547,57 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
15181547 (* For more complex circuits, we might be able to simulate this with a int -> (int, int) map *)
15191548 let is_decomposable (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : bool =
15201549 match inps with
1521- | {type_ =`CIBitstring w } :: [] when w mod in_w = 0 && Backend. size_of_reg r mod out_w = 0 ->
1550+ | {type_ =`CIBitstring w } :: [] when ( Backend. size_of_reg r mod out_w = 0 ) ->
15221551 let deps = Backend.Deps. deps_of_reg r in
1523- Backend.Deps. is_splittable in_w out_w deps
1552+ Backend.Deps. is_splittable in_w out_w deps &&
1553+ let base, top = Backend.Deps. dep_range deps in
1554+ let () = Format. eprintf " Passed backend check, checking width of deps (top - base = %d | in_w = %d)@." (top - base) in_w in
1555+ (top - base) mod in_w = 0
15241556 | _ ->
15251557 Format. eprintf " Failed decomposition type check@\n " ;
15261558 Format. eprintf " In_w: %d | Out_w : %d | Circ: %a" in_w out_w pp_circuit c;
15271559 false
15281560
1529- let split_renamer (n : count ) (in_w : width ) (inp : cinp ) : (cinp array) * (Backend.inp -> cbool_type option) =
1530- match inp with
1531- | {type_ = `CIBitstring w ; id} when w mod in_w = 0 ->
1561+ let split_renamer ?(range : (int * int) option ) (n : count ) (in_w : width ) (inp : cinp ) : (cinp array) * (Backend.inp -> cbool_type option) =
1562+ match range, inp with
1563+ | Some (start_idx , end_idx ), {type_ = `CIBitstring w ; id} when (end_idx - start_idx) mod in_w = 0 ->
1564+ let ids = Array. init n (fun i -> create (" split_" ^ (string_of_int i)) |> tag) in
1565+ Array. map (fun id -> {type_ = `CIBitstring in_w; id}) ids,
1566+ (fun (id_ , w ) ->
1567+ let w = w - start_idx in (* FIXME: check if this doesn't cause problems on the upper end *)
1568+ if id <> id_ || w < 0 || w > = end_idx then None else
1569+ let id_idx, bit_idx = (w / in_w), (w mod in_w) in
1570+ Some (Backend. input_node ~id: ids.(id_idx) bit_idx))
1571+ | None , {type_ = `CIBitstring w ; id} when w mod in_w = 0 ->
15321572 let ids = Array. init n (fun i -> create (" split_" ^ (string_of_int i)) |> tag) in
15331573 Array. map (fun id -> {type_ = `CIBitstring in_w; id}) ids,
15341574 (fun (id_ , w ) ->
15351575 if id <> id_ then None else
15361576 let id_idx, bit_idx = (w / in_w), (w mod in_w) in
15371577 Some (Backend. input_node ~id: ids.(id_idx) bit_idx))
1578+ | _ , {type_ = `CIBitstring w ; id} ->
1579+ Format. eprintf " Failed to build split renamer for n=%d in_w=%d w=%d" n in_w w;
1580+ Option. may (fun (bot , top ) -> Format. eprintf " range=(%d, %d)" bot top) range;
1581+ Format. eprintf " @." ;
1582+ assert false
15381583 | _ -> assert false
15391584
1540- let decompose (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : cbitstring cfun list =
1585+ let decompose (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : cbitstring cfun list * (int * int) =
15411586 if not (is_decomposable in_w out_w c) then
15421587 let deps = Backend.Deps. block_deps_of_reg out_w r in
15431588 Format. eprintf " Failed to decompose. in_w=%d out_w=%d Deps:@.%a" in_w out_w (Backend.Deps. pp_block_deps) deps;
15441589 assert false
15451590 else
1591+ (* TODO: don't repeat dependecy computation ? *)
1592+ let dep_range = Backend.Deps. dep_range (Backend.Deps. deps_of_reg r) in
15461593 let n = (Backend. size_of_reg r) / out_w in
15471594 let blocks = Array. init n (fun i ->
15481595 Backend. slice r (i* out_w) out_w) in
1549- let cinps, renamer = split_renamer n in_w (List. hd inps) in
1596+ let cinps, renamer = split_renamer ~range: dep_range n in_w (List. hd inps) in
15501597 Array. map2 (fun r inp ->
15511598 let r = Backend. applys renamer r in
15521599 (`CBitstring r, [inp])
1553- ) blocks cinps |> Array. to_list
1600+ ) blocks cinps |> Array. to_list, dep_range
15541601
15551602 let permute (w : width ) (perm : (int -> int) ) ((`CBitstring r , inps ): cbitstring cfun ) : cbitstring cfun =
15561603 `CBitstring (Backend. permute w perm r), inps
@@ -2164,13 +2211,13 @@ let circuit_permute (bsz: int) (perm: int -> int) (c: circuit) : circuit =
21642211 in
21652212 (permute bsz perm c :> circuit )
21662213
2167- let circuit_mapreduce ?(perm : (int -> int) option ) (c : circuit ) (w_in : width ) (w_out : width ) : circuit list =
2214+ let circuit_mapreduce ?(perm : (int -> int) option ) (c : circuit ) (w_in : width ) (w_out : width ) : circuit list * (int * int) =
21682215 let c = match c, perm with
21692216 | (`CBitstring _ , inps ) as c , None -> c
21702217 | (`CBitstring _ , inps ) as c , Some perm -> permute w_out perm c
21712218 | _ -> assert false
21722219 in
2173- (decompose w_in w_out c :> circuit list )
2220+ (decompose w_in w_out c :> circuit list * (int * int ) )
21742221
21752222type circuit = ExampleInterface .circuit
21762223type pstate = ExampleInterface.PState .pstate
0 commit comments