#!/usr/bin/env python

import pip
pip.main(['install', 'requests'])

import os
import sys
import requests
from glob import glob
from datetime import datetime
import base64
import numpy as np

import ccd

spark_home = '/opt/spark/spark-2.0.2-bin-hadoop2.7'
sys.path.insert(0, os.path.join(spark_home, 'python'))
sys.path.insert(0, os.path.join(spark_home, 'python/lib/py4j-0.10.3-src.zip'))

from pyspark import SparkContext, SparkConf

ubid_band_dict = {
    'tm': {'red': 'band3',
           'blue': 'band1',
           'green': 'band2',
           'nirs': 'band4',
           'swirs1': 'band5',
           'swirs2': 'band7',
           'thermals': 'band6',
           'qas': 'cfmask'},
    'oli': {'red': 'band4',
            'blue': 'band2',
            'green': 'band3',
            'nirs': 'band5',
            'swirs1': 'band6',
            'swirs2': 'band7',
            'thermals': 'band10',
            'qas': 'cfmask'}
}


def sort_band_data(band, field):
    return sorted(band, key=lambda x: x[field])

def b64_to_bytearray(data):
    return np.frombuffer(base64.b64decode(data), np.int16)

def dtstr_to_ordinal(dtstr):
    _dt = datetime.strptime(dtstr, '%Y-%m-%dT%H:%M:%SZ')
    return _dt.toordinal()

def collect_data(band_group, json_data):
    # sort out the initial band lists
    _blist = "band2 band3 band4 band5 band6 band7 cfmask"
    band_list = "band1 " + _blist if band_group is 'tm' else "band10 " + _blist
    for b in band_list.split(" "):
        vars()[b] = []

    # sort the data by band into their respective buckets
    ## this to be removed once all bands reporting
    band_bucket = []
    for item in json_data:
        which_band = item['ubid'][-6:].replace("_", "")
        ## this to be removed once all bands reporting
        band_bucket.append(which_band)
        vars()[which_band].append(item)

    ## this to be removed once all bands reporting
    print "ubids returned: {}".format(set(band_bucket))
    # we only want datasets with corresponding cfmask data
    valid_sources = set([i['source'] for i in vars()['band2']]) & set([i['source'] for i in vars()['cfmask']])
    for bucket in band_list.split(" "):
        _orig = vars()[bucket]
        vars()[bucket+'_clean'] = [item for item in _orig if item['source'] in valid_sources]

    # sort data by date acquired, and convert data to bytearray
    for bucket in band_list.split(" "):
        _sorted = vars()[bucket+'_sorted'] = sort_band_data(vars()[bucket+'_clean'], 'acquired')
        vars()[bucket+'_bytes'] = [b64_to_bytearray(item['data']) for item in _sorted]

    # create our date list
    dates = [dtstr_to_ordinal(i['acquired']) for i in vars()['band2_sorted']]

    # map bands to spectra
    mapping = ubid_band_dict[band_group]

    for band in "red green blue nirs swirs1 swirs2 thermals qas".split(" "):
        vars()[band+'_array'] = np.array(vars()[mapping[band] + '_bytes'])
        #vars()[band] = vars()[mapping[band] + '_bytes']
        #vars()[band+'_array'] = np.array(vars()[band])

    rows = len(dates)  #282
    cells = 10000      # per tile, 100x100

    output = {}
    for pixel in range(0, cells):
        lower = pixel
        upper = pixel + 1
        output[pixel] = {'dates': dates,
                         'red': vars()['red_array'][0:rows, lower:upper],
                         'green': vars()['green_array'][0:rows, lower:upper],
                         'blue': vars()['blue_array'][0:rows, lower:upper],
                         'nirs': vars()['nirs_array'][0:rows, lower:upper],
                         'swirs1': vars()['swirs1_array'][0:rows, lower:upper],
                         'swirs2': vars()['swirs2_array'][0:rows, lower:upper],
                         'thermals': vars()['thermals_array'][0:rows, lower:upper],
                         'qas': vars()['qas_array'][0:rows, lower:upper]}

    print "returning output from collect_output..."
    return output

def run_pyccd(data):
  def np_to_list(_d):
    return [i[0] for i in _d]

  print 'running ccd.detect()'
  return ccd.detect(data['dates'], np_to_list(data['red']), np_to_list(data['green']), np_to_list(data['blue']), np_to_list(data['nirs']), np_to_list(data['swirs1']), np_to_list(data['swirs2']), np_to_list(data['thermals']), np_to_list(data['qas']))


#try:
#  urlarg = sys.argv[1]
#except IndexError:
#  print 'no arg'
#  #raise Exception('provide a url')
#url = urlarg
#print "url is: {}".format(url)

url = """http://lcmap-test.cr.usgs.gov/landsat/tiles?ubid=LANDSAT_7/ETM/sr_band1&ubid=LANDSAT_7/ETM/sr_band2&ubid=LANDSAT_7/ETM/sr_band4&ubid=LANDSAT_7/ETM/sr_band5&ubid=LANDSAT_7/ETM/sr_band7&ubid=LANDSAT_7/ETM/cfmask&ubid=LANDSAT_7/ETM/sr_band3&ubid=LANDSAT_7/ETM/toa_band6&x=-2013585&y=3095805&acquired=1982-01-01/2017-01-01"""
print "url is: {}".format(url)
tile_resp = requests.get(url)
if tile_resp.status_code == 200:
  print "got a 200!" 
elif tile_resp.status_code == 500:
  print "500.... trying again... "
  tile_resp = requests.get(url)      
  if tile_resp.status_code == 500:
    raise Exception("another 500 response!")

band_group = 'oli' if 'OLI_TIRS' in url else 'tm'

output = collect_data(band_group, tile_resp.json())

conf = (SparkConf().setAppName("beta-{}".format(datetime.now().strftime('%Y-%m-%d-%I:%M'))))
sc = SparkContext(conf=conf)

rdd = sc.parallelize(output, 10000)

rdd.map(lambda x: run_pyccd(x))

#x = rdd.take(100)
#print "rdd.take(100): {}".format(x)

x = rdd.collect()
print "got it! collect len:{},\ncollect[0] {}".format(len(x), x[:100])



