In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import poisson
import numpy as np
import torch
import torch.nn as nn
from scipy.stats import ncx2
In [2]:
class Utils:
# Convert string parameter into torch activation function
@staticmethod
def get_activation_function(activation_function_str, negative_slope):
if activation_function_str == "Lrelu":
activation = torch.nn.LeakyReLU(negative_slope=negative_slope)
elif activation_function_str == "relu":
activation = torch.nn.ReLU()
elif activation_function_str == "sigmoid":
activation = torch.nn.Sigmoid()
else:
raise ValueError("Activation function not recognized : {}".format(activation_function_str))
return activation
# Convert string parameter into torch final activation function
@staticmethod
def get_final_activation_function(final_activation):
if final_activation == "Sigmoid":
final_activation = lambda x: torch.sigmoid(x)
elif final_activation == "Clamping":
final_activation = lambda x: torch.clamp(x, min=0., max=1.)
elif final_activation == "Linear":
final_activation = lambda x: x
else:
raise ValueError("final_activation must be Sigmoid, Clamping or Linear")
return final_activation
In [3]:
class Conditional_Chi_Generator(nn.Module):
def __init__(self,
noise_size: int=1,
min_d: float=1.,
max_d: float=1.,
min_lda: float=1.,
max_lda: float=1.,
hidden_dim: int=100,
activation_function: str="Lrelu",
final_activation: str="Sigmoid",
label_size: int=10,
negative_slope: float=0.1,
nb_hidden_layers_G: int=1,
eps: float=0.,
**kwargs):
super(Conditional_Chi_Generator, self).__init__()
self.min_d = min_d
self.max_d = max_d
self.min_lda = min_lda
self.max_lda = max_lda
self.input_dim = noise_size
self.nb_hidden_layers = nb_hidden_layers_G
self.hidden_dim = hidden_dim
self.activation = Utils.get_activation_function(activation_function, negative_slope)
self.input_noise = nn.Linear(self.input_dim, self.hidden_dim)
self.label_size = label_size
self.input_label = nn.Linear(2, self.label_size)
self.fcs = nn.ModuleList()
self.eps = eps
for i in range(self.nb_hidden_layers):
if i == 0:
self.fcs.append(nn.Linear(self.hidden_dim + self.label_size, self.hidden_dim))
else:
self.fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim))
self.output = nn.Linear(self.hidden_dim, 1)
self.final_activation = Utils.get_final_activation_function(final_activation)
def forward(self, x, d, lda):
# Normalize d with min_d and max_d into [0, 1]
if self.max_d == self.min_d:
d = 1
else:
d = (d - self.min_d) / (self.max_d - self.min_d)
# Normalize lambda with min_lda and max_lda into [0, 1]
if self.max_lda == self.min_lda:
lda = 1
else:
lda = (lda - self.min_lda) / (self.max_lda - self.min_lda)
return self.feed_forward(x, d, lda)
def feed_forward(self, x, d, lda):
x = self.activation(self.input_noise(x))
params = self.activation(self.input_label(torch.cat((d, lda), 1)))
x = torch.cat((x, params), 1)
for i in range(self.nb_hidden_layers):
x = self.activation(self.fcs[i](x))
x = self.final_activation(self.output(x))
return x + self.eps
In [9]:
params = {}
# tested_d are d values on which we will perform tests (KS divergence)
params["tested_d"] = [0.5, 1, 2, 10, 20]
params["tested_lda"] = [0, 10, 50, 100]
# displayed_d are d values on which we will display histograms. displayed_d has to be subset of tested_d to be displayed
params["min_d"] = 0.3
params["max_d"] = 10.
params["min_lda"] = 0.
params["max_lda"] = 100
# alpha: chi.ppf(alpha) is network's max return value.
# Because network return values are bounded in [0,1], 1 has to have a draw equivalent which is chi.ppf(alpha)
params["alpha"] = 0.999
# Networks parameters :
# Number of uniform draws for G
params["noise_size"] = 5
# Hidden layer size for d inputs when parametrized (leaving it at 10 is satisfactory)
params["label_size"] = 50
# "Lrelu", "relu" or "sigmoid" (default Lrelu with negative slope 0.1)
params["activation_function"] = "Lrelu"
params["negative_slope"] = 0.1
# Final activation for Generator: Sigmoid, Linear or Clamping
params["final_activation"] = "Sigmoid"
# hidden layer size
params["hidden_dim"] = 100
# number of hidden layers for both D and G after input layer (default 1) : total number of hidden layers is nb_hidden_layers + 1
params["nb_hidden_layers_G"] = 2
params["nb_hidden_layers_D"] = 2
# Load_dir_model is the folder containing the model to load
params["load_dir_model"] = "checkpoints"
Load the best model (PATH should be like "models_poisson/poisson_mu-0.1-20_finalactivation-linear/checkpoints".
Chi2_generator need to get the same features.
In [10]:
Chi2_checkpoint = torch.load(params["load_dir_model"], map_location='cpu')
print(Chi2_checkpoint['epoch'])
Chi2_generator = Conditional_Chi_Generator(**params).to('cpu')
Chi2_generator.load_state_dict(Chi2_checkpoint['generator_state_dict'])
989
Out[10]:
<All keys matched successfully>
Call the best model.
max_d, max_lda, min_d, min_lda, alpha depends on the training (take the same). (M, m) depends on max_d, max_lda, min_d, min_lda, alpha (see training).
In [11]:
def CHI_generator(d, _lda, size, max_d, max_lda, min_d, min_lda, alpha, **kwargs) :
M = ncx2.ppf(alpha, max_d, max_lda)
m = ncx2.ppf(1-alpha, min_d, min_lda)
if d > max_d or lda > max_lda:
_d = max_d
_lda = max_lda
else:
_d = d
_lda = lda
partial_uniforms = torch.rand(size, 5)
d_input = torch.ones(size, 1)*_d
lda_input = torch.ones(size, 1)*_lda
simulated_chi_draws = Chi2_generator(partial_uniforms, d_input, lda_input).detach().numpy()*(M - m) + m
if d > max_d or lda > max_lda:
mean_1 = _d + _lda
var_1 = 2 * (_d + 2 * _lda)
mean_2 = d + lda
var_2 = 2 * (d + 2 * lda)
simulated_chi_draws = np.array([(e - mean_1) / np.sqrt(var_1) * np.sqrt(var_2) + mean_2 for e in simulated_chi_draws])
return simulated_chi_draws.reshape(1,-1)[0]
Plot histograms on tested_d x tested_lda
In [12]:
size = 10000
for d in params['tested_d'] :
for lda in params['tested_lda'] :
true_Chi = ncx2.rvs(d, lda, size = size)
GAN_Chi = CHI_generator(d, lda, size, **params)
print(GAN_Chi)
plt.hist(GAN_Chi, bins=200, alpha=0.75, color = 'red', density=True, label="Simulated distribution")
plt.hist(true_Chi, bins=100, alpha=0.75, density=True, color = 'blue', label="Simulated distribution")
plt.title("d = {}, lambda = {}".format(str(d), str(lda)))
plt.show()
[3.3479127e-01 6.4073812e-04 2.2677474e-01 ... 1.3280369e+00 8.6109430e-02 9.2289643e-04]
[ 9.879395 16.422916 17.174894 ... 3.9016144 6.30216 14.003506 ]
[48.107883 45.760303 60.81476 ... 39.69779 32.335545 27.83881 ]
[110.61104 100.371445 105.78386 ... 104.075806 104.99056 62.235317]
[0.12775859 0.5653011 0.50445455 ... 0.00796793 6.315169 0.25586033]
[10.706098 15.215473 13.5059805 ... 10.748959 14.424066 17.05218 ]
[54.13307 54.796635 38.304226 ... 68.352 56.700527 47.717304]
[ 57.644592 90.416725 83.665245 ... 87.64964 122.42915 102.743324]
[2.0486002 3.735958 1.2911066 ... 0.27291632 1.487199 2.7643454 ]
[ 6.076557 12.974577 17.306206 ... 7.336929 11.360194 14.5023775]
[60.348663 36.389885 44.12204 ... 31.97676 29.59617 41.03645 ]
[123.00268 126.88816 85.43255 ... 110.69222 86.40935 67.05214]
[10.636717 4.224041 3.8419163 ... 4.9751334 6.8579974 13.746451 ]
[25.318956 25.771511 19.044113 ... 18.693462 9.302665 13.454223]
[62.733097 31.805643 63.62578 ... 39.398438 41.242714 78.30022 ]
[115.23342 80.90468 117.50503 ... 120.70757 123.35385 101.94164]
[21.936035 21.603977 26.061064 ... 7.659603 22.993664 21.091455]
[39.80445 35.308857 37.41305 ... 22.458147 35.773674 42.27864 ]
[90.96896 82.371315 75.74946 ... 48.18065 61.689667 55.96617 ]
[ 97.79124 98.345825 89.31103 ... 126.76405 104.234375 106.14512 ]
In [ ]: