package GribView;

import java.awt.Color;
import java.io.ByteArrayInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.Math;
import java.rmi.RemoteException;
import javax.swing.JFrame;

import visad.ConstantMap;
import visad.DataReferenceImpl;
import visad.Display;
import visad.FlatField;
import visad.FunctionType;
import visad.Linear2DSet;
import visad.RealType;
import visad.RealTupleType;
import visad.ScalarMap;
import visad.VisADException;
import visad.java3d.DefaultRendererJ3D;
import visad.java3d.DisplayImplJ3D;
import visad.java3d.DisplayRendererJ3D;

import net.sourceforge.jgrib.BitInputStream;
import net.sourceforge.jgrib.GribFile;
import net.sourceforge.jgrib.GribRecord;
import net.sourceforge.jgrib.GribRecordBDS;
import net.sourceforge.jgrib.GribRecordBMS;
import net.sourceforge.jgrib.GribRecordGDS;
import net.sourceforge.jgrib.GribRecordPDS;
import net.sourceforge.jgrib.GribRecordIS;
import net.sourceforge.jgrib.GribRecordLight;
import net.sourceforge.jgrib.NoValidGribException;
import net.sourceforge.jgrib.NotSupportedException;

public class GribTest {

	private JFrame jframe;
	
	private Color colour = Color.white;
	private DisplayImplJ3D display;
	private DisplayRendererJ3D displayRenderer;

	private GribFile gribFile;
	
	private float minValue = 1.0e19f;  // Store minimuim value of Grib
data
	private float maxValue = -1.0e19f; // Store maximuim value of Grib
data

	public GribTest()
		throws VisADException, RemoteException
	{
		// Specify the name of the file containing the GRIB data
                String fileName = "test.grib";
		
		GribRecord gribRecord = null;
		// Create a GribFile with this filename
	        try {
			gribFile = new GribFile(fileName);
			// Read in the first record
			gribRecord = gribFile.getRecord(1);
		} catch (FileNotFoundException e) {
			System.out.println (e + " in file " + fileName);
			e.printStackTrace();
		} catch (IOException e) {
			System.out.println (e + " in file " + fileName);
			e.printStackTrace();
		} catch (NoValidGribException e) {
			System.out.println (e + " in file " + fileName);
			e.printStackTrace();
		} catch (NotSupportedException e) {
			System.out.println (e + " in file " + fileName);
			e.printStackTrace();
		}
		
		
		// Define the domain - (latitude, longitude)
		RealType latitude  = RealType.Latitude;
		RealType longitude = RealType.Longitude;
		RealTupleType domainTupleType =
			new RealTupleType(latitude, longitude);            

		// Get the geographical range of the grid
		double[] range = getExtremePoints(gribRecord.getGDS());
		int xVals = gribRecord.getGDS().getGridNX();
		int yVals = gribRecord.getGDS().getGridNY();
		
		// Create a linear 2D set with containing the latitude/longitude
		// coverage of the grid
		Linear2DSet domainSet = new Linear2DSet(
			domainTupleType, range[1], range[0],
			yVals,
			range[2], range[3], xVals);
		
		// Define the RealType of the range - RGB
		RealType rgbType     = RealType.getRealType("RGB");

		// Define the function type - (latitude, longitude) -> RGB
		FunctionType functionType =
			new FunctionType(domainTupleType, rgbType);

		// Define a flat field with this function and domain set
		FlatField gribField = new FlatField( functionType, domainSet); 

		// Get the actual data values from the grib record
		float [][] flatSamples = getSamples(gribRecord);
		// Set the flat field to contain these values
		gribField.setSamples(flatSamples ,false );
		
		// Create a data reference with this flat field
		DataReferenceImpl dataReference =
			new DataReferenceImpl("RGBDataRef");
		dataReference.setData(gribField);

		// Set up the display
		display = new DisplayImplJ3D("Grib View");
		displayRenderer =
			(DisplayRendererJ3D)display.getDisplayRenderer();
		displayRenderer.setBoxOn(false);

		// Make latitude map to Y axis, and longitude to X axis
		ScalarMap latMap     =
			new ScalarMap( latitude,        Display.YAxis );
		ScalarMap lonMap     =
			new ScalarMap( longitude,       Display.XAxis );
		// Set the range to cover the area of the GRIB field
                latMap.setRange(range[1], range[0]);
                lonMap.setRange(range[2], range[3]);
		
		// Set the RBG scalarType to map to Display.RGB
		ScalarMap rgbMap =
			new ScalarMap( rgbType,         Display.RGB );
		// Set the range to cover the range of data
		rgbMap.setRange(minValue, maxValue);
		
		// Add the scalarmaps to the display
		display.addMap( latMap );
		display.addMap( lonMap );
		display.addMap( rgbMap );
		
		// Configure the RGB constant map
		ConstantMap[] constantMap = {
			new ConstantMap(colour.getRed()/255.0f, Display.Red),
			new ConstantMap(colour.getGreen()/255.0f,Display.Green),
			new ConstantMap(colour.getBlue()/255.0f, Display.Blue),
			new ConstantMap(0.0, Display.ZAxis)
		};
		DefaultRendererJ3D dataRenderer = new DefaultRendererJ3D();
        
		// Add the data reference to the display
                display.addReferences(dataRenderer, dataReference,
constantMap);

		// Create a JFrame to contain the display
                jframe = new JFrame ("Grib Test");
                jframe.setSize(600, 500);
                jframe.getContentPane().add(display.getComponent());

                jframe.setVisible(true);
	}
	
	/**
	 * Find the extreme north, south, east and west points
	 */
	private double[] getExtremePoints (GribRecordGDS gds) {
		
		double[] range = new double[4];
		switch (gds.getGridScanmode()) {

			case 0:
				range[0] = gds.getGridLat1(); // North
				range[1] = gds.getGridLat2(); // South
				range[2] = gds.getGridLon1(); // West
				range[3] = gds.getGridLon2(); // East
				break;
			case 64:
				range[0] = gds.getGridLat2(); // North
				range[1] = gds.getGridLat1(); // South
				range[2] = gds.getGridLon1(); // West
				range[3] = gds.getGridLon2(); // East
				break;
		}
		if (range[2] > range[3]) range[3] += 360.0;
		
		return range;
	}
	
	/**
	 * Extract the data values from the binary data section of a GRIB
	 * record. Return them as a 2-D array of floats
	 */
	private float[][] getSamples (GribRecord gribRecord) {
		int xVals = gribRecord.getGDS().getGridNX();
		int yVals = gribRecord.getGDS().getGridNY();
		int scanMode = gribRecord.getGDS().getGridScanmode();
		float decimalScale = gribRecord.getPDS().getDecimalScale();
		float mdi = gribRecord.getBDS().UNDEFINED;
		float [] values = gribRecord.getBDS().getValues();

		float [][] flatSamples = new float[1][xVals*yVals];

		float scale = (float)Math.pow(10.0f, decimalScale);
		
		float sample = 0.0f;
		// Loop over each column of target 2d array
		for (int c = 0; c < xVals; c++) {
			// Loop over each row of target 2d array
			for (int r = 0; r < yVals; r++) {
				// The order of the data within the array is
				// indicated by the scanmode, 
				// must put a seperate entry here for each scan
				// mode so that each possible order is
				// catered for
				switch (scanMode) {
					case 0:
						sample =
							values[(yVals - r - 1)*xVals + c];
						break;
					case 64:
						sample =
							values[r*xVals + c];
						break;
					default:
						System.err.println("You must add support for scanmode: " +
scanMode);
						break;
				}
				// If the values equals the missing data
				// indicator. Set to NaN
				if (sample == mdi) sample = Float.NaN;
				
				// Put the sample into the array
				flatSamples[0][ c * yVals + r ] = sample;

				// Record the range of data
				if ((sample > maxValue) && (sample < mdi)) {
					maxValue = sample;
				}
				if (sample < minValue) minValue = sample;
                        
			}
		}
        
		return(flatSamples);
	}
	
	public static void main(String[] args) throws Exception {

		new GribTest();
	}
    
}
