(**
 	Module: Signal	
	Description: signal definition and operations.
	@author WANG Haisheng	
	Created: 03/06/2013	Modified: 03/06/2013
*)

open Types;;
open Basic;;
open Value;;

exception Signal_operation of string;;

let delay_memory_length = 10000;;

class rate : int -> int -> rate_type = 
  fun (num_init : int) ->
    fun (denom_init : int) ->
      let rec pgcd : int -> int -> int = 
	fun i1 -> fun i2 ->
	  let r = i1 mod i2 in 
	  if r = 0 then i2 else pgcd i2 r in
      let factor = 
	if denom_init = 0 then 
	  raise (Signal_operation "sample rate denominater = 0.")
	else 
	  pgcd (abs num_init) (abs denom_init) in
      object (self)
	val _num = num_init / factor
	val _denom = denom_init / factor
	method num = _num
	method denom = _denom
	method to_int = 
	  self#num / self#denom
	method to_float = 
	  (float_of_int self#num) /. (float_of_int self#denom)
	method to_string = 
	  (string_of_int self#num) ^ "/" ^ (string_of_int self#denom)
	method equal : rate_type -> bool = 
	  fun (r : rate_type) -> (self#num = r#num) && (self#denom = r#denom)
	method mul : int -> rate_type = 
	  fun (i : int) -> new rate (self#num * i) self#denom
	method div : int -> rate_type = 
	  fun (i : int) -> new rate self#num (self#denom * i)
      end
	  

class signal : rate_type -> (time -> value_type) -> signal_type = 
  fun (freq_init : rate_type) ->
    fun (func_init : time -> value_type) ->
      object (self)
	val mutable signal_func = func_init
	val mutable memory_length = 0
	method frequency = freq_init
	method at = signal_func

	method private check_freq : signal_type list -> rate_type = 
	  fun (sl : signal_type list) ->
	    let check : rate_type -> signal_type -> rate_type = 
	      fun (f : rate_type) ->
		fun (s : signal_type) ->
		  if f#equal s#frequency || s#frequency#num = 0 then f
		  else if f#num = 0 then s#frequency
		  else raise (Signal_operation "frequency not matched.") in
	    List.fold_left check self#frequency sl

	method add_memory : int -> unit = 
	  fun (length : int) ->
	    assert (length >= 0);
	    if memory_length >= length then ()
	    else
	      let memory = Hashtbl.create length in
	      let func : time -> value = 
		fun (t : time) ->
		  try Hashtbl.find memory t
		  with Not_found ->
		    let result = func_init t in
		    let () = Hashtbl.replace memory t result in
		    let () = 
		      if (t - length) >= 0 then
			Hashtbl.remove memory (t - length)
		      else () in
		    result in
	      memory_length <- length;
	      signal_func <- func

	method private delay_by : int -> time -> value = 
	  fun i -> fun t ->
	    if (t - i) >= 0 then
	      self#at (t - i)
	    else if t >= 0 && (t - i) < 0 then
	      (self#at 0)#zero
	    else raise (Signal_operation "Delay time < 0.")

	method private prim1 : 
	    (time -> value_type) -> signal_type = 
	      fun (func : time -> value_type) ->
		let freq = self#frequency in
		new signal freq func 

	method private prim2 : 
	    (time -> value_type -> value_type) -> signal_type -> signal_type = 
	  fun (func_binary : time -> value_type -> value_type) ->
	    fun (s : signal_type) ->
	      let freq = self#check_freq [s] in
	      let func = fun t -> (func_binary t) (s#at t) in
	      new signal freq func

	method neg = self#prim1 (fun t -> (self#at t)#neg)
	method floor = self#prim1 (fun t -> (self#at t)#floor)
	method ceil = self#prim1 (fun t -> (self#at t)#ceil)
	method rint = self#prim1 (fun t -> (self#at t)#rint)
	method sin = self#prim1 (fun t -> (self#at t)#sin)
	method asin = self#prim1 (fun t -> (self#at t)#asin)
	method cos = self#prim1 (fun t -> (self#at t)#cos)
	method acos = self#prim1 (fun t -> (self#at t)#acos)
	method tan = self#prim1 (fun t -> (self#at t)#tan)
	method atan = self#prim1 (fun t -> (self#at t)#atan)
	method exp = self#prim1 (fun t -> (self#at t)#exp)
	method sqrt = self#prim1 (fun t -> (self#at t)#sqrt)
	method ln = self#prim1 (fun t -> (self#at t)#ln)
	method lg = self#prim1 (fun t -> (self#at t)#lg)
	method int = self#prim1 (fun t -> (self#at t)#int)
	method float = self#prim1 (fun t -> (self#at t)#float)
	method abs = self#prim1 (fun t -> (self#at t)#abs)

	method add = self#prim2 (fun t -> (self#at t)#add)
	method sub = self#prim2 (fun t -> (self#at t)#sub)
	method mul = self#prim2 (fun t -> (self#at t)#mul)
	method div = self#prim2 (fun t -> (self#at t)#div)
	method power = self#prim2 (fun t -> (self#at t)#power)
	method _and = self#prim2 (fun t -> (self#at t)#_and)
	method _or = self#prim2 (fun t -> (self#at t)#_or)
	method _xor = self#prim2 (fun t -> (self#at t)#_xor)
	method atan2 = self#prim2 (fun t -> (self#at t)#atan2)
	method _mod = self#prim2 (fun t -> (self#at t)#_mod)
	method fmod = self#prim2 (fun t -> (self#at t)#fmod)
	method remainder = self#prim2 (fun t -> (self#at t)#remainder)
	method gt = self#prim2 (fun t -> (self#at t)#gt)
	method lt = self#prim2 (fun t -> (self#at t)#lt)
	method geq = self#prim2 (fun t -> (self#at t)#geq)
	method leq = self#prim2 (fun t -> (self#at t)#leq)
	method eq = self#prim2 (fun t -> (self#at t)#eq)
	method neq = self#prim2 (fun t -> (self#at t)#neq)
	method max = self#prim2 (fun t -> (self#at t)#max)
	method min = self#prim2 (fun t -> (self#at t)#min)
	method shl = self#prim2 (fun t -> (self#at t)#shl)
	method shr = self#prim2 (fun t -> (self#at t)#shr)

	method delay : signal_type -> signal_type =
	  fun (s : signal_type) ->
	    let freq = self#check_freq [s] in
	    let () = self#add_memory delay_memory_length in
	    let func : time -> value_type = 
	      fun (t : time) ->
		let i = (s#at t)#to_int in
		self#delay_by i t  in
	    new signal freq func

	method mem : signal_type = 
	  let freq = self#frequency in
	  let () = self#add_memory 1 in
	  let func = fun (t : time) -> self#delay_by 1 t in
	  new signal freq func

	method rdtable : signal_type -> signal_type -> signal_type = 
	  fun (s_size : signal_type) ->
	    fun (s_index : signal_type) ->
	      let freq = self#check_freq [s_index] in
	      let () = self#add_memory ((s_size#at 0)#to_int) in
	      let func : time -> value_type = fun t -> 
		self#at ((s_index#at t)#to_int) in
	      new signal freq func

	method rwtable : signal_type -> signal_type -> 
	  signal_type -> signal_type -> signal_type = 
	    fun init -> fun wstream -> fun windex -> fun rindex ->
	      let freq = self#check_freq [init; wstream; windex; rindex] in
	      let () = init#add_memory ((self#at 0)#to_int) in
	      let () = wstream#add_memory ((self#at 0)#to_int) in
	      let func : time -> value_type = fun (ti : time) -> 
		let rec table : time -> index -> value_type = 
		  fun t -> fun i -> 
		    if t > 0 then
		      (if i = (windex#at t)#to_int then (wstream#at t)
		      else table (t - 1) i)
		    else if t = 0 then
		      (if i = (windex#at 0)#to_int then (wstream#at 0)
		      else init#at i)
		    else raise (Signal_operation "signal time should be > 0") in
		table ti ((rindex#at ti)#to_int) in
	      new signal freq func

	method select2 : signal_type -> signal_type -> signal_type =
	  fun s_first -> 
	    fun s_second ->
	      let freq = self#check_freq [s_first; s_second] in
	      let func : time -> value_type = 
		fun t -> let i = (self#at t)#to_int in
		if i = 0 then s_first#at t
		else if i = 1 then s_second#at t
		else raise (Signal_operation "select2 index 0|1.") in
	      new signal freq func

	method select3 : 
	    signal_type -> signal_type -> signal_type -> signal_type =
	      fun s_first -> fun s_second -> fun s_third ->
		let freq = self#check_freq [s_first; s_second; s_third] in
		let func : time -> value_type = 
		  fun t -> let i = (self#at t)#to_int in
		  if i = 0 then s_first#at t
		  else if i = 1 then s_second#at t
		  else if i = 2 then s_third#at t
		  else raise (Signal_operation "select2 index 0|1.") in
		new signal freq func	
		    
	method prefix : signal_type -> signal_type =
	    fun (s_init : signal_type) ->
	      let () = self#add_memory 1 in
	      let func : time -> value_type = 
		fun t ->
		  if t = 0 then s_init#at 0
		  else if t > 0 then self#at (t - 1) 
		  else raise (Signal_operation "prefix time < 0.") in
	      new signal self#frequency func


	method vectorize : signal_type -> signal_type =
	  fun s_size ->
	    let size = (s_size#at 0)#to_int in
	    if size <= 0 then      
	      raise (Signal_operation "Vectorize: size <= 0.")
	    else 
	      let freq = self#frequency#div size in
	      let func : time -> value_type = 
		fun t ->
		  let vec = fun i -> (self#at (size * t + i))#get in
		  new value (Vec (new vector size vec)) in
	      new signal freq func

	method serialize : signal_type = 
	  let size = 
	    match (self#at 0)#get with
	    | Vec vec -> vec#size
	    | _ -> raise (Signal_operation "Serialize: scalar input.") in
	  let freq = self#frequency#mul size in
	  let func : time -> value_type = 
	    fun t -> 
	      match (self#at (t/size))#get with
	      | Vec vec -> new value (vec#nth (t mod size))
	      | _ -> raise (Signal_operation 
			      "Serialize: signal type not consistent.") in
	  new signal freq func

	method vconcat : signal_type -> signal_type = 
	  fun s ->
	    let freq = self#check_freq [s] in
	    let func : time -> value_type = 
	      fun t ->
		match ((self#at t)#get, (s#at t)#get) with
		| (Vec vec1, Vec vec2) ->
		    let size1 = vec1#size in
		    let size2 = vec2#size in
		    let size = size1 + size2 in
		    let vec = fun i -> 
		      if i < size1 then vec1#nth i
		      else vec2#nth (i - size1) in
		    new value (Vec (new vector size vec))
		| _ -> raise (Signal_operation "Vconcat: scalar.") in
	    new signal freq func

	method vpick : signal_type -> signal_type = 
	  fun s_index ->
	    let freq = self#check_freq [s_index] in
	    let func : time -> value_type = 
	      fun t -> 
		let i = (s_index#at t)#to_int in
		match (self#at t)#get with
		| Vec vec -> new value (vec#nth i)
		| _ -> raise (Signal_operation "Vpick: scalar.") in
	    new signal freq func

      end;;