diff --git a/medical_data_visualizer.py b/medical_data_visualizer.py index ea7bbd1..710e725 100644 --- a/medical_data_visualizer.py +++ b/medical_data_visualizer.py @@ -20,11 +20,16 @@ def draw_cat_plot(): # Group and reformat the data to split it by 'cardio'. Show the counts of each feature. You will have to rename one of the columns for the catplot to work correctly. df_cat = None + # Draw the catplot with 'sns.catplot()' + # Get the figure for the output + fig = None + + # Do not modify the next two lines fig.savefig('catplot.png') return fig diff --git a/test_module.py b/test_module.py index 090af1c..3f75c75 100644 --- a/test_module.py +++ b/test_module.py @@ -38,7 +38,7 @@ class HeatMapTestCase(unittest.TestCase): for label in self.ax.get_xticklabels(): actual.append(label.get_text()) expected = ['id', 'age', 'sex', 'height', 'weight', 'ap_hi', 'ap_lo', 'cholesterol', 'gluc', 'smoke', 'alco', 'active', 'cardio', 'overweight'] - self.assertEqual(actual, expected, "Expected bar plot legend labels to be months of the year.") + self.assertEqual(actual, expected, "Expected heat map labels to be 'id', 'age', 'sex', 'height', 'weight', 'ap_hi', 'ap_lo', 'cholesterol', 'gluc', 'smoke', 'alco', 'active', 'cardio', 'overweight'.") def test_heat_map_values(self): actual = [text.get_text() for text in self.ax.get_default_bbox_extra_artists() if isinstance(text, mpl.text.Text)]