Source code for organ.solver

"""Main module for training and testing of the OrGAN.
"""

import numpy as np
import os
import time
import datetime
from itertools import repeat

import torch
import torch.nn.functional as F

from organ.models import CPGenerator, CPDiscriminator, \
    EdgeAwareGenerator, Discriminator
from organ.structure.models import Organization
from organ.data.organization_structure_dataset \
    import OrganizationStructureDataset
from organ.utils import MetricsAggregator, all_scores


class Normalizer:

    def __init__(self, device, per_feature=False):
        self.per_feature = per_feature
        self.device = device

    def fit(self, x: np.ndarray):
        if self.per_feature:
            self.m = np.max(x, axis=0)
        else:
            self.m = np.max(x)
        self.mt = torch.tensor(self.m).to(self.device)

    def transform(self, x):
        if isinstance(x, np.ndarray):
            return x / self.m
        return x / self.mt

    def reverse_transform(self, x):
        if isinstance(x, np.ndarray):
            return x * self.m
        return x * self.mt


[docs] class Solver(object): """Class for training and testing the OrGAN model.""" def __init__(self, config): """Constructor. Parameters ---------- config : namespace, argparse.Namespace An object with configuration parameter values. """ # Training problem specification # # Conditional generation (cGAN) self.conditional = config.conditional # Parametric generation self.parametric = config.parametric # Organization structure model (describing how # the organization should be evaluated) self.org_model = config.rules # Quality metrics self.org_metrics = MetricsAggregator(self.org_model) # Dataset self.data = OrganizationStructureDataset(load_params=self.parametric, load_cond=self.conditional) self.data.load(config.data_dir) # Models configuration (generator, discriminator, # approximator) # Dimensions of the generator input self.z_dim = config.z_dim # The number of node types self.m_dim = self.data.node_num_types # The number of edge types self.b_dim = self.data.edge_num_types # Dimensions of the fully-connected layers group # at the beginning of the generator self.g_conv_dim = config.g_conv_dim # G's edge convolution specification self.g_edge_conv_dim = config.g_edge_conv_dim # Specification of G's fully connected layers for parameter values self.g_params_fc_dim = config.g_params_fc_dim # Спецификация преобразований, которые должны # быть реализованы дискриминатором и аппроксиматором. # Состоит из трех компонент: # graph_conv_dim (список, описывающий параметры графовых сверток, # в частности, размерности представлений вершин), aux_dim # (количество признаков в глобальном представлении графа) и # linear_dim (список, задающий количества нейронов в серии # полносвязных слоев) self.d_conv_dim = config.d_conv_dim # Specification of a fully-connected block at the end of the # discriminator self.d_fc_dim = config.d_fc_dim # Condition encoder parameters of the D self.d_cond_enc_dim = config.d_cond_enc_dim # Вес для штрафа на величину градиента в функции оптимизации self.lambda_gp = config.lambda_gp # Метод постобработки сгенерированных графов self.post_method = config.post_method # Список метрик организационной структуры, которые будут # использоваться при обучении (all - все) self.metric = 'all' # Конфигурация процесса обучения # # Размер батча self.batch_size = config.batch_size # Количество итераций (батчей) в процессе обучения self.num_iters = config.num_iters # Количество итераций (перед последней, num_iters) в течение # которых будет осуществляться снижение константы обучения self.num_iters_decay = config.num_iters_decay # Константа обучения для генератора self.g_lr = config.g_lr # Константа обучения для дискриминатора self.d_lr = config.d_lr # Дропаут (одно и то же значение используется везде, между # каждой парой слоев) self.dropout = config.dropout # Периодичность тренировки генератора # (каждые n_critic батчей) self.n_critic = config.n_critic # beta1 для Adam (при обучении всех моделей) self.beta1 = config.beta1 # beta2 для Adam (при обучении всех моделей) self.beta2 = config.beta2 # Итерация, с которой нужно продолжить процесс обучения. # Если значение не 0, то все модели будут загружены из # точек сохранения и процесс продолжен. self.resume_iters = config.resume_iters # Конфигурация процесса тестирования # # Указание на то, какую именно модель следует тестировать # (модель, созданную после test_iters итераций обучения). self.test_iters = config.test_iters # Miscellaneous. self.use_tensorboard = config.use_tensorboard self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Директории # # Директория записи журнала (используется только с # Tensorboard) self.log_dir = config.log_dir # Директория для сохранения моделей # (из этой же директории они будут подгружаться при необходимости # продолжить обучение) self.model_save_dir = config.model_save_dir # Directory to write samples during training self.samples_dir = config.samples_dir # Настройка периодичности вывода информации # # Периодичность записи данных в журнал (для Tensorboard) self.log_step = config.log_step # Периодичность сохранения моделей self.model_save_step = config.model_save_step # Периодичность изменения констант обучения. # Этим параметром регулируется то, как часто будет оцениваться # необходимость ревизии констант. См. также `num_iters_decay`. self.lr_update_step = config.lr_update_step # Should we pretrain? self.pretrain = config.pretrain # For the log to be informative, it should contain quality # characteristics of only generated structures assert self.log_step % self.n_critic == 0 # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() # Build normalizers for float features of the dataset if self.parametric: self.node_features_normalizer = Normalizer(self.device) self.node_features_normalizer.fit(self.data.node_params) if self.conditional: self.cond_normalizer = Normalizer(self.device, per_feature=True) self.cond_normalizer.fit(self.data.cond)
[docs] def build_model(self): """Create neural models (generator, discriminator, and approximator). """ print('Max nodes:', self.data.vertexes) print('Node types:', self.data.node_num_types, self.m_dim) print('Edge types:', self.data.edge_num_types, self.b_dim) if not self.parametric and not self.conditional: self.G = EdgeAwareGenerator(self.g_conv_dim, self.g_edge_conv_dim, self.z_dim, self.data.vertexes, self.data.edge_num_types, self.dropout) else: self.G = CPGenerator(self.g_conv_dim, # Graph encoding self.g_edge_conv_dim, # Edge convs self.g_params_fc_dim, # Parameters FC self.z_dim, self.data.condition_dim, self.data.vertexes, self.data.edge_num_types, self.data.features_per_node, self.dropout) # NOTE: Архитектуры дискриминатора и аппроксиматора полностью # идентичны. if not self.parametric and not self.conditional: self.D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout) self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout) else: self.D = CPDiscriminator(self.d_conv_dim, # self.d_fc_dim, # FC at the end self.d_cond_enc_dim, # Condition encoding self.m_dim, self.b_dim, self.data.condition_dim, self.data.features_per_node, self.dropout) self.V = CPDiscriminator(self.d_conv_dim, # self.d_fc_dim, # FC at the end self.d_cond_enc_dim, # Condition encoding self.m_dim, self.b_dim, self.data.condition_dim, self.data.features_per_node, self.dropout) # Совместный оптимизатор для генератора self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) # Оптимизатор для дискриминатора self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) # Оптимизатор для аппроксиматора self.v_optimizer = torch.optim.Adam(self.V.parameters(), self.g_lr, # SIC! [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') self.print_network(self.V, 'V') self.G.to(self.device) self.D.to(self.device) self.V.to(self.device)
[docs] def load_pretrained(self): """Load pretrained models.""" for model_code, model in [('G', self.G), ('D', self.D), ('V', self.V)]: # if there are pre-trained models and they are compatible path = os.path.join(self.model_save_dir, f'pre-{model_code}.ckpt') try: model.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)) # noqa: E501 print(f'Pretrained {model_code} has been loaded.') except Exception: print(f'Can"t load pre-trained {model_code} model, starting from scratch.') # noqa: E501
[docs] def print_network(self, model, name): """Print model description. Parameters ---------- model : torch.Module Model to print. name : str Model name (only for readability purposes). """ num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params))
[docs] def restore_model(self, resume_iters): """Load models from a savepoint. Load the state of all models (generator, discriminator, and approximator) from a savepoint, located at `model_save_dir`. Parameters ---------- resume_iters : int Iteration number, to specify a model savepoint. """ print('Loading the trained models from step {}...'.format(resume_iters)) # noqa: E501 G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) # noqa: E501 D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) # noqa: E501 V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(resume_iters)) # noqa: E501 self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) # noqa: E501 self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) # noqa: E501 self.V.load_state_dict(torch.load(V_path, map_location=lambda storage, loc: storage)) # noqa: E501
[docs] def build_tensorboard(self): """Tensorboard logging initialization.""" from logger import Logger self.logger = Logger(self.log_dir)
[docs] def update_lr(self, g_lr, d_lr): """Sets learning rate constants (for all the models). Parameters ---------- g_lr : float Learning rate for the generator (and approximator). d_lr : float Learning rate for the discriminator. """ for param_group in self.g_optimizer.param_groups: param_group['lr'] = g_lr for param_group in self.d_optimizer.param_groups: param_group['lr'] = d_lr for param_group in self.v_optimizer.param_groups: param_group['lr'] = g_lr # SIC!
[docs] def reset_grad(self): """Reset gradients of all optimizers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() self.v_optimizer.zero_grad()
[docs] def gradient_penalty(self, y, x): """Gradient penalty. (L2_norm(dy/dx) - 1)**2 """ weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm - 1)**2)
[docs] def label2onehot(self, labels, dim): """Transform labels into one-hot encoded vectors. Given tensor with integer values `labels` is extended by one dimensions, in which these labels are converted into one-hot codes. Parameters ---------- labels : torch.tensor (int64) Tensor with non-negative integer labels. dim : int Number of categories in `labels` tensor. This number becomes the size of the new dimension of the output tensor. The specified number must be greater than the max value of `labels`. Returns ------- torch.tensor (float) Real-valued tensor, consisting of zeros and ones. """ out = torch.zeros(list(labels.size())+[dim]).to(self.device) out.scatter_(len(out.size()) - 1, labels.unsqueeze(-1), 1.) return out
[docs] def sample_z(self, batch_size): """Form samples from the input distribution of the generator.""" return np.random.normal(0, 1, size=(batch_size, self.z_dim))
[docs] def postprocess(self, inputs, method, temperature=1.): """Postprocessing by one of the differentiable discretization methods. The method is used to transform matrices, describing edges of a graph (without activations) to a representation, where an edge can have only one type (or be marked as absent). In other words, the representation is transformed into one, consisting of ones and zeroes (almost). Parameters ---------- inputs : torch.tensor, tuple [torch.tensor], list [torch.tensor] Input tensors to transform. method : str Transformation type: `soft_gumbel`, `hard_gumbel`, `softmax`. temperature : float Transformation parameter. Returns ------- list [torch.tensor] The list of output tensors, with transformation applied to the last dimension. If `inputs` was one tensor, the result is still a list, though one-element. """ def listify(x): return x if type(x) == list or type(x) == tuple else [x] def delistify(x): return x if len(x) > 1 else x[0] if method == 'soft_gumbel': softmax = [F.gumbel_softmax(e_logits.contiguous(). view(-1, e_logits.size(-1)) / temperature, # noqa: E501 hard=False).view(e_logits.size()) for e_logits in listify(inputs)] elif method == 'hard_gumbel': softmax = [F.gumbel_softmax(e_logits.contiguous(). view(-1, e_logits.size(-1)) / temperature, # noqa: E501 hard=True).view(e_logits.size()) for e_logits in listify(inputs)] else: softmax = [F.softmax(e_logits / temperature, -1) for e_logits in listify(inputs)] return [delistify(e) for e in (softmax)]
[docs] def postprocess_nodes(self, nodes_logits): """Transforms a list of node logits into richer form. Most code assumes, that the set of graph nodes is described by tensor vertexes x node_types. However, in the case of organization structures it turns out that a node and a node type are mostly synonyms (there can be at most one node of a given type). Therefore, generator returns only logits of presence of certain types of nodes, and this method transforms these logits into batch of vertex x node_types tensors, placing the values on diagonal and complementing the probability of node absence. Parameters ---------- nodes_logits : pytorch.tensor Batch of logits for node presence, batch x vertexes. Returns ------- torch.tensor Batch of specifications batch x vertexes x nodes. """ nodes_sigm = torch.sigmoid(nodes_logits) nodes_hat = torch.diag_embed(nodes_sigm) nodes_hat[:, :, 0] += (1 - nodes_sigm) return nodes_hat
[docs] def reward(self, orgs): """Structural reward. The method calculates a vector of structural reward values for the given batch of organization descriptions. The definition of structural reward can be project-specific (the list of metrics is defined in `self.metric`) and relies on various metrics defined in `org_model` passed to the constructor. Parameters ---------- orgs : list A list of organization specifications. Returns ------- numpy.ndarray, shape (batch_size, 1) Batch of reward values. """ return self.org_metrics.valid_scores(orgs).reshape(-1, 1)
[docs] def train(self): """Training cycle.""" def compute_gp_loss(a_tensor, x_tensor, edges_hat, nodes_hat): eps = torch.rand(a_tensor.size(0), 1, 1, 1).to(self.device) x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True) # noqa: E501 x_int1 = (eps.squeeze(-1) * x_tensor + (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True) grad0, grad1 = self.D(x_int0, x_int1, None) d_loss_gp = self.gradient_penalty(grad0, x_int0) + \ self.gradient_penalty(grad1, x_int1) return d_loss_gp def process_batch(a_tensor, x_tensor, params, cond, orgs, z, critic=True, is_training=True): """Обработка батча.""" # Set the networks into the required mode self.G.train(is_training) self.D.train(is_training) self.V.train(is_training) torch.set_grad_enabled(is_training) # =============================================================== # # 1. Train the discriminator # # =============================================================== # # Compute loss with real structures. logits_real = self.D(a_tensor, x_tensor, params, # node features cond) # condition # minimize: -log(D(real)) - log(1-D(G(z))) d_loss_real = -torch.mean(torch.log(torch.sigmoid(logits_real))) # Compute loss with fake structures. edges_hat, nodes_hat, params_hat = self._invoke_G(z, cond) logits_fake = self.D(edges_hat, nodes_hat, params_hat, # node features cond) # condition d_loss_fake = -torch.mean(torch.log(1 - torch.sigmoid(logits_fake))) # noqa: E501 # Compute loss for gradient penalty. # NOTE: It doesn't account for parametric gradient if True: d_loss_gp = 0.0 else: d_loss_gp = compute_gp_loss(a_tensor, x_tensor, edges_hat, nodes_hat) if is_training: # Backward and optimize. d_loss = d_loss_fake + d_loss_real + self.lambda_gp * d_loss_gp self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() # loss['D/loss_gp'] = d_loss_gp.item() # =============================================================== # # 2. Train the generator and approximator # # =============================================================== # if critic: # =========================================================== # # 2.1 Train the approximator # # =========================================================== # # Получить батч из генератора edges_hat, nodes_hat, params_hat = self._invoke_G(z, cond) # Получить оценку настоящих образцов с помощью "черного ящика" # Real Reward rewardR = torch.from_numpy(self.reward(orgs)).to(self.device) # Получить оценку сгенерированного батча с помощью # "черного ящика" # Fake Reward orgs = self._orgs_from(edges_hat, nodes_hat, params_hat, cond) rewardF = torch.from_numpy(self.reward(orgs)).to(self.device) # Скорректировать веса аппроксиматора с учетом ошибки # предсказаний "валидности" образцов # Value loss value_proba_real = self.V(a_tensor, x_tensor, params, cond, torch.sigmoid) value_proba_fake = self.V(edges_hat, nodes_hat, params_hat, cond, torch.sigmoid) v_loss = torch.mean((value_proba_real - rewardR) ** 2 + (value_proba_fake - rewardF) ** 2) if is_training: self.reset_grad() v_loss.backward() self.v_optimizer.step() loss['V/loss'] = v_loss.item() # =========================================================== # # 2.2 Train the generator # # =========================================================== # # Получить батч из генератора edges_hat, nodes_hat, params_hat = self._invoke_G(z, cond) # Оценить правдоподобие с точки зрения дискриминатора logits_fake = self.D(edges_hat, nodes_hat, params_hat, # node features cond) # condition # minimize: - log(D(G(z))) (mimic real) # g_loss_fake = -torch.mean(logits_fake) g_loss_fake = -torch.mean(torch.log(torch.sigmoid(logits_fake))) # noqa: E501 # Оценить выполнение требований, описываемых аппроксиматором value_proba_fake = self.V(edges_hat, nodes_hat, params_hat, # node features cond, # condition torch.sigmoid) # Мы хотим, чтобы сгенерированные образцы удовлетворяли # критериям, аппроксимируемым V, то есть V выдавал # для них 1.0 g_loss_value = -torch.mean(torch.log(value_proba_fake)) # Тут также может быть расчет других, дифференцируемых, # характеристик сгенерированной структуры if hasattr(self.org_model, 'soft_constraints'): # User-level function has to deal with non-normalized # values params_hat_ = self.node_features_normalizer.\ reverse_transform(params_hat) \ if params_hat is not None else None cond_ = self.cond_normalizer.reverse_transform( cond) if cond is not None else None g_loss_soft_constraints = self.org_model.soft_constraints( nodes_hat, edges_hat, params_hat_, cond_) else: g_loss_soft_constraints = torch.tensor(0.0, device=self.device) # В итоге функция потерь для генератора складывается из # потерь неправдоподобности (g_loss_fake) потерь, связанных с # нарушением ограничений V (g_loss_value) и прочих потерь # Backward and optimize. g_loss = g_loss_fake + g_loss_value + g_loss_soft_constraints if is_training: self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_value'] = g_loss_value.item() loss['G/loss_soft'] = g_loss_soft_constraints.item() return orgs, loss # Learning rate cache for decaying. g_lr = self.g_lr d_lr = self.d_lr # Start training from scratch or resume training. start_iters = 0 if self.resume_iters: start_iters = self.resume_iters self.restore_model(self.resume_iters) elif self.pretrain: print('Start pre-training...') self._pretrain() # self.load_pretrained() # Start training. print('Start training...') start_time = time.time() for i in range(start_iters, self.num_iters): # Получение очередного батча, его подготовка и загрузка на # устройство x_tensor, a_tensor, params, cond, orgs, z = self._next_batch('train') # noqa: E501 # Обработка обучающего батча, пересчет весов orgs, loss = process_batch(a_tensor, x_tensor, params, cond, orgs, z, critic=((i+1) % self.n_critic == 0), is_training=True) # Валидация и вывод информации о текущем качестве моделей if (i+1) % self.log_step == 0: # Получение валидационного батча x_tensor, a_tensor, params, cond, orgs, z = self._next_batch('validation') # noqa: E501 # Обработка обучающего батча, пересчет весов orgs, loss = process_batch(a_tensor, x_tensor, params, cond, orgs, z, critic=((i+1) % self.n_critic == 0), is_training=False) et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) # noqa: E501 # Log update m0, m1 = all_scores(self.org_metrics, orgs, self.data, norm=True) # 'orgs' is output of Fake Reward # noqa: E501 m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()} m0.update(m1) loss.update(m0) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i+1) # =============================================================== # # 4. Miscellaneous # # =============================================================== # # Save model checkpoints. if (i+1) % self.model_save_step == 0: G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) # noqa: E501 D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) # noqa: E501 V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(i+1)) # noqa: E501 torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) torch.save(self.V.state_dict(), V_path) print('Saved model checkpoints into {}...'.format(self.model_save_dir)) # noqa: E501 if (i+1) % 10000 == 0: if self.samples_dir is not None: self._write_samples(os.path.join(self.samples_dir, f'samples-{i+1}.txt'), orgs[:self.batch_size]) # Decay learning rates. if (i+1) % self.lr_update_step == 0 and \ (i+1) > (self.num_iters - self.num_iters_decay): g_lr -= (self.g_lr / float(self.num_iters_decay)) d_lr -= (self.d_lr / float(self.num_iters_decay)) self.update_lr(g_lr, d_lr) print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) # noqa: E501
[docs] def test(self): """Model testing.""" # Load the trained generator. self.restore_model(self.test_iters) self.G.eval() self.D.eval() self.V.eval() with torch.no_grad(): # В сущности, для тестирования реальные образцы нам # не нужны, потому что мы просто хотим оценить # правдоподобность генерируемых изображений # # Note, that testing code loads all models at once, # potentially it may result in memory problems. n, _, __, cond = self.data.next_test_batch() cond = self.cond_normalizer.transform(cond) cond = torch.from_numpy(cond).to(self.device).float() z = self.sample_z(n.shape[0]) z = torch.from_numpy(z).to(self.device).float() # Z-to-target edges_hat, nodes_hat, params_hat = self._invoke_G(z, cond) orgs = self._orgs_from(edges_hat, nodes_hat, params_hat, cond) # Log update m0, m1 = all_scores(self.org_metrics, orgs, self.data, norm=True) # 'orgs' is output of Fake Reward # noqa: E501 m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()} m0.update(m1) log = 'Testing on {} structures: '.format(n.shape[0]) if m0: log += ', '.join(["{}: {:.4f}".format(tag, value) for tag, value in m0.items()]) print(log)
[docs] def generate(self, batch_size: int = 1, ctx=None): """Generate a batch of samples. Parameters ---------- batch_size : int Number of samples to generate. ctx Context for the samples to be generated. May be optional. """ if ctx is not None: if not isinstance(ctx, np.ndarray): ctx = np.array(ctx) if ctx.ndim == 1: ctx = np.stack([ctx] * batch_size, axis=0) elif ctx.ndim == 2: if ctx.shape[0] == 1: ctx = np.concatenate([ctx] * batch_size, axis=0) elif ctx.shape[0] != batch_size: raise ValueError('For two-dimensional ctx, the first ' 'dimension must be 1 or match the ' 'batch size') else: pass # ctx is fine as it is ctx = self.cond_normalizer.transform(ctx) ctx = torch.from_numpy(ctx).to(self.device).float() # Load the trained generator. self.restore_model(self.test_iters) self.G.eval() self.D.eval() self.V.eval() with torch.no_grad(): # Sample noise z = self.sample_z(batch_size) # Make tensor and pass to the device z = torch.from_numpy(z).to(self.device).float() # Z-to-target edges_hat, nodes_hat, params_hat = self._invoke_G(z, ctx) # Convert to organizations return self._orgs_from(edges_hat, nodes_hat, params_hat, ctx)
[docs] def generate_valid(self, n: int, ctx=None, max_generate: int = 1000): """Generate valid organizations. Parameters ---------- n : int The number of valid organizations to generate. ctx : np.ndarray Condition (context) features, (n_features, ). max_generate : int Maximal number of instances to generate. If the underlying model accuracy is low, it may take too much time to generate the required number of valid organizations. This parameter helps to control the process and stop generation even if the required count isn't achieved. Returns ------- list The list of organizations, containing not more than `n` instances of `Organization` class. """ if ctx is not None: if not isinstance(ctx, np.ndarray): ctx = np.array(ctx) if ctx.ndim == 1 or (ctx.ndim == 2 and ctx.shape[0] == 1): pass else: raise ValueError('ctx must be (n_features, ) or (1, n_features)') # noqa: E501 valid_orgs = [] batch_size = 32 n_generated = 0 while len(valid_orgs) < n and n_generated < max_generate: candidates = self.generate(batch_size, ctx=ctx) valid_orgs.extend([org for org in candidates if self.org_model.validness(org)]) n_generated += batch_size return valid_orgs[:n]
def _pretrain(self): """Pretrain models.""" BATCH_SIZE = min(50, self.data.train_count) def pretrain_validator(model, target_nodes, target_edges, target_params, cond, max_iters=1000): loss_fn = torch.nn.BCELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for i in range(max_iters): # Батч "мусора" от генератора optimizer.zero_grad() input_z = self.sample_z(BATCH_SIZE) z = torch.from_numpy(input_z).to(self.device).float() edges_hat, nodes_hat, params_hat = self._invoke_G(z, cond) value_bad = model(edges_hat, nodes_hat, params_hat, cond, torch.sigmoid) loss_bad = loss_fn(value_bad, torch.zeros((BATCH_SIZE, 1), dtype=torch.float32, device=self.device)) loss_bad.backward() optimizer.step() loss_bad = loss_bad.detach().cpu().item() # Хорошие модели optimizer.zero_grad() value_good = model(target_edges, target_nodes, target_params, cond, torch.sigmoid) loss_good = loss_fn(value_good, torch.ones((BATCH_SIZE, 1), dtype=torch.float32, device=self.device)) loss_good.backward() optimizer.step() def pretrain_generator(z, target_nodes, target_edges, target_params, cond, max_iters=1000, loss_eps=0.01): loss_fn = torch.nn.BCELoss() optimizer = torch.optim.Adam(self.G.parameters(), lr=0.001) for i in range(max_iters): optimizer.zero_grad() edges_hat, nodes_hat, params_hat = self._invoke_G(z, cond) loss = loss_fn(nodes_hat, target_nodes) + \ loss_fn(edges_hat, target_edges) loss.backward() optimizer.step() loss_value = loss.detach().cpu().item() if loss_value < loss_eps: break print(f'Generator loss@pretrain: {loss_value:.5f}') # Get a batch of real examples x_tensor, a_tensor, params, cond, _, z = self._next_batch('train', BATCH_SIZE) # noqa: E501 # Use these examples to give the validator and discriminator # ideas of what is good and evil pretrain_validator(self.V, x_tensor, a_tensor, params, cond, max_iters=10) pretrain_validator(self.D, x_tensor, a_tensor, params, cond, max_iters=10) # Use the noise to pretrain the generator. # It tries to map each point to the respective # sample pretrain_generator(z, x_tensor, a_tensor, params, cond, max_iters=1000, loss_eps=0.01) def _invoke_G(self, z, cond): """Generate a batch of graphs.""" edges_logits, nodes_logits, node_params = self.G(cond, z) # Postprocess with Gumbel softmax edges_hat = self.postprocess((edges_logits, ), self.post_method)[0] nodes_hat = self.postprocess_nodes(nodes_logits) return edges_hat, nodes_hat, node_params def _next_batch(self, mode: str, batch_size=None): """Retrieve next batch and load it to the device. Parameters ---------- mode: str Specification of what set to use: 'train' or 'validation'. Returns ------- tuple A tensor of nodes (batch, nodes, nodes), a tensor of edges (batch, nodes, nodes, edges), a list of structures corresponding to the batch, and z-noise to use as an input for the generator. """ if batch_size is None: batch_size = self.batch_size if mode == 'train': x, a, p, c = self.data.next_train_batch(batch_size) elif mode == 'validation': x, a, p, c = self.data.next_validation_batch() else: raise ValueError(f'Unknown mode: \'{mode}\'. ' 'Only ''train'' and ''validation'' supported') z = self.sample_z(x.shape[0]) # Батчи одинакового размера orgs = [Organization(x_, a_, node_features=p_, condition=c_) for x_, a_, p_, c_ in zip(x, a, p if p is not None else repeat(None), c if c is not None else repeat(None) )] # a is a (self.batch_size, 12, 12) numpy array - adjacency matrices (a_ij is the number of connections) # noqa: E501 # x is a (self.batch_size, 12) numpy array - node type (categorical, 0 for no-node) # noqa: E501 # Загрузим данные на вычислительное устройство и приведем в вид, # ожидаемый нейронными сетями a = torch.from_numpy(a).to(self.device).long() # Adjacency. x = torch.from_numpy(x).to(self.device).long() # Nodes. a_tensor = self.label2onehot(a, self.b_dim) x_tensor = self.label2onehot(x, self.m_dim) z = torch.from_numpy(z).to(self.device).float() # If it is parametric generation, we have to normalize # the parameters. Otherwise, it will be None if self.parametric: p = self.node_features_normalizer.transform(p) p = torch.from_numpy(p).to(self.device).float() # If it is conditional generation, then the condition # must also be normalized if self.conditional: c = self.cond_normalizer.transform(c) c = torch.from_numpy(c).to(self.device).float() return x_tensor, a_tensor, p, c, orgs, z def _orgs_from(self, edges_hat, nodes_hat, params_hat, cond): """Получение организационных структур из выходных данных генератора. """ # Раньше было вот так # edges_hard = self.postprocess((edges_logits, ), # self.post_method)[0] # nodes_hard = self.postprocess_nodes(nodes_logits) edges_hard = torch.max(edges_hat, -1)[1] # раньше было hard nodes_hard = torch.max(nodes_hat, -1)[1] # раньше было hard orgs = [self.data.matrices2graph(n_.detach().cpu().numpy(), e_.detach().cpu().numpy(), strict=True) for e_, n_ in zip(edges_hard, nodes_hard)] if params_hat is not None: # Back to the original domain params = params_hat.detach().cpu().numpy() params = self.node_features_normalizer.reverse_transform(params) # Post-process parameters # TODO move this code somewhere else # Get rid of the negative, remove for non-existing nodes params = np.maximum(params, np.zeros_like(params)) params = params * (np.expand_dims(nodes_hard.detach() .cpu().numpy(), -1) > 0) else: params = repeat(None) if cond is not None: cond = cond.detach().cpu().numpy() cond = self.cond_normalizer.reverse_transform(cond) else: cond = repeat(None) return [Organization(x_, a_, node_features=p_, condition=c_) for (x_, a_), p_, c_ in zip(orgs, params, cond)] def _write_samples(self, filename: str, orgs, log: str = None) -> None: """Writes samples to a given file.""" with open(filename, 'w') as f: for i, org in enumerate(orgs): print(f'Sample #{i}:', '\nContext\n', org.condition, '\nNodes:\n', org.nodes, '\nStaff:\n', org.node_features, '\nEdges:\n', org.edges, '\nCheck results:\n', self.org_model.check_paramater_feasibility(org.nodes, # noqa: E501 org.node_features, # noqa: E501 # logging=True, # noqa: E501 ctx=org.condition), # noqa: E501 file=f) print('=======', file=f) if log is not None: print(log, file=f)