sub_kw = dict(facecolor='azure') fig_kw = dict(facecolor='silver') fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, subplot_kw=sub_kw, **fig_kw) ax1.plot(x, y) ax2.scatter(x, y) ax4.set_facecolor('pink')