From 076938f18a43f521f24303dd1ce81a17c737e417 Mon Sep 17 00:00:00 2001
From: Fabian Pedregosa <fabian.pedregosa@inria.fr>
Date: Tue, 5 Jan 2010 13:38:54 +0000
Subject: [PATCH] Flatten data in Iris dataset

From: cdavid <cdavid@cb17146a-f446-4be1-a4f7-bd7c5bb65646>

git-svn-id: https://scikit-learn.svn.sourceforge.net/svnroot/scikit-learn/trunk@13 22fbfee3-77ab-4535-9bad-27d1bd3bc7d8
---
 scikits/learn/datasets/iris/data.py        | 48 ++++++++++-----
 scikits/learn/datasets/iris/iris.py        | 71 ++++++++++++++++++++--
 scikits/learn/datasets/iris/src/convert.py | 34 ++++++-----
 3 files changed, 118 insertions(+), 35 deletions(-)

diff --git a/scikits/learn/datasets/iris/data.py b/scikits/learn/datasets/iris/data.py
index 4e0b0fbbd9..ce470280ed 100644
--- a/scikits/learn/datasets/iris/data.py
+++ b/scikits/learn/datasets/iris/data.py
@@ -1,6 +1,6 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
-# Last Change: Sun Jul 01 08:00 PM 2007 J
+# Last Change: Tue Jul 17 04:00 PM 2007 J
 
 # The code and descriptive text is copyrighted and offered under the terms of
 # the BSD License from the authors; see below. However, the actual dataset may
@@ -98,25 +98,45 @@ def load():
     """load the iris data and returns them.
     
     :returns:
-        data: recordarray
-            a record array of the data.
+        d : dict
+            contains the following values:
+            - 'data' : a record array with the actual data
+            - 'label' : label[i] = label index of data[i]
+            - 'class' : class[i] is the string corresponding to label index i.
+
+    Example
+    -------
+    
+    Let's say you are interested in the samples 10, 25, and 50, and want to
+    know their class name.
+
+    >>>> d = load()
+    >>>> ind = [10, 25, 50]
+    >>>> lind = d['label'][ind] # returns the label index of each sample
+    >>>> d['class'][lind] # returns the class name of each sample
+
     """
     import numpy
-    from iris import SL, SW, PL, PW, CLI
+    from iris import SL, SW, PL, PW, LABELS, LI2LN
     PW = numpy.array(PW).astype(numpy.float)
     PL = numpy.array(PL).astype(numpy.float)
     SW = numpy.array(SW).astype(numpy.float)
     SL = numpy.array(SL).astype(numpy.float)
     data    = {}
-    for i in CLI.items():
-        name = i[0][5:]
-        data[name] = numpy.empty(len(i[1]), [('petal width', numpy.int),\
-                        ('petal length', numpy.int),
-                        ('sepal width', numpy.int),
-                        ('sepal length', numpy.int)])
-        data[name]['petal width'] = numpy.round(PW[i[1]] * 10)
-        data[name]['petal length'] = numpy.round(PL[i[1]] * 10)
-        data[name]['sepal width'] = numpy.round(SW[i[1]] * 10)
-        data[name]['sepal length'] = numpy.round(SL[i[1]] * 10)
+    data['data'] = numpy.empty(len(PW), 
+                               [('petal width', numpy.int),
+                                ('petal length', numpy.int),
+                                ('sepal width', numpy.int),
+                                ('sepal length', numpy.int)])
+
+    data['data']['petal width'] = numpy.round(PW * 10)
+    data['data']['petal length'] = numpy.round(PL * 10)
+    data['data']['sepal width'] = numpy.round(SW * 10)
+    data['data']['sepal length'] = numpy.round(SL * 10)
+    data['label'] = numpy.array(LABELS).astype(numpy.int)
+    data['class'] = numpy.empty(len(LI2LN), 
+                                'S%d' % numpy.max([len(i) for i in LI2LN.values()]))
+    for i,c in LI2LN.items():
+        data['class'][i] = c
     
     return data
diff --git a/scikits/learn/datasets/iris/iris.py b/scikits/learn/datasets/iris/iris.py
index 7348f910bd..94da7409f7 100644
--- a/scikits/learn/datasets/iris/iris.py
+++ b/scikits/learn/datasets/iris/iris.py
@@ -1,12 +1,71 @@
-# Autogenerated by convert.py at Sun, 01 Jul 2007 10:36:56 +0000
+# Autogenerated by convert.py at Tue, 17 Jul 2007 06:47:17 +0000
 
-SL = ['5.1', '4.9', '4.7', '4.6', '5.0', '5.4', '4.6', '5.0', '4.4', '4.9', '5.4', '4.8', '4.8', '4.3', '5.8', '5.7', '5.4', '5.1', '5.7', '5.1', '5.4', '5.1', '4.6', '5.1', '4.8', '5.0', '5.0', '5.2', '5.2', '4.7', '4.8', '5.4', '5.2', '5.5', '4.9', '5.0', '5.5', '4.9', '4.4', '5.1', '5.0', '4.5', '4.4', '5.0', '5.1', '4.8', '5.1', '4.6', '5.3', '5.0', '7.0', '6.4', '6.9', '5.5', '6.5', '5.7', '6.3', '4.9', '6.6', '5.2', '5.0', '5.9', '6.0', '6.1', '5.6', '6.7', '5.6', '5.8', '6.2', '5.6', '5.9', '6.1', '6.3', '6.1', '6.4', '6.6', '6.8', '6.7', '6.0', '5.7', '5.5', '5.5', '5.8', '6.0', '5.4', '6.0', '6.7', '6.3', '5.6', '5.5', '5.5', '6.1', '5.8', '5.0', '5.6', '5.7', '5.7', '6.2', '5.1', '5.7', '6.3', '5.8', '7.1', '6.3', '6.5', '7.6', '4.9', '7.3', '6.7', '7.2', '6.5', '6.4', '6.8', '5.7', '5.8', '6.4', '6.5', '7.7', '7.7', '6.0', '6.9', '5.6', '7.7', '6.3', '6.7', '7.2', '6.2', '6.1', '6.4', '7.2', '7.4', '7.9', '6.4', '6.3', '6.1', '7.7', '6.3', '6.4', '6.0', '6.9', '6.7', '6.9', '5.8', '6.8', '6.7', '6.7', '6.3', '6.5', '6.2', '5.9']
+SL = ['5.1', '4.9', '4.7', '4.6', '5.0', '5.4', '4.6', '5.0', '4.4', '4.9',
+'5.4', '4.8', '4.8', '4.3', '5.8', '5.7', '5.4', '5.1', '5.7', '5.1', '5.4',
+'5.1', '4.6', '5.1', '4.8', '5.0', '5.0', '5.2', '5.2', '4.7', '4.8', '5.4',
+'5.2', '5.5', '4.9', '5.0', '5.5', '4.9', '4.4', '5.1', '5.0', '4.5', '4.4',
+'5.0', '5.1', '4.8', '5.1', '4.6', '5.3', '5.0', '7.0', '6.4', '6.9', '5.5',
+'6.5', '5.7', '6.3', '4.9', '6.6', '5.2', '5.0', '5.9', '6.0', '6.1', '5.6',
+'6.7', '5.6', '5.8', '6.2', '5.6', '5.9', '6.1', '6.3', '6.1', '6.4', '6.6',
+'6.8', '6.7', '6.0', '5.7', '5.5', '5.5', '5.8', '6.0', '5.4', '6.0', '6.7',
+'6.3', '5.6', '5.5', '5.5', '6.1', '5.8', '5.0', '5.6', '5.7', '5.7', '6.2',
+'5.1', '5.7', '6.3', '5.8', '7.1', '6.3', '6.5', '7.6', '4.9', '7.3', '6.7',
+'7.2', '6.5', '6.4', '6.8', '5.7', '5.8', '6.4', '6.5', '7.7', '7.7', '6.0',
+'6.9', '5.6', '7.7', '6.3', '6.7', '7.2', '6.2', '6.1', '6.4', '7.2', '7.4',
+'7.9', '6.4', '6.3', '6.1', '7.7', '6.3', '6.4', '6.0', '6.9', '6.7', '6.9',
+'5.8', '6.8', '6.7', '6.7', '6.3', '6.5', '6.2', '5.9']
 
-SW = ['3.5', '3.0', '3.2', '3.1', '3.6', '3.9', '3.4', '3.4', '2.9', '3.1', '3.7', '3.4', '3.0', '3.0', '4.0', '4.4', '3.9', '3.5', '3.8', '3.8', '3.4', '3.7', '3.6', '3.3', '3.4', '3.0', '3.4', '3.5', '3.4', '3.2', '3.1', '3.4', '4.1', '4.2', '3.1', '3.2', '3.5', '3.1', '3.0', '3.4', '3.5', '2.3', '3.2', '3.5', '3.8', '3.0', '3.8', '3.2', '3.7', '3.3', '3.2', '3.2', '3.1', '2.3', '2.8', '2.8', '3.3', '2.4', '2.9', '2.7', '2.0', '3.0', '2.2', '2.9', '2.9', '3.1', '3.0', '2.7', '2.2', '2.5', '3.2', '2.8', '2.5', '2.8', '2.9', '3.0', '2.8', '3.0', '2.9', '2.6', '2.4', '2.4', '2.7', '2.7', '3.0', '3.4', '3.1', '2.3', '3.0', '2.5', '2.6', '3.0', '2.6', '2.3', '2.7', '3.0', '2.9', '2.9', '2.5', '2.8', '3.3', '2.7', '3.0', '2.9', '3.0', '3.0', '2.5', '2.9', '2.5', '3.6', '3.2', '2.7', '3.0', '2.5', '2.8', '3.2', '3.0', '3.8', '2.6', '2.2', '3.2', '2.8', '2.8', '2.7', '3.3', '3.2', '2.8', '3.0', '2.8', '3.0', '2.8', '3.8', '2.8', '2.8', '2.6', '3.0', '3.4', '3.1', '3.0', '3.1', '3.1', '3.1', '2.7', '3.2', '3.3', '3.0', '2.5', '3.0', '3.4', '3.0']
+SW = ['3.5', '3.0', '3.2', '3.1', '3.6', '3.9', '3.4', '3.4', '2.9', '3.1',
+'3.7', '3.4', '3.0', '3.0', '4.0', '4.4', '3.9', '3.5', '3.8', '3.8', '3.4',
+'3.7', '3.6', '3.3', '3.4', '3.0', '3.4', '3.5', '3.4', '3.2', '3.1', '3.4',
+'4.1', '4.2', '3.1', '3.2', '3.5', '3.1', '3.0', '3.4', '3.5', '2.3', '3.2',
+'3.5', '3.8', '3.0', '3.8', '3.2', '3.7', '3.3', '3.2', '3.2', '3.1', '2.3',
+'2.8', '2.8', '3.3', '2.4', '2.9', '2.7', '2.0', '3.0', '2.2', '2.9', '2.9',
+'3.1', '3.0', '2.7', '2.2', '2.5', '3.2', '2.8', '2.5', '2.8', '2.9', '3.0',
+'2.8', '3.0', '2.9', '2.6', '2.4', '2.4', '2.7', '2.7', '3.0', '3.4', '3.1',
+'2.3', '3.0', '2.5', '2.6', '3.0', '2.6', '2.3', '2.7', '3.0', '2.9', '2.9',
+'2.5', '2.8', '3.3', '2.7', '3.0', '2.9', '3.0', '3.0', '2.5', '2.9', '2.5',
+'3.6', '3.2', '2.7', '3.0', '2.5', '2.8', '3.2', '3.0', '3.8', '2.6', '2.2',
+'3.2', '2.8', '2.8', '2.7', '3.3', '3.2', '2.8', '3.0', '2.8', '3.0', '2.8',
+'3.8', '2.8', '2.8', '2.6', '3.0', '3.4', '3.1', '3.0', '3.1', '3.1', '3.1',
+'2.7', '3.2', '3.3', '3.0', '2.5', '3.0', '3.4', '3.0']
 
-PL = ['1.4', '1.4', '1.3', '1.5', '1.4', '1.7', '1.4', '1.5', '1.4', '1.5', '1.5', '1.6', '1.4', '1.1', '1.2', '1.5', '1.3', '1.4', '1.7', '1.5', '1.7', '1.5', '1.0', '1.7', '1.9', '1.6', '1.6', '1.5', '1.4', '1.6', '1.6', '1.5', '1.5', '1.4', '1.5', '1.2', '1.3', '1.5', '1.3', '1.5', '1.3', '1.3', '1.3', '1.6', '1.9', '1.4', '1.6', '1.4', '1.5', '1.4', '4.7', '4.5', '4.9', '4.0', '4.6', '4.5', '4.7', '3.3', '4.6', '3.9', '3.5', '4.2', '4.0', '4.7', '3.6', '4.4', '4.5', '4.1', '4.5', '3.9', '4.8', '4.0', '4.9', '4.7', '4.3', '4.4', '4.8', '5.0', '4.5', '3.5', '3.8', '3.7', '3.9', '5.1', '4.5', '4.5', '4.7', '4.4', '4.1', '4.0', '4.4', '4.6', '4.0', '3.3', '4.2', '4.2', '4.2', '4.3', '3.0', '4.1', '6.0', '5.1', '5.9', '5.6', '5.8', '6.6', '4.5', '6.3', '5.8', '6.1', '5.1', '5.3', '5.5', '5.0', '5.1', '5.3', '5.5', '6.7', '6.9', '5.0', '5.7', '4.9', '6.7', '4.9', '5.7', '6.0', '4.8', '4.9', '5.6', '5.8', '6.1', '6.4', '5.6', '5.1', '5.6', '6.1', '5.6', '5.5', '4.8', '5.4', '5.6', '5.1', '5.1', '5.9', '5.7', '5.2', '5.0', '5.2', '5.4', '5.1']
+PL = ['1.4', '1.4', '1.3', '1.5', '1.4', '1.7', '1.4', '1.5', '1.4', '1.5',
+'1.5', '1.6', '1.4', '1.1', '1.2', '1.5', '1.3', '1.4', '1.7', '1.5', '1.7',
+'1.5', '1.0', '1.7', '1.9', '1.6', '1.6', '1.5', '1.4', '1.6', '1.6', '1.5',
+'1.5', '1.4', '1.5', '1.2', '1.3', '1.5', '1.3', '1.5', '1.3', '1.3', '1.3',
+'1.6', '1.9', '1.4', '1.6', '1.4', '1.5', '1.4', '4.7', '4.5', '4.9', '4.0',
+'4.6', '4.5', '4.7', '3.3', '4.6', '3.9', '3.5', '4.2', '4.0', '4.7', '3.6',
+'4.4', '4.5', '4.1', '4.5', '3.9', '4.8', '4.0', '4.9', '4.7', '4.3', '4.4',
+'4.8', '5.0', '4.5', '3.5', '3.8', '3.7', '3.9', '5.1', '4.5', '4.5', '4.7',
+'4.4', '4.1', '4.0', '4.4', '4.6', '4.0', '3.3', '4.2', '4.2', '4.2', '4.3',
+'3.0', '4.1', '6.0', '5.1', '5.9', '5.6', '5.8', '6.6', '4.5', '6.3', '5.8',
+'6.1', '5.1', '5.3', '5.5', '5.0', '5.1', '5.3', '5.5', '6.7', '6.9', '5.0',
+'5.7', '4.9', '6.7', '4.9', '5.7', '6.0', '4.8', '4.9', '5.6', '5.8', '6.1',
+'6.4', '5.6', '5.1', '5.6', '6.1', '5.6', '5.5', '4.8', '5.4', '5.6', '5.1',
+'5.1', '5.9', '5.7', '5.2', '5.0', '5.2', '5.4', '5.1']
 
-PW = ['0.2', '0.2', '0.2', '0.2', '0.2', '0.4', '0.3', '0.2', '0.2', '0.1', '0.2', '0.2', '0.1', '0.1', '0.2', '0.4', '0.4', '0.3', '0.3', '0.3', '0.2', '0.4', '0.2', '0.5', '0.2', '0.2', '0.4', '0.2', '0.2', '0.2', '0.2', '0.4', '0.1', '0.2', '0.1', '0.2', '0.2', '0.1', '0.2', '0.2', '0.3', '0.3', '0.2', '0.6', '0.4', '0.3', '0.2', '0.2', '0.2', '0.2', '1.4', '1.5', '1.5', '1.3', '1.5', '1.3', '1.6', '1.0', '1.3', '1.4', '1.0', '1.5', '1.0', '1.4', '1.3', '1.4', '1.5', '1.0', '1.5', '1.1', '1.8', '1.3', '1.5', '1.2', '1.3', '1.4', '1.4', '1.7', '1.5', '1.0', '1.1', '1.0', '1.2', '1.6', '1.5', '1.6', '1.5', '1.3', '1.3', '1.3', '1.2', '1.4', '1.2', '1.0', '1.3', '1.2', '1.3', '1.3', '1.1', '1.3', '2.5', '1.9', '2.1', '1.8', '2.2', '2.1', '1.7', '1.8', '1.8', '2.5', '2.0', '1.9', '2.1', '2.0', '2.4', '2.3', '1.8', '2.2', '2.3', '1.5', '2.3', '2.0', '2.0', '1.8', '2.1', '1.8', '1.8', '1.8', '2.1', '1.6', '1.9', '2.0', '2.2', '1.5', '1.4', '2.3', '2.4', '1.8', '1.8', '2.1', '2.4', '2.3', '1.9', '2.3', '2.5', '2.3', '1.9', '2.0', '2.3', '1.8']
+PW = ['0.2', '0.2', '0.2', '0.2', '0.2', '0.4', '0.3', '0.2', '0.2', '0.1',
+'0.2', '0.2', '0.1', '0.1', '0.2', '0.4', '0.4', '0.3', '0.3', '0.3', '0.2',
+'0.4', '0.2', '0.5', '0.2', '0.2', '0.4', '0.2', '0.2', '0.2', '0.2', '0.4',
+'0.1', '0.2', '0.1', '0.2', '0.2', '0.1', '0.2', '0.2', '0.3', '0.3', '0.2',
+'0.6', '0.4', '0.3', '0.2', '0.2', '0.2', '0.2', '1.4', '1.5', '1.5', '1.3',
+'1.5', '1.3', '1.6', '1.0', '1.3', '1.4', '1.0', '1.5', '1.0', '1.4', '1.3',
+'1.4', '1.5', '1.0', '1.5', '1.1', '1.8', '1.3', '1.5', '1.2', '1.3', '1.4',
+'1.4', '1.7', '1.5', '1.0', '1.1', '1.0', '1.2', '1.6', '1.5', '1.6', '1.5',
+'1.3', '1.3', '1.3', '1.2', '1.4', '1.2', '1.0', '1.3', '1.2', '1.3', '1.3',
+'1.1', '1.3', '2.5', '1.9', '2.1', '1.8', '2.2', '2.1', '1.7', '1.8', '1.8',
+'2.5', '2.0', '1.9', '2.1', '2.0', '2.4', '2.3', '1.8', '2.2', '2.3', '1.5',
+'2.3', '2.0', '2.0', '1.8', '2.1', '1.8', '1.8', '1.8', '2.1', '1.6', '1.9',
+'2.0', '2.2', '1.5', '1.4', '2.3', '2.4', '1.8', '1.8', '2.1', '2.4', '2.3',
+'1.9', '2.3', '2.5', '2.3', '1.9', '2.0', '2.3', '1.8']
 
-CLI = {'Iris-virginica': [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149], 'Iris-setosa': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], 'Iris-versicolor': [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]}
+LABELS = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
+2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
+
+LI2LN = {0: 'Iris-setosa', 1: 'Iris-versicolor', 2: 'Iris-virginica'}
 
diff --git a/scikits/learn/datasets/iris/src/convert.py b/scikits/learn/datasets/iris/src/convert.py
index 0f1fb7b7b4..86111fd77c 100755
--- a/scikits/learn/datasets/iris/src/convert.py
+++ b/scikits/learn/datasets/iris/src/convert.py
@@ -1,10 +1,19 @@
 #! /usr/bin/env python
-# Last Change: Sun Jul 01 07:00 PM 2007 J
+# Last Change: Tue Jul 17 03:00 PM 2007 J
 
 # This script generates a python file from the txt data
 import time
 import csv
 
+from scikits.learn.datasets.misc import dumpvar
+
+# array for equivalence label index <> label name
+ln2li = {'Iris-setosa' : 0, 'Iris-versicolor': 1, 'Iris-virginica' :2}
+li2ln = {}
+for c,i in ln2li.items():
+    li2ln[i] = c
+
+# Load the data
 dataname = 'iris.data'
 f = open(dataname, 'r')
 a = csv.reader(f)
@@ -19,22 +28,17 @@ pl = [i[2] for i in el]
 pw = [i[3] for i in el]
 cl = [i[4] for i in el]
 
-dcl = dict([(i, []) for i in cl])
-for i in range(len(cl)):
-    dcl[cl[i]].append(i)
+# dcl[i] = label index of data[i]
+dcl = [ln2li[i] for i in cl]
 
 # Write the data in oldfaitful.py
-a = open("iris.py", "w")
+a = open("../iris.py", "w")
 a.write('# Autogenerated by convert.py at %s\n\n' % 
         time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime()))
 
-def dump_var(var, varname):
-    a.write(varname + " = ")
-    a.write(str(var))
-    a.write("\n\n")
-
-dump_var(sl, 'SL')
-dump_var(sw, 'SW')
-dump_var(pl, 'PL')
-dump_var(pw, 'PW')
-dump_var(dcl, 'CLI')
+a.writelines(dumpvar(sl, 'SL'))
+a.writelines(dumpvar(sw, 'SW'))
+a.writelines(dumpvar(pl, 'PL'))
+a.writelines(dumpvar(pw, 'PW'))
+a.writelines(dumpvar(dcl, 'LABELS'))
+a.writelines(dumpvar(li2ln, 'LI2LN'))
-- 
GitLab