DPG demo version from 2019-08-20
[dpg] / tensor / tmult.m
1 function result = tmult(A,U,d)
2 %
3 % TMULT - tensor multiply (S = A x_i U) of A by U along dimension i
4 %
5 % usage: S = tmult( A, U, i )
6
7 %
8 % Copyright 2008  W. Scott Hoge  (wsh032580 at proton dot me)
9 %
10 % Licensed under the terms of the MIT License 
11 % (https://opensource.org/licenses/MIT)
12
13 sz = size(A);
14 tmp = U * local_unfold(A,d);
15
16 sz(d) = size(U,1);
17 N = length(sz);
18
19 result = permute( reshape( tmp, sz( [ d:-1:1 N:-1:(d+1) ] ) ), ...
20                   [ d:-1:1 N:-1:(d+1) ] );
21
22 %%%% the following was used with CFW's permute version of the unfolding
23 %
24 % result = permute(  ...
25 %     reshape( tmp, sz([ d:N 1:(d-1) ]) ), ...
26 %     [  (N-d+2):N 1:(N-d+1) ] );
27
28 function result = local_unfold(A,d)
29 %
30 % UNFOLD - reduces the dimension of a tensor by "unfolding" along one direction
31 %          (performs unfolding as described by deLathauwer in SIMAX 21(4):1253.
32 %
33 % Example: given a 3D tensor, A, 
34 %
35 % unfold(A,2) =>        I1
36 %                     +----+ +----+ +----+
37 %                  I2 |    | |    | |    |
38 %                     |    | |    | |    |
39 %                     +----+ +----+ +----+
40 %                        \ ____|____ /
41 %                              I3
42 %
43 % for a 3D tensor, 
44 %
45 % unfold(A,1) =  [ A(:,1,:)  A(:,2,:)  ... A(:,n2,:)  ],
46 % unfold(A,2) = ( A(:,:,1)' A(:,:,2)' ... A(:,:,n3)' ),
47 % unfold(A,3) = ( A(1,:,:)' A(2,:,:)' ... A(n1,:,:)' )
48 %
49 % To fold the matrix back (i.e. undo the unfold), one *must* know the
50 % size of the final matrix and use the same permutation as the unfolding.  
51 % An example:
52 %    sz = size(A);
53 %    A == permute( reshape( unfold(A,d) , sz([ d:-1:1 N:-1:(d+1) ]) ), ...
54 %                  [  d:-1:1 N:-1:(d+1) ] );
55 %
56
57
58 sz = size(A);
59 N = length(sz);
60 if (d > length(sz)), return; end;
61
62 result = reshape( permute( A, [ d:-1:1 N:-1:(d+1) ]), ...
63                   [ sz(d) prod(sz([1:(d-1) (d+1):N])) ] );
64
65 % unfold(A,1) = reshape( permute( A, [ 1 3 2 ]), [sz(1) prod(sz(3:2)) ] )
66 % unfold(A,2) = reshape( permute( A, [ 2 1 3 ]), [sz(2) prod(sz([1 3])) ] )
67 % unfold(A,3) = reshape( permute( A, [ 3 2 1 ]), [sz(3) prod(sz([1 2])) ] )
68