import matplotlib.pyplot as plt
import numpy as np
import hmf as hmfcalc
import yt, run
from subs import com

fig = plt.figure(figsize=(8,6))
fig.subplots_adjust(left=0.13, bottom=0.12, right=0.97, top=0.97, hspace=0.06)
ax = fig.add_subplot(1,1,1)

ax.set_xscale("log")
ax.set_xlim(1.0e1,5.0e2)
ax.set_xlabel("$r_h [{\\rm kpc}]$")

ax.set_yscale("log")
ax.set_ylim(1.0e-7,1.0e2)
ax.set_ylabel("$n(>r_h) [{\\rm Mpc}^{-3}]$")


def Plot(ax,fname,color="k",linewidth=3,linestyle="-",label=None,scale=1):

    hs = yt.load(fname)

    m = hs.all_data()[('halos','virial_radius')].in_units("kpccm")*scale
    m.sort()

    if('current_redshift' in hs.parameters):
        z = hs.parameters['current_redshift']
        lbox = hs.parameters['domain_right_edge'][0].in_units("Mpccm")
    else:
        z = 1/hs.parameters['scale'] - 1
        lbox = hs.parameters['box_size']/hs.parameters["h0"]
    print(z,lbox)
    n = np.arange(len(m),0,-1)/lbox**3
    ax.plot(m,n,color=color,linewidth=linewidth,linestyle=linestyle,label=label)
    

Plot(ax,"D/A/a=0.1282/hop/hop.0.h5",color="r",label="HOP")
Plot(ax,"D/A/rs/halos_8.0.bin",color="b",label="Rockstar")
Plot(ax,"D/A/rs/halos_8.0.bin",color="purple",label="Rockstar$\times 1.5$",scale=1.5)


plt.legend(labelspacing=0.25,loc=3,framealpha=1)
plt.show()
