Skip to content
Snippets Groups Projects
Commit f8926229 authored by Olivier Grisel's avatar Olivier Grisel
Browse files

FIX Python 3 support for datasets.species_distributions

parent bd0b93e9
No related branches found
No related tags found
No related merge requests found
......@@ -64,20 +64,17 @@ except ImportError:
print(__doc__)
def create_species_bunch(species_name,
train, test,
coverages, xgrid, ygrid):
"""
create a bunch with information about a particular organism
def create_species_bunch(species_name, train, test, coverages, xgrid, ygrid):
"""Create a bunch with information about a particular organism
This will use the test/train record arrays to extract the
data specific to the given species name.
"""
bunch = Bunch(name=' '.join(species_name.split("_")[:2]))
species_name = species_name.encode('ascii')
points = dict(test=test, train=train)
for label, pts in points.iteritems():
for label, pts in points.items():
# choose points associated with the desired species
pts = pts[pts['species'] == species_name]
bunch['pts_%s' % label] = pts
......
......@@ -59,8 +59,8 @@ species_names = ['Bradypus Variegatus', 'Microryzomys Minutus']
Xtrain = np.vstack([data['train']['dd lat'],
data['train']['dd long']]).T
ytrain = np.array([d.startswith('micro') for d in data['train']['species']],
dtype='int')
ytrain = np.array([d.decode('ascii').startswith('micro')
for d in data['train']['species']], dtype='int')
Xtrain *= np.pi / 180. # Convert lat/long to radians
# Set up the data grid for the contour plot
......
......@@ -43,9 +43,11 @@ from os.path import exists
try:
# Python 2
from urllib2 import urlopen
PY2 = True
except ImportError:
# Python 3
from urllib.request import urlopen
PY2 = False
import numpy as np
......@@ -60,24 +62,19 @@ COVERAGES_URL = join(DIRECTORY_URL, "coverages.zip")
DATA_ARCHIVE_NAME = "species_coverage.pkz"
def _load_coverage(F, header_length=6,
dtype=np.int16):
"""
load a coverage file.
def _load_coverage(F, header_length=6, dtype=np.int16):
"""Load a coverage file from an open file object.
This will return a numpy array of the given dtype
"""
try:
header = [F.readline() for i in range(header_length)]
except:
F = open(F)
header = [F.readline() for i in range(header_length)]
make_tuple = lambda t: (t.split()[0], float(t.split()[1]))
header = dict([make_tuple(line) for line in header])
M = np.loadtxt(F, dtype=dtype)
nodata = header['NODATA_value']
nodata = header[b'NODATA_value']
if nodata != -9999:
print(nodata)
M[nodata] = -9999
return M
......@@ -87,24 +84,21 @@ def _load_csv(F):
Parameters
----------
F : string or file object
file object or name of file
F : file object
CSV file open in byte mode.
Returns
-------
rec : np.ndarray
record array representing the data
"""
try:
names = F.readline().strip().split(',')
except:
F = open(F)
if PY2:
# Numpy recarray wants Python 2 str but not unicode
names = F.readline().strip().split(',')
rec = np.loadtxt(F, skiprows=1, delimiter=',',
dtype='a22,f4,f4')
else:
# Numpy recarray wants Python 3 str but not bytes...
names = F.readline().decode('ascii').strip().split(',')
rec = np.loadtxt(F, skiprows=0, delimiter=',', dtype='a22,f4,f4')
rec.dtype.names = names
return rec
......@@ -243,8 +237,7 @@ def fetch_species_distributions(data_home=None,
fhandle = BytesIO(X[f])
print(' - converting', f)
coverages.append(_load_coverage(fhandle))
coverages = np.asarray(coverages,
dtype=dtype)
coverages = np.asarray(coverages, dtype=dtype)
bunch = Bunch(coverages=coverages,
test=test,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment