# Theory Tensor

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Tensor›

theory Tensor
imports Main
begin

typedef 'a tensor = "{t::nat list × 'a list. length (snd t) = prod_list (fst t)}"

definition dims::"'a tensor  nat list" where
"dims A = fst (Rep_tensor A)"

definition vec::"'a tensor  'a list" where
"vec A = snd (Rep_tensor A)"

definition tensor_from_vec::"nat list  'a list  'a tensor" where
"tensor_from_vec d v = Abs_tensor (d,v)"

lemma
assumes "length v = prod_list d"
shows dims_tensor[simp]: "dims (tensor_from_vec d v) = d"
and   vec_tensor[simp]:  "vec (tensor_from_vec d v) = v"
by (simp add: Abs_tensor_inverse assms dims_def tensor_from_vec_def vec_def)+

lemma tensor_from_vec_simp[simp]: "tensor_from_vec (dims A) (vec A) = A"
by (simp add: Rep_tensor_inverse Tensor.vec_def dims_def tensor_from_vec_def)

lemma length_vec: "length (vec A) = prod_list (dims A)"
by (metis (mono_tags, lifting) Rep_tensor Tensor.vec_def dims_def mem_Collect_eq)

lemma tensor_eqI[intro]:
assumes "dims A = dims B" and "vec A = vec B"
shows "A=B"
by (metis assms tensor_from_vec_simp)

abbreviation order::"'a tensor  nat" where
"order t == length (dims t)"

inductive valid_index::"nat list  nat list  bool" (infix "" 50) where
Nil: "[]  []" |
Cons: "is  ds  i<d  i#is  d#ds"

inductive_cases valid_indexE[elim]: "is  ds"
inductive_cases valid_index_dimsE[elim]: "is  dims A"

lemma valid_index_length: "is  ds  length is = length ds"
by (induction rule:valid_index.induct; auto)

lemma valid_index_lt: "is  ds  m<length ds  is!m < ds!m"
proof (induction arbitrary:m rule:valid_index.induct)
case Nil
then show ?case by auto
next
case Cons
then show ?case by (metis gr0_conv_Suc length_Cons linorder_neqE_nat not_less_eq nth_Cons' nth_Cons_Suc)
qed

lemma valid_indexI:
assumes "length is = length ds" and "m. m<length ds  is!m < ds!m"
shows "is  ds"
using assms proof (induction "is" arbitrary:ds)
case Nil
then show ?case by (metis length_0_conv valid_index.simps)
next
case (Cons a "is" ds)
then obtain d ds' where "ds = d # ds'" by (metis length_Suc_conv)
then have "is  ds'" using Cons by (metis length_Cons less_irrefl linorder_neqE_nat not_less_eq nth_Cons_Suc)
then show ?case using Cons.prems(2) ds = d # ds' valid_index.Cons by fastforce
qed

lemma valid_index_append:
assumes is1_valid:"is1  ds1" and is2_valid:"is2  ds2"
shows "is1 @ is2  ds1 @ ds2"
apply (rule valid_indexI[of "is1 @ is2" "ds1 @ ds2"])
unfolding nth_append
using valid_index_lt[OF is2_valid] valid_index_lt[OF is1_valid] valid_index_length[OF is1_valid] valid_index_length[OF is2_valid] length_append
by (auto simp add: ‹length is1 = length ds1)

lemma valid_index_list_all2_iff: "is  ds  list_all2 (<) is ds"
by (metis list_all2_conv_all_nth list_all2_nthD valid_indexI valid_index_length valid_index_lt)

definition fixed_length_sublist::"'a list  nat  nat  'a list" where
"fixed_length_sublist xs l i = (take l (drop (l*i) xs))"

fun lookup_base::"nat list  'a list  nat list  'a" where
lookup_base_Nil: "lookup_base [] v [] = hd v" |
lookup_base_Cons: "lookup_base (d # ds) v (i # is) =
lookup_base ds (fixed_length_sublist v (prod_list ds) i) is"

definition lookup::"'a tensor  nat list  'a" where
"lookup A = lookup_base (dims A) (vec A)"

fun tensor_vec_from_lookup::"nat list  (nat list  'a)  'a list" where
tensor_vec_from_lookup_Nil: "tensor_vec_from_lookup [] e = [e []]" |
tensor_vec_from_lookup_Cons: "tensor_vec_from_lookup (d # ds) e = concat (map (λi. tensor_vec_from_lookup ds (λis. e (i # is))) [0..<d])"

definition tensor_from_lookup::"nat list  (nat list  'a)  'a tensor" where
"tensor_from_lookup ds e = tensor_from_vec ds (tensor_vec_from_lookup ds e)"

lemma concat_parts_leq:
assumes "a * d  length v"
shows "concat (map (fixed_length_sublist v d) [0..<a]) = take (a*d) v"
using assms proof (induction a)
case 0
then show ?case by simp
next
case (Suc a)
then have "concat (map (fixed_length_sublist v d) [0..<a]) = take (a * d) v" by auto
then have "concat (map (fixed_length_sublist v d) [0..<Suc a]) =
take (a * d) v @ fixed_length_sublist v d a" using fixed_length_sublist_def by auto
then show ?case using Suc by (metis add.commute mult.commute mult_Suc take_add fixed_length_sublist_def)
qed

lemma concat_parts_eq:
assumes "a * d = length v"
shows "concat (map (fixed_length_sublist v d) [0..<a]) = v"

lemma tensor_lookup_base:
assumes "length v = prod_list ds"
and "is. is  ds  lookup_base ds v is = e is"
shows "tensor_vec_from_lookup ds e = v"
using assms proof (induction ds arbitrary:v e)
case Nil
then show ?case unfolding tensor_vec_from_lookup.simps
by (metis One_nat_def Tensor.lookup_base_Nil length_0_conv length_Suc_conv list.sel(1) prod_list.Nil valid_index.Nil)
next
case (Cons a ds)
then have "a * prod_list ds = length v" by auto
{
fix i assume "i<a"
then have "prod_list ds * (i+1)  length v" using a * prod_list ds = length v using discrete mult.commute mult_le_mono1 by metis
have "is'. is'  ds  e (i # is') = lookup_base ds (fixed_length_sublist v (prod_list ds) i) is'"
using i<a by (metis Cons.prems(2) Tensor.lookup_base_Cons valid_index.simps)
then have "tensor_vec_from_lookup ds (λis'. e (i # is')) = fixed_length_sublist v (prod_list ds) i"
using Cons using ‹prod_list ds * (i + 1)  length v by (simp add: Cons.IH fixed_length_sublist_def)
}
then show ?case unfolding tensor_vec_from_lookup_Cons lookup_base_Cons
using   concat_parts_eq[OF a * prod_list ds = length v]
atLeastLessThan_iff map_eq_conv set_upt Cons by (metis (no_types, lifting))
qed

lemma tensor_lookup:
assumes "is. is  dims A  lookup A is = e is"
shows "tensor_from_lookup (dims A) e = A"
using tensor_lookup_base lookup_def length_vec tensor_from_lookup_def by (metis assms tensor_from_vec_simp)

lemma concat_equal_length:
assumes "xs. xsset xss  length xs = l"
shows "length (concat xss) = length xss*l"
using assms by (induction xss;auto)

lemma concat_equal_length_map:
assumes "i. i<a  length (f i) = d"
shows "length (concat (map (λi. f i) [0..<a])) = a*d"
using assms by (induction a;auto)

lemma concat_parts:
assumes "xs. xsset xss  length xs = d" and "i<length xss"
shows "fixed_length_sublist (concat xss) d i = xss ! i"
using assms proof (induction xss arbitrary:i)
case Nil
then show ?case by simp
next
case (Cons xs xss)
then have "length (concat xss) = length xss * d" by (simp add: Cons.prems(1) concat_equal_length)
show ?case
proof (cases i)
case 0
then have "fixed_length_sublist (concat (xs # xss)) d i = xs"
unfolding fixed_length_sublist_def by (simp add: Cons.prems(1))
then show ?thesis using 0 by auto
next
case (Suc i')
then have "fixed_length_sublist (concat xss) d i' = xss ! i'" using Cons by auto
then show ?thesis unfolding fixed_length_sublist_def using Suc Cons.prems(1) by auto
qed
qed

lemma concat_parts':
assumes "i. i<a  length (f i) = d"
and "i<a"
shows "fixed_length_sublist (concat (map (λi. f i) [0..<a])) d i = f i"
using assms proof (induction a)
case 0
then show ?case by simp
next
case (Suc a)
then have "(i. i < a  length (f i) = d)" by auto
then have "length (concat (map f [0..<a])) = a*d" using concat_equal_length_map by auto
show ?case
proof (cases "i=a")
assume "i=a"
then have "fixed_length_sublist (concat (map f [0..<Suc a])) d i = f a"
by (simp add: Suc.prems(1) ‹length (concat (map f [0..<a])) = a * d fixed_length_sublist_def)
then show ?case using i=a by auto
next
assume "ia"
then have "fixed_length_sublist (concat (map f [0..<a])) d i = f i"
"concat (map f [0..<Suc a]) = concat (map f [0..<a]) @ f a" using Suc by auto
show ?case unfolding ‹concat (map f [0..<Suc a]) = concat (map f [0..<a]) @ f a
unfolding fixed_length_sublist_def drop_append
using  ‹length (concat (map f [0..<a])) = a * d  ‹fixed_length_sublist (concat (map f [0..<a])) d i = f i
using append_assoc append_eq_conv_conj append_take_drop_id assms(1) assms(2)  fixed_length_sublist_def
by metis
qed
qed

lemma length_tensor_vec_from_lookup:
"length (tensor_vec_from_lookup ds e) = prod_list ds"
by (induction ds arbitrary:e; auto simp add: concat_equal_length_map)

lemma lookup_tensor_vec:
assumes "isds"
shows "lookup_base ds (tensor_vec_from_lookup ds e) is = e is"
using assms proof (induction arbitrary:e rule:valid_index.induct)
case Nil
then show ?case by simp
next
case (Cons "is" ds i d e)
then show ?case unfolding tensor_vec_from_lookup_Cons lookup_base_Cons
by (simp add: length_tensor_vec_from_lookup concat_parts'[of d "λi. tensor_vec_from_lookup ds (λis. e (i # is))" "prod_list ds" i] i < d)
qed

lemma lookup_tensor_from_lookup:
assumes "isds"
shows "lookup (tensor_from_lookup ds e) is = e is"
unfolding lookup_def tensor_from_lookup_def
by (simp add: lookup_tensor_vec assms length_tensor_vec_from_lookup)

lemma dims_tensor_from_lookup: "dims (tensor_from_lookup ds e) = ds"
unfolding tensor_from_lookup_def

lemma tensor_lookup_cong:
assumes "tensor_from_lookup ds e1 = tensor_from_lookup ds e2"
and "isds"
shows "e1 is = e2 is" using assms lookup_tensor_from_lookup by metis

lemma tensor_from_lookup_eqI:
assumes "is. isds  e1 is = e2 is"
shows "tensor_from_lookup ds e1 = tensor_from_lookup ds e2"
by (metis assms lookup_tensor_vec length_tensor_vec_from_lookup tensor_lookup_base tensor_from_lookup_def)

lemma tensor_lookup_eqI:
assumes "dims A = dims B" and "is. is(dims A)  lookup A is = lookup B is"
shows "A = B" by (metis assms(1) assms(2) tensor_lookup)

end