66from matplotlib .pyplot import figure as _figure
77
88
9- def spaceplots (inputs , outputs , input_names = None , output_names = None , ** kwargs ):
9+ def spaceplots (
10+ inputs , outputs , input_names = None , output_names = None , limits = None , ** kwargs
11+ ):
1012 num_samples , num_inputs = inputs .shape
1113 if input_names is not None :
1214 if len (input_names ) != num_inputs :
@@ -23,10 +25,21 @@ def spaceplots(inputs, outputs, input_names=None, output_names=None, **kwargs):
2325 else :
2426 output_names = [None ] * num_outputs
2527
28+ if limits is not None :
29+ if limits .shape [1 ] != 2 :
30+ raise RuntimeError (
31+ "There must be a upper and lower limit for each output"
32+ )
33+ elif limits .shape [0 ] != num_outputs :
34+ raise RuntimeError ("Output data and limits don't match" )
35+ else :
36+ limits = [[None , None ]] * num_outputs
37+
2638 for out_index in range (num_outputs ):
2739 yield _subspace_plot (
2840 inputs , outputs [:, out_index ], input_names = input_names ,
29- output_name = output_names [out_index ], ** kwargs
41+ output_name = output_names [out_index ], min_output = limits [out_index ][0 ],
42+ max_output = limits [out_index ][1 ], ** kwargs
3043 )
3144
3245
@@ -79,7 +92,18 @@ def _setup_axes(
7992 return fig , axes , grid
8093
8194
82- def _subspace_plot (inputs , output , * , input_names , output_name , ** kwargs ):
95+ def _subspace_plot (
96+ inputs , output , * , input_names , output_name , scatter_args = None ,
97+ histogram_args = None , min_output = None , max_output = None
98+ ):
99+ if scatter_args is None :
100+ scatter_args = {}
101+ if histogram_args is None :
102+ histogram_args = {}
103+ if min_output is None :
104+ min_output = min (output )
105+ if max_output is None :
106+ max_output = max (output )
83107
84108 # see https://matplotlib.org/examples/pylab_examples/multi_image.html
85109 _ , num_inputs = inputs .shape
@@ -89,11 +113,13 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
89113 if output_name is not None :
90114 fig .suptitle (output_name )
91115
92- norm = _Normalize (min ( output ), max ( output )) # TODO: get from user if needed
116+ norm = _Normalize (min_output , max_output )
93117
94118 hist_plots = []
95119 for i in range (num_inputs ):
96- hist_plots .append (_plot_hist (inputs [:, i ], axis = axes [i ][i ]))
120+ hist_plots .append (_plot_hist (
121+ inputs [:, i ], axis = axes [i ][i ], ** histogram_args
122+ ))
97123
98124 scatter_plots = []
99125 scatter_plots_grid = []
@@ -103,7 +129,7 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
103129 sc_plot = _plot_scatter (
104130 x = inputs [:, x_index ], y = inputs [:, y_index ], z = output ,
105131 axis = axes [y_index ][x_index ], # check order
106- norm = norm
132+ norm = norm , ** scatter_args
107133 )
108134 scatter_plots .append (sc_plot )
109135 scatter_plots_grid [y_index ].append (sc_plot )
0 commit comments