package weka.clusterers.forMetisMQI;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.Stack;

import org.apache.commons.collections15.Factory;
import org.apache.commons.collections15.Transformer;

import weka.clusterers.forMetisMQI.graph.Bisection;
import weka.clusterers.forMetisMQI.graph.Edge;
import weka.clusterers.forMetisMQI.graph.Node;
import weka.clusterers.forMetisMQI.graph.Subgraph;
import edu.uci.ics.jung.algorithms.flows.EdmondsKarpMaxFlow;
import edu.uci.ics.jung.graph.DirectedGraph;
import edu.uci.ics.jung.graph.DirectedSparseGraph;
import weka.clusterers.forMetisMQI.util.Util;

public class MQI {

	static int i = -1;

	static private Set<Node> DFSReversed(Node currentNode,
			DirectedGraph<Node, Edge> g, Map<Edge, Number> edgeFlowMap,
			Set<Node> marked) {
		Collection<Edge> inEdges = g.getInEdges(currentNode);
		Set<Node> result = new HashSet<Node>();
		result.add(currentNode);
		Iterator<Edge> inEdgesIterator = inEdges.iterator();
		while (inEdgesIterator.hasNext()) {
			Edge edge = inEdgesIterator.next();
			Node src = g.getSource(edge);
			Edge reverseEdge = g.findEdge(src, currentNode);
			if (reverseEdge != null && !marked.contains(src)) {
				int flow = (Integer) edgeFlowMap.get(reverseEdge);
				int capacity = reverseEdge.getCapacity();
				if (flow < capacity) {
					marked.add(src);
					result.addAll(DFSReversed(src, g, edgeFlowMap, marked));
				}
			}
		}
		return result;
	}

	static private Set<Node> BFSReversed(Node sink,
			DirectedGraph<Node, Edge> g, Map<Edge, Number> edgeFlowMap) {
		Set<Node> result = new HashSet<Node>();
		Set<Node> visitedNodes = new HashSet<Node>();
		Stack<Node> nodesToVisit = new Stack<Node>();
		result.add(sink);
		nodesToVisit.push(sink);
		while (!nodesToVisit.empty()) {
			Node currentNode = nodesToVisit.pop();
			visitedNodes.add(currentNode);
			Collection<Edge> inEdges = g.getInEdges(currentNode);
			Iterator<Edge> inEdgesIterator = inEdges.iterator();
			while (inEdgesIterator.hasNext()) {
				Edge edge = inEdgesIterator.next();
				Node src = g.getSource(edge);
				Edge reverseEdge = g.findEdge(src, currentNode);
				if (reverseEdge != null) {
					int flow = (Integer) edgeFlowMap.get(reverseEdge);
					int capacity = reverseEdge.getCapacity();
					if (flow < capacity) {
						if (!nodesToVisit.contains(src)
								&& !visitedNodes.contains(src)) {
							nodesToVisit.push(src);
						}
						result.add(src);
					}
				}
			}
		}
		return result;
	}

	static private DirectedGraph<Node, Edge> prepareDirectedGraph(
			Bisection bisection, Node source, Node sink, boolean forConductance) {
		Subgraph B = bisection.getLargerSubgraph();
		Subgraph A = bisection.getSmallerSubgraph();
		int a = 0;
		if (!forConductance)
			a = A.getVertexCount();
		else {
//			a = Math.min(B.totalDegree(),A.totalDegree());
			a = A.totalDegree();
		}
		int c = bisection.edgeCut() / 2;

		DirectedGraph<Node, Edge> g = new DirectedSparseGraph<Node, Edge>();
		Iterator<Node> nodes = A.iterator();
		while (nodes.hasNext()) {
			Node u = nodes.next();
			g.addVertex(u);
		}
		nodes = A.iterator();
		int id = 0;
		while (nodes.hasNext()) {
			Node u = nodes.next();
			Iterator<Node> neighbors = A.getNeighbors(u).iterator();
			while (neighbors.hasNext()) {
				Node v = neighbors.next();
				g.addEdge(new Edge(Integer.toString(id), A.getWeight(u, v), a),
						u, v);
				id++;
			}
		}

		g.addVertex(source);
		g.addVertex(sink);

		// build the edges from source to each node of A which previously was
		// connected
		// with a node of B.
		nodes = B.iterator();
		while (nodes.hasNext()) {
			Node u = nodes.next();
			Iterator<Node> neighbors = B.getGraph().getNeighbors(u).iterator();
			while (neighbors.hasNext()) {
				Node v = neighbors.next();
				if (A.contains(v)) {
					Edge e = g.findEdge(source, v);
					if (e != null) {
						e.setCapacity(e.getCapacity() + a);
					} else {
						g.addEdge(new Edge(Integer.toString(id), 1, a), source,
								v);
						id++;
					}
				}
			}
		}

		nodes = A.iterator();
		while (nodes.hasNext()) {
			Node u = nodes.next();
			if(forConductance)
				g.addEdge(new Edge(Integer.toString(id), 1, c * bisection.getGraph().degree(u)), u, sink);
			else
				g.addEdge(new Edge(Integer.toString(id), 1, c), u, sink);
			id++;
		}
		return g;
	}

	/**
	 * Given a partion of a graph, execute the Max-Flow Quotient-cut Improvement
	 * algorithm, to find an improved cut and then returns the cluster which
	 * yields the best quotient cut.
	 * 
	 * @param partition
	 * @return
	 */
	static public Set<Node> mqi(Bisection partition, boolean forConductance) {
//		System.out.println("INITIAL BISECTION: " + partition.toString());
		boolean finished = false;
		Bisection bisection = partition;
		Set<Node> cluster = new HashSet<Node>(partition.getSmallerSubgraph()
				.createInducedSubgraph().getVertices());
//		System.out.println("IMPROVING SUBGRAPH: " + cluster);
		int maxFlowThreshold = Integer.MAX_VALUE;
		while (!finished) {
			Node source = new Node("$$$$S");
			Node sink = new Node("$$$$T");
			DirectedGraph<Node, Edge> directedGraph = prepareDirectedGraph(
					bisection, source, sink, true);
			Transformer<Edge, Number> capTransformer = new Transformer<Edge, Number>() {
				public Double transform(Edge e) {
					return (double) e.getCapacity();
				}
			};
			Map<Edge, Number> edgeFlowMap = new HashMap<Edge, Number>();
			i = -1;
			// This Factory produces new edges for use by the algorithm
			Factory<Edge> edgeFactory = new Factory<Edge>() {
				public Edge create() {
					i++;
					return new Edge("$$$$" + Integer.toString(i), 1, 1);
				}
			};
			EdmondsKarpMaxFlow<Node, Edge> alg = new EdmondsKarpMaxFlow<Node, Edge>(
					directedGraph, source, sink, capTransformer, edgeFlowMap,
					edgeFactory);

			if (!forConductance)
				maxFlowThreshold = bisection.getLargerSubgraph()
						.getVertexCount()
						* bisection.edgeCut() / 2;
			else {
//				maxFlowThreshold = Math.min(bisection.getLargerSubgraph().totalDegree(), bisection.getSmallerSubgraph().totalDegree());
				maxFlowThreshold = bisection.getSmallerSubgraph().totalDegree();
				maxFlowThreshold = maxFlowThreshold
						* (bisection.edgeCut() / 2);
			}
			alg.evaluate();
//			Util.viewFlowGraph(directedGraph, edgeFlowMap);
			System.out.println("MAX FLOW: " + alg.getMaxFlow() + " THRESHOLD: "
					+ maxFlowThreshold);
			if (alg.getMaxFlow() < maxFlowThreshold) {
				Set<Node> dfsResult = DFSReversed(sink, directedGraph,
						edgeFlowMap, new HashSet<Node>());
				dfsResult.remove(sink);
				cluster = dfsResult;
				bisection = new Bisection(new Subgraph(
					bisection.getGraph(), cluster));
//				System.out.println("REFINED BISECTION: " + bisection.toString());
			} else
				finished = true;
		}
		return cluster;
	}

}
