[Tutor] Question in regards to loops and matplotlib

Zachary Rizer protonlenny at gmail.com
Thu Aug 8 17:58:18 CEST 2013


This looks quite clean! Since posting this question I have cleaned up this
script by using multiple functions, but it still isn't this clean! My data
is a normal CSV table with column headers. It isn't in a dictionary format
like your sample data here. I'm sure I could switch it to that format
though. I'll post a sample of my data, first 10 rows with column headers
here, and my current rendition of the code. Also, not to familiar with the
lambda function you did here. Little bit confused on how you split the
groups up. Anyway, new version of code and sample data:

I attached the ten row sample because it has 60 columns of info on each row
and would look huge pasted.

"""
Zach Rizer
UVJ plot of n>2 pop compare qui and sf disk and bulge pop
data = n>2_qui_flag.csv
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patch
from matplotlib.path import Path

z1 = (0.6, 0.9, 0.59, 0.81, 2, "0.6<z<0.9")
z2 = (0.9, 1.3, 0.52, 0.89, 1.93, "0.9<z<1.3")
z3 = (1.3, 1.8, 0.49, 0.92, 1.9, "1.3<z<1.8")
z4 = (1.8, 2.5, 0.49, 0.92, 1.9, "1.8<z<2.5")

def redshift(table,z_range):
    """
    takes a given csv table and a chosen redshift range and spits
    out a new array with only the objects with said redshift
    """
    z_data = []
    for row in table:
        if row['pz_z'] >= z_range[0]  and row['pz_z'] < z_range[1]:
            z_data.append(row)
    z_data = np.array(z_data)
    return z_data

def quiescence(table,z_range):
    """
    takes given csv table and specific redshfit range and spits out
    a two-tuple where first index of the tuple is the quiescent pop
    and the second index is the star forming pop
    """
    qui_data = []
    sf_data = []
    for row in table:
        if row['rf_UminV'] > 1.3 and row['rf_VminJ'] < 1.6 and
row['rf_UminV'] > 0.88*(row['rf_VminJ'])+z_range[2]:
            qui_data.append(row)
        else:
            sf_data.append(row)
    qui_data = np.array(qui_data)
    sf_data = np.array(sf_data)
    return qui_data, sf_data

def disk_vs_bulge(table):
    """
    turns given csv table into a two-tuple array where the first
    index is the disk-dom pop and the second index is the bulge-dom
    pop: these cuts are determined by visclass data
    """
    disk_data = []
    bulge_data = []
    for row in table:
        if row['vc_Dv'] > 0.65 and row['vc_DSw'] > 0.65:
            disk_data.append(row)
        elif row['vc_Dv'] < 0.35 and row['vc_DSw'] > 0.65:
            bulge_data.append(row)
    disk_data = np.array(disk_data)
    bulge_data = np.array(bulge_data)
    return disk_data, bulge_data

def make_mass_cuts(table):
    """
    make mass cuts using logMass row and creates four-tuple
    each index is divided into its own dual tuple, x and y UVJ colors
    array returned = [(m1_x, m1_y), (m2_x, m2_y), (m3_x, m3_y), (m4_x,
m4_y)]
    """
    m1_x = []
    m1_y = []
    m2_x = []
    m2_y = []
    m3_x = []
    m3_y = []
    m4_x = []
    m4_y = []
    for row in table:
        if row['ms_lmass'] >= 9.7 and row['ms_lmass'] < 10:
            m1_x.append(row['rf_VminJ'])
            m1_y.append(row['rf_UminV'])
        elif row['ms_lmass'] >= 10 and row['ms_lmass'] < 10.5:
            m2_x.append(row['rf_VminJ'])
            m2_y.append(row['rf_UminV'])
        elif row['ms_lmass'] >= 10.5 and row['ms_lmass'] < 10.8:
            m3_x.append(row['rf_VminJ'])
            m3_y.append(row['rf_UminV'])
        elif row['ms_lmass'] >= 10.8:
            m4_x.append(row['rf_VminJ'])
            m4_y.append(row['rf_UminV'])
    return [(m1_x, m1_y), (m2_x, m2_y), (m3_x, m3_y), (m4_x, m4_y)]

def plots(m_q_disk, m_q_bulge, m_sf_disk, m_sf_bulge):
    """
    plot the first column as x, and second column as y
    takes all four mass arrays and plots all 16 subplots
    returns a 3-tuple of plot of qui_bulge, sf_bulge, sf_disk for legend
    """
    plt.plot(m_q_disk[3][0], m_q_disk[3][1], 'wo', ms=12)
    qui_bulge_label, = plt.plot(m_q_bulge[3][0], m_q_bulge[3][1], 'ro',
ms=12)
    plt.plot(m_q_disk[2][0], m_q_disk[2][1], 'wo', ms=9)
    plt.plot(m_q_bulge[2][0], m_q_bulge[2][1], 'ro', ms=9)
    plt.plot(m_q_disk[1][0], m_q_disk[1][1], 'wo', ms=6)
    plt.plot(m_q_bulge[1][0], m_q_bulge[1][1], 'ro', ms=6)
    plt.plot(m_q_disk[0][0], m_q_disk[0][1], 'wo', ms=3)
    plt.plot(m_q_bulge[0][0], m_q_bulge[0][1], 'ro', ms=3)

    sf_disk_label, = plt.plot(m_sf_disk[3][0], m_sf_disk[3][1], 'wo', ms=12)
    sf_bulge_label, = plt.plot(m_sf_bulge[3][0], m_sf_bulge[3][1], 'bo',
ms=12)
    plt.plot(m_sf_disk[2][0], m_sf_disk[2][1], 'wo', ms=9)
    plt.plot(m_sf_bulge[2][0], m_sf_bulge[2][1], 'bo', ms=9)
    plt.plot(m_sf_disk[1][0], m_sf_disk[1][1], 'wo', ms=6)
    plt.plot(m_sf_bulge[1][0], m_sf_bulge[1][1], 'bo', ms=6)
    plt.plot(m_sf_disk[0][0], m_sf_disk[0][1], 'wo', ms=3)
    plt.plot(m_sf_bulge[0][0], m_sf_bulge[0][1], 'bo', ms=3)

    return qui_bulge_label, sf_bulge_label, sf_disk_label

def make_hash_region(z_range):
    """make quiescent region"""
    verts = [(-1.,1.3),
             (z_range[3],1.3),
             (1.6,z_range[4]),
             (1.6,2.5),
             (-1.,2.5),
             (-1.,1.3)]
    codes = [Path.MOVETO,
             Path.LINETO,
             Path.LINETO,
             Path.LINETO,
             Path.LINETO,
             Path.CLOSEPOLY]
    return Path(verts,codes)

def titles_labels(labels, z_range):
    """
    creates axis labels, title, and legend
    associating labels to these three specific plots
    """
    plt.title('UVJ Plot of n>2 population at ' + z_range[5])
    plt.xlabel('V-J')
    plt.ylabel('U-V')
    plt.xlim(0.5, 2.0)
    plt.ylim(0.5, 2.5)
    plt.legend([labels[0], labels[1], labels[2]],
               ["Qui-Bulge-Dom", "SF-Bulge-Dom", "Disk-Dom"],
               'best', numpoints=1, fontsize='small')

def main():
    """
    calls total script to create plot
    to run, simply load csv in the first spot in the np.genfromtxt args
    and exchange the zbin value to either z1, z2, z3, or z4
    """
    fig = plt.figure('UVJ')
    n2_subset =
np.genfromtxt('/Users/ProtonLenny/Documents/Research/Catalog_Data/Catalog_4/n>2.csv',
                     dtype=None, names=True, delimiter =",")
    zbin = z4
    z_data = redshift(n2_subset, zbin)
    qui_data = (quiescence(z_data, zbin))[0]
    sf_data = (quiescence(z_data, zbin))[1]

    qui_disk = (disk_vs_bulge(qui_data))[0]
    qui_bulge = (disk_vs_bulge(qui_data))[1]
    sf_disk = (disk_vs_bulge(sf_data))[0]
    sf_bulge = (disk_vs_bulge(sf_data))[1]

    m_q_disk = make_mass_cuts(qui_disk)
    m_q_bulge = make_mass_cuts(qui_bulge)
    m_sf_disk = make_mass_cuts(sf_disk)
    m_sf_bulge = make_mass_cuts(sf_bulge)

    legend_labels = plots(m_q_disk, m_q_bulge, m_sf_disk, m_sf_bulge)
    path = make_hash_region(zbin)
    titles_labels(legend_labels, zbin)

    hash_region = patch.PathPatch(path, facecolor='none', lw=1.5,
alpha=0.5, hatch='//')
    ax = fig.add_subplot(111) #subplot for shaded region
    ax.add_patch(hash_region)
    plt.show()

if __name__ == '__main__':
    main()


On Thu, Aug 8, 2013 at 8:38 AM, Oscar Benjamin
<oscar.j.benjamin at gmail.com>wrote:

> On 1 August 2013 06:21, Zachary Rizer <protonlenny at gmail.com> wrote:
> > So I just started coding, and I'm so glad I chose Python to start me
> off! I
> > really enjoy the clean layout, easy syntax, and power of the language.
> I'm
> > doing astronomy research and I just started teaching myself matplotlib
> along
> > with my general python work. I feel like I'm catching on quick, but I
> also
> > feel that this particular plot script is a little rough. The plot looks
> > correct, but the code seems really long for what I'm trying to do. So any
> > tricks to make this more efficient would be greatly appreciated!
>
> Hi Zachary,
>
> You can definitely make this shorter. I can't run your code since you
> didn't provide any data. It's good to use dummy data in examples
> posted to mailing lists so that others can run the code. Instead I'll
> show an example with my own dummy data:
>
>
> #!/usr/bin/env python
> #
> # Plot heights and weights of subjects from different groups.
>
> from matplotlib import pyplot as plt
>
> # Dummy data
> subjects = [
>     # Group, height (m), weight (kg), diagnosis
>     {'type': 'patient', 'height': 1.8, 'weight': 90 , 'diag': 'acute'  },
>     {'type': 'patient', 'height': 1.6, 'weight': 85 , 'diag': 'acute'  },
>     {'type': 'patient', 'height': 1.9, 'weight': 120, 'diag': 'chronic'},
>     {'type': 'patient', 'height': 2.0, 'weight': 110, 'diag': 'chronic'},
>     {'type': 'control', 'height': 1.7, 'weight': 60 , 'diag': 'N/A'    },
>     {'type': 'control', 'height': 2.1, 'weight': 100, 'diag': 'N/A'    },
> ]
>
> # Have just one place where we define the properties of each group
> groups = {
>     'control': {
>         'indicator': lambda s: s['type'] == 'control',
>         'color': 'blue',
>         'marker': '*',
>         'label':'Controls'
>      },
>     'chronic': {
>         'indicator': lambda s: s['diag'] == 'chronic',
>         'color': 'magenta',
>         'marker': 's',
>         'label':'Chronic patients'
>     },
>     'acute'  : {
>         'indicator': lambda s: s['diag'] == 'acute',
>         'color': 'red',
>         'marker': 'o',
>         'label':'Acute patients'
>     },
> }
>
> # Now we can reuse the same code for every subject
> for groupdata in groups.values():
>     groupdata['subjects'] = []
>
> # Distribute the subjects to each group
> for subject in subjects:
>     for groupdata in groups.values():
>         isingroup = groupdata['indicator']
>         if isingroup(subject):
>             groupdata['subjects'].append(subject)
>             break
>
> # Now create/format the figure
> fig = plt.figure(figsize=(6, 5))
> ax = fig.add_axes([0.15, 0.15, 0.70, 0.70])
> ax.set_xlabel(r'Weight ($kg$)')
> ax.set_ylabel(r'Height ($m$)')
> ax.set_title('Height vs. weight by subject group')
> ax.set_xlim([50, 150])
> ax.set_ylim([1.5, 2.2])
>
> # Plot each group separately with its own format settings
> for group, groupdata in groups.items():
>     xdata = [s['weight'] for s in groupdata['subjects']]
>     ydata = [s['height'] for s in groupdata['subjects']]
>     ax.plot(xdata, ydata, linestyle='none', color=groupdata['color'],
>             marker=groupdata['marker'], label=groupdata['label'])
>
> ax.legend(loc='upper left', numpoints=1)
>
> # Show the figure in GUI by default.
> # Save to e.g. 'plot.pdf' if filename is given.
> import sys
> if len(sys.argv) > 1:
>     imgname = sys.argv[1]
>     fig.savefig(imgname)
> else:
>     plt.show()
>
>
> Oscar
>
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/tutor/attachments/20130808/f9ab7b13/attachment-0001.html>
-------------- next part --------------
A non-text attachment was scrubbed...
Name: sample.csv
Type: text/csv
Size: 5099 bytes
Desc: not available
URL: <http://mail.python.org/pipermail/tutor/attachments/20130808/f9ab7b13/attachment-0001.csv>


More information about the Tutor mailing list