| import torch |
| import numpy as np |
| from functools import reduce, partial |
| from operator import mul |
| from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations |
|
|
|
|
| def chain_functions(*functions): |
| return lambda initial: reduce(lambda x, f: f(x), functions, initial) |
|
|
|
|
| def remove_fx_parametrisation(fx): |
| def remover(m): |
| if not is_parametrized(m): |
| return |
| for k in list(m.parametrizations.keys()): |
| remove_parametrizations(m, k) |
|
|
| fx.apply(remover) |
| return fx |
|
|
|
|
| def get_chunks(keys, original_shapes): |
| (position, _), *_ = filter(lambda i_k: "U.original" in i_k[1], enumerate(keys)) |
| original_chunks = list(map(partial(reduce, mul), original_shapes)) |
| U_matrix_shape = original_shapes[position] |
|
|
| dimensions_not_need = np.ravel_multi_index( |
| np.tril_indices(**dict(zip(("n", "m"), U_matrix_shape))), U_matrix_shape |
| ) + sum(original_chunks[:position]) |
|
|
| selected_chunks = ( |
| original_chunks[:position] |
| + [original_chunks[position] - dimensions_not_need.size] |
| + original_chunks[position + 1 :] |
| ) |
| return selected_chunks, position, U_matrix_shape, dimensions_not_need |
|
|
|
|
| def vec2statedict( |
| x: torch.Tensor, |
| keys, |
| original_shapes, |
| selected_chunks, |
| position, |
| U_matrix_shape, |
| ): |
| chunks = list(torch.split(x, selected_chunks)) |
| U = x.new_zeros(reduce(mul, U_matrix_shape)) |
| U[ |
| np.ravel_multi_index( |
| np.triu_indices(n=U_matrix_shape[0], k=1, m=U_matrix_shape[1]), |
| U_matrix_shape, |
| ) |
| ] = chunks[position] |
| chunks[position] = U |
|
|
| state_dict = dict( |
| zip( |
| keys, |
| map(lambda x, shape: x.reshape(*shape), chunks, original_shapes), |
| ) |
| ) |
| return state_dict |
|
|