Coverage for src / mppy / matrix.py: 78%
556 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-13 09:54 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-13 09:54 +0200
1from abc import ABC, abstractmethod
2from types import NotImplementedType
3from typing import cast, overload
5import numba
6import numpy as np
8from .constants import e, eps
9from .scalar import Scalar
11"""
12================== Base Class Matrix ==================
13"""
16class MpMatrix(ABC):
17 """
18 Dense and Sparse are both Max-Plus matrices and mainly differ in how they store data.
19 This base class exists so functions can use one shared matrix type and accept either
20 implementation transparently. Should be used later on, if any other matrix instance is implemented.
21 Enforces that every matrix has the same baseline of functionality
23 Like my_func(matrix: "MpMatrix")...
24 """
26 """
27 Required Arithmetics
28 """
30 @abstractmethod
31 def __add__(self, other: "MpMatrix") -> "MpMatrix":
32 pass
34 @abstractmethod
35 def __matmul__(self, other: "MpMatrix") -> "MpMatrix":
36 pass
38 @abstractmethod
39 def __rmatmul__(self, other: "MpMatrix") -> "MpMatrix":
40 pass
42 @abstractmethod
43 def __mul__(self, other: "Scalar | np.number") -> "MpMatrix":
44 pass
46 @abstractmethod
47 def __rmul__(self, other: "Scalar | np.number") -> "MpMatrix":
48 pass
50 @abstractmethod
51 def __pow__(self, n: int) -> "MpMatrix":
52 pass
54 @abstractmethod
55 def __eq__(self, other: "MpMatrix") -> bool:
56 pass
58 @abstractmethod
59 def transpose(self) -> "MpMatrix":
60 pass
62 @property
63 @abstractmethod
64 def T(self):
65 pass
67 @property
68 @abstractmethod
69 def shape(self) -> tuple[int, int]:
70 pass
72 """
73 Required Statics
74 """
76 @staticmethod
77 @abstractmethod
78 def identity(n: int) -> "MpMatrix":
79 pass
81 @staticmethod
82 @abstractmethod
83 def zeros(n: int, m: int) -> "MpMatrix":
84 pass
87"""
88================== Dense Matrix ==================
89"""
92class Dense(MpMatrix):
93 """
94 Implementation of the MpMatrix ABC for Dense Matrices. A 2D numpy array is used to store every value explicity
95 """
97 def __init__(self, other):
98 self.__value = np.asarray(other, dtype=np.float64)
99 if self.__value.ndim != 2:
100 raise ValueError("Dense Matrix is only defined for 2D")
102 @classmethod
103 def vector(cls, data):
104 """
105 Shortcut to create a dense vector. Data can be passed in as 1D array to avoid having to write
106 2D array manually
107 :param data: 1D vector
108 :return: Column Vector with values from data, each cell has its own row
109 """
110 if not isinstance(data, (np.ndarray, list)):
111 raise ValueError("Data needs to be a list or np.ndarray")
112 val = data if isinstance(data, np.ndarray) else np.asarray(data)
113 if val.ndim != 1:
114 raise ValueError("Vectors are only defined for 1D")
115 return Dense([[k] for k in val])
117 @classmethod
118 def identity(cls, n) -> "Dense":
119 """
120 Creates an identity matrix with the dimensions n x n
121 :param n: the dimension of the matrix
122 :return: The identity matrix
123 """
124 id_mat = np.identity(n)
125 id_mat[id_mat == 0] = eps
126 id_mat[id_mat == 1] = e
127 return Dense(id_mat)
129 @classmethod
130 def zeros(cls, n, m) -> "Dense":
131 """
132 Creates a zero matrix with the dimensions n x m
133 :param n: the rows
134 :param m: the columns
135 :return: The zero matrix
136 """
137 return Dense(np.full((n, m), eps))
139 def raw(self) -> np.ndarray:
140 """
141 The raw numpy ndarray
142 :return: the numpy ndarray
143 """
144 return self.__value
146 def __add__(self, other) -> "Dense | NotImplementedType":
147 """
148 Adds two matrices together
149 :param other: matrix to add, needs to implement "MpMatrix"
150 :return: Always a dense matrix, because adding a dense to a sparse will always result in a dense matrix
151 """
152 if isinstance(other, Dense):
153 if other.shape != self.shape:
154 raise ValueError(
155 f"Shape mismatch. Operand A has shape {self.shape} and operand B has shape {other.shape}"
156 )
157 return Dense(np.maximum(self.__value, other.__value))
158 elif isinstance(other, Sparse):
159 return dense_sparse_add(other, self)
160 else:
161 return NotImplemented
163 def __pow__(self, n) -> "Dense":
164 if n < 0:
165 raise ValueError("Negative exponent not supported for matrices")
166 if self.shape[0] != self.shape[1]:
167 raise ValueError(
168 f"Shape mismatch. Pow is only implemented for square matrices"
169 )
170 return cast(
171 Dense, pow_by_repeated_squaring(self, n, Dense.identity(self.shape[0]))
172 )
174 @property
175 def T(self) -> "Dense":
176 """
177 Transposes the matrix
178 :return: the transposed matrix
179 """
180 return self.transpose()
182 def transpose(self) -> "Dense":
183 """
184 Transposes the matrix
185 :return: the transposed matrix
186 """
187 return Dense(self.__value.T)
189 def to_sparse(self) -> Sparse:
190 """
191 Converts the matrix to a sparse matrix
192 :return: the sparse dense matrix
193 """
194 return Sparse(self.raw())
196 def __eq__(self, other) -> bool:
197 """
198 If other is sparse, it returns false. User should convert sparse to dense with
199 conversions.sparse_to_dense() and then pass it into eq.
200 :param other: other matrix, implementing "MpMatrix"
201 :return: bool True if equal, False otherwise
202 TODO: It might be slightly faster to implemement a more specialized eq for dense x sparse.
203 """
204 if isinstance(other, Dense):
205 return np.array_equal(self.__value, other.__value)
206 elif isinstance(other, np.ndarray):
207 # early exit, since it has to be 2D
208 if other.ndim != 2:
209 return False
210 return np.array_equal(self.__value, other)
211 elif isinstance(other, Sparse):
212 return self == other.to_dense()
213 return False
215 def __getitem__(self, key):
216 """
217 Gets an item by key
218 :param key: if the key is a tuple[int,int] it will return the exact value in that cell.
219 if the key is an int, it will return the whole row without the epsilons
220 :return: Either the row or the cell value
221 """
222 # if key < 0 or key >= len(self.__value):
223 # raise ValueError("Key is out of bounds for dense matrix")
224 # return self.__value[key]
226 if isinstance(key, (int, np.integer)):
227 if key < 0 or key >= self.shape[0]:
228 raise IndexError
229 return self.__value[key]
230 elif isinstance(key, tuple):
231 # we return the exact item
232 (row, col) = key
233 if row < 0 or row >= self.shape[0] or col < 0 or col >= self.shape[1]:
234 raise IndexError
235 return self.__value[row][col]
236 else:
237 raise ValueError
239 def __setitem__(self, key, value):
240 """
241 Sets an item by key a key
242 :param key: either tuple[int,int] to set a cell or int to set a row
243 :param value: if key is a tuple[int,int], then value needs to be a number, else it needs to be a list of numbers
244 matching the col count
245 :return: Either the row or the cell value
246 """
247 if isinstance(value, (np.ndarray, list)) and isinstance(key, int):
248 if key < 0 or key >= self.shape[0]:
249 raise IndexError
250 if len(value) != self.shape[1]:
251 raise ValueError(
252 f"Shape missmatch: Needed {self.shape[1]} values, but got {len(value)}"
253 )
254 self.__value[key] = value
255 elif isinstance(value, (int, float, np.number)) and isinstance(key, tuple):
256 (row, col) = key
257 if row < 0 or row >= self.shape[0] or col < 0 or col >= self.shape[1]:
258 raise IndexError
259 self.__value[row][col] = value
260 else:
261 raise ValueError
263 def __mul__(self, other) -> "Dense":
264 """
265 Scalar multiplication
266 :param other: either instance of Scalar class or a "raw" number (int, float, np.number)
267 :return: The Sparse Matrix after multiplication
268 """
269 if not isinstance(other, Scalar) and not isinstance(
270 other, (int, float, np.number)
271 ):
272 return NotImplemented
273 val = other.val() if isinstance(other, Scalar) else np.float64(other)
274 return Dense(val + self.__value)
276 def __rmul__(self, other):
277 """
278 We flip it! (╯°□°)╯ ┻━┻
279 """
280 return self.__mul__(other)
282 @overload
283 def __matmul__(self, other: "Dense") -> "Dense":
284 pass
286 @overload
287 def __matmul__(self, other: "Sparse") -> "Dense":
288 pass
290 def __matmul__(self, other) -> "Dense | Sparse | NotImplemented":
291 """
292 Matrix multiplication
293 :param other: the other matrix to multiply, which needs to implement "MpMatrix"
294 :return: Either Sparse or Dense depending on multiplication
295 """
296 if self.shape[1] != other.shape[0]:
297 raise ValueError(
298 "shape mismatch. Second dimension of A needs to match first dimension of B"
299 )
300 if isinstance(other, Dense):
301 return dense_dense_mult(self, other)
302 elif isinstance(other, Sparse):
303 return dense_sparse_mult(self, other)
304 else:
305 return NotImplemented
307 @overload
308 def __rmatmul__(self, other: "Dense") -> "Dense":
309 pass
311 @overload
312 def __rmatmul__(self, other: "Sparse") -> "Dense":
313 pass
315 def __rmatmul__(self, other) -> "Dense | Sparse | NotImplemented":
316 """
317 Matrix multiplication
318 :param other: the other matrix to multiply, which needs to implement "MpMatrix
319 :return: Either Sparse or Dense depending on multiplication
320 """
321 if isinstance(other, (Sparse, Dense)):
322 return other.__matmul__(self)
323 else:
324 return NotImplemented
326 @property
327 def shape(self) -> tuple[int, int]:
328 return self.__value.shape
330 def __str__(self) -> str:
331 result = ""
332 for k in self.__value:
333 result += f"{k}\n"
334 return result
337"""
338================== Sparse Matrix ==================
339"""
342class Sparse(MpMatrix):
343 """
344 Implementation of the MpMatrix ABC for Sparse Matrices. It uses the CSR format to compress the matrix to avoid storing
345 many eps's
346 """
348 def __init__(self, other):
349 """
350 Creates a Sparse Matrix with CSR. Other should always be an array/list
351 The list is then casted to an np.array and enumerated over to build value, coldInd and rowPtr
352 lists to compress the matrix
353 """
354 val = other if isinstance(other, np.ndarray) else np.asarray(other)
355 if val.ndim != 2:
356 raise ValueError("Sparse Matrix is only defined for 2D")
357 self.__original_shape = val.shape
358 col_ind = []
359 row_ptr = [0]
360 value = []
361 nnz = 0
362 old_row = 0
364 for (row, col), x in np.ndenumerate(val):
365 if row != old_row:
366 row_ptr.append(nnz)
367 old_row = row
368 if np.isfinite(x):
369 value.append(x)
370 col_ind.append(col)
371 nnz += 1
372 row_ptr.append(nnz)
374 self.__colInd = np.asarray(col_ind, dtype=np.int32)
375 self.__rowPtr = np.asarray(row_ptr, dtype=np.int32)
376 self.__value = np.asarray(value, dtype=np.float64)
378 @classmethod
379 def vector(cls, data):
380 """
381 Shortcut to create a sparse vector. Data can be passed in as 1D array to avoid having to write
382 2D array manually
383 :param data: 1D vector
384 :return: Column Vector with values from data, each cell has its own row
385 """
386 if not isinstance(data, (np.ndarray, list)):
387 raise ValueError("Data needs to be a list or np.ndarray")
388 val = data if isinstance(data, np.ndarray) else np.asarray(data)
389 if val.ndim != 1:
390 raise ValueError("Vectors are only defined for 1D")
391 return Sparse([[k] for k in val])
393 @classmethod
394 def from_csr(cls, col_ind, row_ptr, value, original_shape) -> "Sparse":
395 """
396 Packs csr format into sparse matrix
397 :param col_ind: the col indices
398 :param row_ptr: the row information
399 :param value: the non eps values
400 :param original_shape: the shape of the original matrix
401 :return: Sparse matrix
402 """
403 instance = cls.__new__(cls)
404 instance.__colInd = np.asarray(col_ind, dtype=np.int32)
405 instance.__rowPtr = np.asarray(row_ptr, dtype=np.int32)
406 instance.__value = np.asarray(value, dtype=np.float64)
407 instance.__original_shape = original_shape
408 return instance
410 @classmethod
411 def identity(cls, n) -> "Sparse":
412 """
413 Creates an n x n identity matrix
414 :param n: the dimension of the matrix
415 :return: Identity matrix in Sparse format
416 """
417 col_ind = list(range(n))
418 value = [e] * n
419 row_ptr = list(range(n + 1))
420 return cls.from_csr(col_ind, row_ptr, value, (n, n))
422 @classmethod
423 def zeros(cls, n, m) -> "Sparse":
424 """
425 Creates a zero matrix with the dimensions n x n
426 :param n: number of rows
427 :param m: number of columns
428 :return: Zeros matrix in Sparse format
429 """
430 return cls.from_csr([], [0] * (n + 1), [], (n, m))
432 def __add__(self, other) -> "Sparse | Dense | NotImplementedType":
433 """
434 Adds two matrices together
435 :param other: the other matrix to add, which needs to implement "MpMatrix"
436 :return: For Sparse + Sparse -> Sparse, Sparse + Dense -> Dense
437 """
438 if isinstance(other, Sparse):
439 return sparse_sparse_add(self, other)
440 elif isinstance(other, Dense):
441 return dense_sparse_add(self, other)
442 else:
443 return NotImplemented
445 def __pow__(self, n) -> "Sparse":
446 if n < 0:
447 raise ValueError("Negative exponent not supported for matrices")
448 if self.shape[0] != self.shape[1]:
449 raise ValueError(
450 f"Shape mismatch. Pow is only implemented for square matrices"
451 )
452 return cast(
453 Sparse, pow_by_repeated_squaring(self, n, Sparse.identity(self.shape[0]))
454 )
456 def __eq__(self, other) -> bool:
457 """
458 Checks matrices for equality.
459 :param other: needs to be an instance of "MpMatrix"
460 :return: True if equal, False otherwise
461 """
462 if isinstance(other, Dense):
463 return self == other.to_sparse()
464 elif isinstance(other, np.ndarray):
465 if other.ndim != 2:
466 return False
467 return self == Sparse(other)
468 elif isinstance(other, Sparse):
469 return (
470 self.shape == other.shape
471 and np.array_equal(self.__colInd, other.__colInd)
472 and np.array_equal(self.__rowPtr, other.__rowPtr)
473 and np.array_equal(self.__value, other.__value)
474 )
475 return False
477 @property
478 def T(self) -> "Sparse":
479 """
480 Transposes the spar matrix
481 :return: The transposed sparse matrix
482 """
483 return self.transpose()
485 def transpose(self) -> "Sparse":
486 """
487 Transposes the sparse matrix.
488 :return: Transposed sparse matrix.
489 """
490 n_rows, n_cols = self.shape
492 rows = [[] for _ in range(n_cols)]
493 # we need to iterate over all rows (also empty ones)
494 for r in range(n_rows):
495 start = self.__rowPtr[r]
496 end = self.__rowPtr[r + 1]
497 for idx in range(start, end):
498 c = self.__colInd[idx]
499 v = self.__value[idx]
500 rows[c].append((r, v))
502 new_values = []
503 new_col_ind = []
504 new_row_ptr = [0]
506 for new_r in range(n_cols):
507 for new_c, v in rows[new_r]:
508 new_col_ind.append(new_c)
509 new_values.append(v)
510 new_row_ptr.append(len(new_values))
512 return Sparse.from_csr(new_col_ind, new_row_ptr, new_values, (n_cols, n_rows))
514 def to_dense(self) -> Dense:
515 """
516 Converts the sparse matrix into a dense matrix
517 :return: The dense sparse matrix
518 """
519 (val, col_ind, row_ptr) = self.raw()
520 (n_rows, n_cols) = self.shape
521 result_array = sparse_to_ndarray(val, col_ind, row_ptr, n_rows, n_cols)
522 return Dense(result_array)
524 def __getitem__(self, key):
525 """
526 Gets an item by key
527 :param key: if the key is a tuple[int,int] it will return the exact value in that cell.
528 if the key is an int, it will return the whole row without the epsilons
529 :return: Either the row or the cell value
530 """
531 if isinstance(key, (int, np.integer)):
532 if key < 0 or key >= self.__original_shape[0]:
533 raise IndexError
534 # when its not a tuple, we retuen the row at key
535 start = self.__rowPtr[key]
536 end = self.__rowPtr[key + 1]
538 result = np.full(self.__original_shape[1], -np.inf)
539 result[self.__colInd[start:end]] = self.__value[start:end]
540 return result
541 elif isinstance(key, tuple):
542 # we return the exact item
543 (row, col) = key
544 if (
545 row < 0
546 or row >= self.__original_shape[0]
547 or col < 0
548 or col >= self.__original_shape[1]
549 ):
550 raise IndexError
551 start = self.__rowPtr[row]
552 end = self.__rowPtr[row + 1]
553 cols = self.__colInd[start:end]
554 idx = np.where(cols == col)[0]
555 if len(idx) == 0:
556 return eps
557 else:
558 return self.__value[start + idx[0]]
559 else:
560 raise ValueError
562 def __setitem__(self, key, value):
563 if not isinstance(key, tuple):
564 if key < 0 or key >= self.__original_shape[0]:
565 raise IndexError
566 if not isinstance(value, (np.ndarray, list)):
567 raise ValueError(
568 "When setting a row, the value needs to be a list or np.ndarray"
569 )
570 if len(value) != self.__original_shape[1]:
571 raise ValueError("New row needs to match column count of the matrix")
572 start = self.__rowPtr[key]
573 end = self.__rowPtr[key + 1]
574 new_values = []
575 new_columns = []
576 for i in range(len(value)):
577 k = value[i]
578 if np.isfinite(k):
579 new_values.append(k)
580 new_columns.append(i)
581 self.__value = np.concatenate(
582 (self.__value[:start], np.asarray(new_values), self.__value[end:])
583 )
584 self.__colInd = np.concatenate(
585 (self.__colInd[:start], np.asarray(new_columns), self.__colInd[end:])
586 )
587 diff = len(new_columns) - (end - start)
588 for i in range(key + 1, len(self.__rowPtr)):
589 self.__rowPtr[i] += diff
590 else:
591 if not isinstance(value, (np.number, int, float)) and value != eps:
592 raise ValueError("When setting a cell, the value needs to be a number")
593 (row, col) = key
594 if (
595 row < 0
596 or row >= self.__original_shape[0]
597 or col < 0
598 or col >= self.__original_shape[1]
599 ):
600 raise IndexError
601 start = self.__rowPtr[row]
602 end = self.__rowPtr[row + 1]
603 col_ind = self.__colInd[start:end]
604 values = self.__value[start:end]
605 if col in col_ind:
606 pos = np.where(col_ind == col)[0][0]
607 # existing value
608 if np.isfinite(value):
609 values[pos] = value
610 self.__value = np.concatenate(
611 (self.__value[:start], np.asarray(values), self.__value[end:])
612 )
613 else:
614 values = np.concatenate((values[:pos], values[pos + 1 :]))
615 col_ind = np.concatenate((col_ind[:pos], col_ind[pos + 1 :]))
616 self.__value = np.concatenate(
617 (self.__value[:start], values, self.__value[end:])
618 )
619 self.__colInd = np.concatenate(
620 (self.__colInd[:start], col_ind, self.__colInd[end:])
621 )
622 for i in range(row + 1, len(self.__rowPtr)):
623 self.__rowPtr[i] -= 1
624 else:
625 # not existing value
626 if not np.isfinite(value):
627 return
628 insert_pos = len(np.where(col_ind < col)[0])
629 val_before = values[:insert_pos]
630 val_after = values[insert_pos:]
631 col_before = col_ind[:insert_pos]
632 col_after = col_ind[insert_pos:]
634 values = np.concatenate((val_before, np.asarray([value]), val_after))
635 col_ind = np.concatenate((col_before, np.asarray([col]), col_after))
636 self.__value = np.concatenate(
637 (self.__value[:start], np.asarray(values), self.__value[end:])
638 )
639 self.__colInd = np.concatenate(
640 (self.__colInd[:start], np.asarray(col_ind), self.__colInd[end:])
641 )
642 for i in range(row + 1, len(self.__rowPtr)):
643 self.__rowPtr[i] += 1
645 """
646 __matmul__ itself can return 3 different types, depending on what "other" was.
647 For better typing, function overloading to have one distinct return type for each case
648 """
650 @overload
651 def __matmul__(self, other: "Sparse") -> "Sparse":
652 pass
654 @overload
655 def __matmul__(self, other: Dense) -> Dense:
656 pass
658 def __matmul__(self, other) -> "Sparse | Dense | NotImplementedType":
659 """
660 Matrix multiplication
661 :param other: the other matrix to multiply, which needs to implement "MpMatrix"
662 :return: Either Sparse or Dense depending on multiplication
663 """
664 if self.shape[1] != other.shape[0]:
665 raise ValueError(
666 "shape mismatch. Second dimension of A needs to match first dimension of B"
667 )
668 if isinstance(other, Sparse):
669 return sparse_sparse_mult(self, other)
670 elif isinstance(other, Dense):
671 return sparse_dense_mult(self, other)
672 else:
673 return NotImplemented
675 @overload
676 def __rmatmul__(self, other: "Sparse") -> "Sparse":
677 pass
679 @overload
680 def __rmatmul__(self, other: Dense) -> Dense:
681 pass
683 def __rmatmul__(self, other) -> "Sparse | Dense | NotImplementedType":
684 """
685 Matrix multiplication
686 :param other: the other matrix to multiply, which needs to implement "MpMatrix
687 :return: Either Sparse or Dense depending on multiplication
688 """
689 if isinstance(other, (Sparse, Dense)):
690 return other.__matmul__(self)
691 else:
692 return NotImplemented
694 def __mul__(self, other) -> "Sparse":
695 """
696 Scalar multiplication
697 :param other: either instance of Scalar class or a "raw" number (int, float, np.number)
698 :return: The Sparse Matrix after multiplication
699 """
700 if not isinstance(other, Scalar) and not isinstance(
701 other, (int, float, np.number)
702 ):
703 return NotImplemented
704 val = other.val() if isinstance(other, Scalar) else np.float64(other)
705 if val == eps:
706 (n, m) = self.__original_shape
707 return Sparse.zeros(n, m)
708 return Sparse.from_csr(
709 self.__colInd, self.__rowPtr, val + self.__value, self.__original_shape
710 )
712 def __rmul__(self, other) -> "Sparse":
713 """
714 Calls self.__mul__(other)
715 We flip it! (╯°□°)╯ ┻━┻
716 """
717 return self.__mul__(other)
719 def raw(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
720 """
721 The raw CSR data
722 :return: triple of csr, data (values,col_ind,row_ptr)
723 """
724 return self.__value, self.__colInd, self.__rowPtr
726 @property
727 def shape(self) -> tuple[int, int]:
728 """
729 returns the shape
730 :return:
731 """
732 return self.__original_shape
735"""
736===================== Dense x Sparse Matrix Multiplication ==================
737"""
740def dense_sparse_mult(self: Dense, other: Sparse) -> Dense:
741 """
742 Matrix multiplication for Dense x Sparse
743 :param self: Matrix A, which is Dense
744 :param other: Matrix B, which is Sparse
745 :return: New Dense Matrix with result
746 """
747 (values, col_ind, row_ptr) = other.T.raw()
748 return Dense(
749 _dense_sparse_mult(
750 self.raw(), values, col_ind, row_ptr, self.shape[0], other.shape[1]
751 )
752 )
755@numba.njit
756def _dense_sparse_row_all_cols(
757 row: np.ndarray, values: np.ndarray, col_ind: np.ndarray, row_ptr: np.ndarray
758) -> np.ndarray:
759 """
760 Calculates result row for one row with all columns
761 :param row: row of the dense matrix
762 :param values: values of the sparse matrix
763 :param col_ind: col_ind of the sparse matrix
764 :param row_ptr: row_ptr of the sparse matrix
765 :return: the computed row in dense format
766 """
767 result = np.full(len(row_ptr) - 1, eps)
768 for i in range(len(row_ptr) - 1):
769 r_start = row_ptr[i]
770 r_end = row_ptr[i + 1]
771 col_idx = col_ind[r_start:r_end]
772 col_vals = values[r_start:r_end]
773 acc = eps
774 for j in range(len(col_idx)):
775 col = col_idx[j]
776 vv = col_vals[j]
777 rv = row[col]
778 r = vv + rv
779 if r > acc:
780 acc = r
781 result[i] = acc
782 return result
785@numba.njit(parallel=True)
786def _dense_sparse_mult(
787 self: np.ndarray,
788 values: np.ndarray,
789 col_ind: np.ndarray,
790 row_ptr: np.ndarray,
791 row_count: int,
792 col_count: int,
793) -> np.ndarray:
794 """
795 Helper function to have numba compatible datatypes
796 :param self: Dense Matrix raw value (np.ndarray)
797 :param values: CSR values of Matrix B
798 :param col_ind: CSR col_ind of Matrix B
799 :param row_ptr: CSR row_ptr of Matrix B
800 :param row_count: shape[0] of Matrix A
801 :param col_count: shape[1 of Matrix B
802 :return: Result of multiplication as np.ndarray
803 """
804 result = np.full((row_count, col_count), eps)
805 for r in numba.prange(row_count):
806 row = self[r]
807 result[r] = _dense_sparse_row_all_cols(row, values, col_ind, row_ptr)
808 return result
811"""
812===================== Sparse x Dense Matrix Multiplication ==================
813"""
816def sparse_dense_mult(self: Sparse, other: Dense) -> Dense:
817 (values, col_ind, row_ptr) = self.raw()
818 return Dense(_sparse_dense_mult(values, col_ind, row_ptr, other.raw()))
821@numba.njit(parallel=True)
822def _sparse_dense_mult(
823 values: np.ndarray, col_ind: np.ndarray, row_ptr: np.ndarray, other: np.ndarray
824) -> np.ndarray:
825 result = np.full((len(row_ptr) - 1, other.shape[1]), eps)
826 for i in numba.prange(len(row_ptr) - 1):
827 r_start = row_ptr[i]
828 r_end = row_ptr[i + 1]
829 col_idx = col_ind[r_start:r_end]
830 vals = values[r_start:r_end]
831 result[i] = _sparse_dense_row_all_cols(vals, col_idx, other)
832 return result
835def _sparse_dense_row_all_cols(
836 values: np.ndarray, col_ind: np.ndarray, other: np.ndarray
837) -> np.ndarray:
838 return np.max(values[:, None] + other[col_ind, :], axis=0)
841"""
842================== Dense x Dense Matrix Multiplication ==================
843"""
846def dense_dense_mult(self: Dense, other: Dense) -> Dense:
847 """
848 Performs dense x dense matrix multiplication. When the "target" 3D Matrix (which is used within the calculation to store all results)
849 reaches more than 1_000_000 cells, the function created target 2D matrix with size after multiplication and fills in the cells, if less
850 than 1_000_000 a simple numpy broadcast is used
851 :param self: A dense matrix
852 :param other: A dense matrix
853 :return: The result of the multiplication as a dense matrix
854 """
855 if self.shape[1] != other.shape[0]:
856 raise ValueError(
857 f"Shape missmatch. Operand A has {self.shape} and operand B has {other.shape}"
858 )
859 a = self.raw()
860 b = other.raw()
861 m, k = a.shape
862 n = b.shape[1]
864 """
865 I've noticed, that for very big dense matrices, the broadcasting solution got slower when size increased.
866 I compared it to the trivial concept of creating a "empty" target matrix and then updating the correct values.
867 """
868 if m * k * n < 1_000_000:
869 result = np.max(a[:, :, None] + b[None, :, :], axis=1)
870 else:
871 result = np.full((m, n), -np.inf)
872 for i in range(m):
873 result[i] = np.max(a[i, :, None] + b, axis=0)
874 return Dense(result)
877"""
878===================== Sparse x Sparse Matrix Multiplication ==================
879"""
882@numba.njit
883def _sparse_sparse_mult_row_with_col(
884 row_data: np.ndarray, row_col: np.ndarray, col_data: np.ndarray, col_row: np.ndarray
885) -> np.float64:
886 """
887 This function is part of the whole sparse x sparse matrix multiplication implementation. It computes the
888 cell value for a given row and column. Works in O(cols_a + cols_b)
889 :param row_data: the row data of matrix a for the row
890 :param row_col: the column indices of matrix a for the row
891 :param col_data: the column data of the transposed matrix b
892 :param col_row: the column indices of matrix b
893 :return: the max value as of definition
894 """
896 max_value = eps
898 """
899 Old:
900 max_value = eps
901 for a in range(len(row_col)):
902 for b in range(len(col_row)):
903 if row_col[a] != col_row[b]: continue
904 result = row_data[a] + col_data[b]
905 if result > max_value:
906 max_value = result
907 break
908 return max_value
909 has O(n*m)
910 """
911 a = 0
912 b = 0
913 """
914 This has O(n+m). CSR stores col indices sorted, so we can just iterate over both lists and only calculate
915 when we have a match. If row_col a is smaller that col_row b we need to catch up with the matrix b, otherwise with
916 matrix a
917 """
918 while a < len(row_col) and b < len(col_row):
919 if row_col[a] == col_row[b]:
920 result = row_data[a] + col_data[b]
921 if result > max_value:
922 max_value = result
923 a += 1
924 b += 1
925 elif row_col[a] < col_row[b]:
926 a += 1
927 else:
928 b += 1
929 return max_value
932@numba.njit(parallel=True)
933def _sparse_sparse_mult_row_with_all_cols(
934 row_data: np.ndarray,
935 row_col: np.ndarray,
936 values_b: np.ndarray,
937 col_ind_b: np.ndarray,
938 row_ptr_b: np.ndarray,
939) -> tuple[np.ndarray, np.ndarray]:
940 """
941 This function is part of the whole sparse x sparse matrix multiplication implementation. It computes the
942 whole row for a given row and all columns. It iterates over all columns of the transposed matrix b and
943 calls the previous function to compute the cell value for each column.
944 Works in O(rows_b * (cols_a + rows_b))
945 :param row_data: The row data of matrix A
946 :param row_col: The column indices of matrix A
947 :param values_b: The values of matrix B
948 :param col_ind_b: The column indices of matrix B
949 :param row_ptr_b: The row pointer of matrix B
950 :return: Tuple of the computed row data and the column indices
951 """
952 data = np.full(len(row_ptr_b) - 1, eps)
953 cols = np.full(len(row_ptr_b) - 1, -1)
954 for r in numba.prange(len(row_ptr_b) - 1):
955 r_start = row_ptr_b[r]
956 r_end = row_ptr_b[r + 1]
957 col_data = values_b[r_start:r_end]
958 col_row = col_ind_b[r_start:r_end]
959 v = _sparse_sparse_mult_row_with_col(row_data, row_col, col_data, col_row)
960 if not np.isfinite(v):
961 continue
962 data[r] = v
963 cols[r] = r
964 return data[data != eps], cols[cols != -1]
967def _sparse_sparse_mult(
968 values_a, col_ind_a, row_ptr_a, values_b, col_ind_b, row_ptr_b
969) -> tuple[np.ndarray, np.ndarray, list[int]]:
970 """
971 Helper function to take in the raw csr format of both matrices. Was intended to be used with numba.jit but
972 implementation resulted in the use of dynamic python lists, which kills numba.jit benefits.
973 O(rows_a * (rows_ b * (cols_a + rows_b)))
974 :param values_a: values of matrix A
975 :param col_ind_a: column indices of matrix A
976 :param row_ptr_a: row pointer of matrix A
977 :param values_b: values of matrix B
978 :param col_ind_b: column indices of matrix B
979 :param row_ptr_b: row pointer of matrix B
980 :return: Triple of new values, column indices and row indices
981 """
982 row_ptr = [0]
984 """
985 Instead of directly using np arrays with concat we use python lists and concat them into one
986 np array at the end, since the append in python lists is amortized wc O(1)
987 (https://wiki.python.org/moin/TimeComplexity)
988 """
989 all_data = []
990 all_cols = []
992 for r in range(len(row_ptr_a) - 1):
993 r_start = row_ptr_a[r]
994 r_end = row_ptr_a[r + 1]
995 row_data = values_a[r_start:r_end]
996 if len(row_data) == 0:
997 row_ptr.append(row_ptr[-1])
998 continue
999 row_col = col_ind_a[r_start:r_end]
1000 data, cols = _sparse_sparse_mult_row_with_all_cols(
1001 row_data, row_col, values_b, col_ind_b, row_ptr_b
1002 )
1003 all_data.append(data)
1004 all_cols.append(cols)
1005 row_ptr.append(row_ptr[-1] + len(data))
1007 new_values = np.concatenate(all_data) if all_data else np.asarray([])
1008 new_cols = np.concatenate(all_cols) if all_cols else np.asarray([])
1009 return new_values, new_cols, row_ptr
1012def sparse_sparse_mult(self: Sparse, other: Sparse) -> Sparse:
1013 """
1014 Core function to do sparse x sparse. Transposes Matrix B, converts Matrix A and B into raw CSR
1015 and calls _sparse_sparse_mult to do the actual multiplication.
1016 The result is then packed into a new sparse matrix and returned.
1017 The implementation uses the transpose of the second matrix so that columns
1018 can be accessed as CSR rows. Each result entry is computed by merging the
1019 sorted index lists of one row and one column.
1020 :param self: Matrix A
1021 :param other: Matrix B
1022 :return: Result Matrix
1023 """
1024 if self.shape[1] != other.shape[0]:
1025 raise ValueError(
1026 f"Shape missmatch. Operand A has {self.shape} and operand B has {other.shape}"
1027 )
1028 (self_values, self_col_ind, self_row_ptr) = self.raw()
1029 (other_values, other_col_ind, other_row_ptr) = other.T.raw()
1030 (v, c, r) = _sparse_sparse_mult(
1031 self_values,
1032 self_col_ind,
1033 self_row_ptr,
1034 other_values,
1035 other_col_ind,
1036 other_row_ptr,
1037 )
1038 return Sparse.from_csr(c, r, v, (self.shape[0], other.shape[1]))
1041"""
1042================== Add Helper ==================
1043"""
1046def sparse_sparse_add(self: Sparse, other: Sparse) -> Sparse:
1047 if self.shape != other.shape:
1048 raise ValueError(
1049 f"Shape mismatch. Operand A has shape {self.shape} and operand B has shape {other.shape}"
1050 )
1051 new_values = []
1052 new_col_ind = []
1053 new_row_ptr = [0]
1054 (self_values, self_col_ind, self_row_ptr) = self.raw()
1055 (other_values, other_col_ind, other_row_ptr) = other.raw()
1056 for i in range(len(self_row_ptr) - 1):
1057 row_map: dict[int, float] = {}
1058 # map cols -> values of operand A
1059 for j in range(self_row_ptr[i], self_row_ptr[i + 1]):
1060 col = self_col_ind[j]
1061 row_map[col] = self_values[j]
1063 # also map cols -> values of operand B but directly take maximum
1064 for j in range(other_row_ptr[i], other_row_ptr[i + 1]):
1065 col = other_col_ind[j]
1066 val: float = other_values[j]
1067 row_map[col] = max(row_map.get(col, eps), val)
1069 for col in sorted(row_map.keys()):
1070 new_col_ind.append(col)
1071 new_values.append(row_map[col])
1072 new_row_ptr.append(len(new_col_ind))
1074 return Sparse.from_csr(new_col_ind, new_row_ptr, new_values, self.shape)
1077def dense_sparse_add(self: Sparse, other: Dense) -> Dense:
1078 if self.shape != other.shape:
1079 raise ValueError(
1080 f"Shape mismatch. Operand A has shape {self.shape} and operand B has shape {other.shape}"
1081 )
1082 (self_values, self_col_ind, self_row_ptr) = self.raw()
1083 result = other.raw().copy()
1085 for r in range(len(self_row_ptr) - 1):
1086 for c in range(self_row_ptr[r], self_row_ptr[r + 1]):
1087 col = self_col_ind[c]
1088 val = self_values[c]
1089 result[r][col] = max(result[r][col], val)
1090 return Dense(result)
1093"""
1094================== Conversions ==================
1095"""
1098@numba.njit
1099def sparse_to_ndarray(val, col_ind, row_ptr, n_rows, n_cols) -> np.ndarray:
1100 """
1101 Helper function to convert sparse matrix to dense matrix with ugly parameters
1102 :param val: the value list of csr
1103 :param col_ind: the col_ind list of csr
1104 :param row_ptr: the row_ptr list of csr
1105 :param n_rows: row count
1106 :param n_cols: column count
1107 :return: np.ndarray which is casted to Dense later on inside to_dense()
1108 """
1109 result = np.full((n_rows, n_cols), eps)
1110 for i in range(n_rows):
1111 for j in range(row_ptr[i], row_ptr[i + 1]):
1112 result[i, col_ind[j]] = val[j]
1113 return result
1116"""
1117================== Pow ==================
1118https://en.wikipedia.org/wiki/Exponentiation_by_squaring
1119"""
1122def pow_by_repeated_squaring(m: MpMatrix, n: int, identity: MpMatrix) -> MpMatrix:
1123 """
1124 Does pow by repeated squaring
1125 :param m: matrix to square
1126 :param n: pow
1127 :param identity: id base of same type as n
1128 :return: the id matrix after repeated squaring
1129 """
1130 # in wikipedia there is a check for n < 0, since for max plus matrices its n in N we don't need that
1131 while n > 0:
1132 if n % 2 != 0:
1133 identity @= m
1134 n -= 1
1135 m @= m
1136 n //= 2
1137 return identity