{-# LANGUAGE BangPatterns #-}

import           GHC.Conc (getNumProcessors)
import           Control.Concurrent.MVar
import           Control.Concurrent
import           Control.Applicative
import           Control.Monad
import           Control.Exception (evaluate)
import           Data.Function
import           Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as M
import           System.Environment

prime :: Int -> Bool
prime !n = all (\x -> n `mod` x /= 0) (2:[3,5..s])
    where s = ceiling (sqrt (fromIntegral n)) 

data BatchPool = BatchPool {
     pool_results :: MVar (IntMap Result)
    ,pool_binfo   :: MVar (Int, BatchInfo)
}

newtype Result = Result [Int]
    
data BatchInfo = BatchInfo {
     batch_gen    :: !Gen
    ,batch_offset :: !Int
}

data Gen = Gen {
     gen_current   :: !Int
    ,gen_skip      :: !Int
    ,gen_maxNumber :: !Int
}

config = BatchInfo {
     batch_offset = 50000
    ,batch_gen = Gen {
             gen_current   = 3
            ,gen_skip      = 2
            ,gen_maxNumber = 50003
        }
}

nextBatch :: BatchInfo -> BatchInfo
nextBatch (BatchInfo (Gen current skip maxNum) offset) = BatchInfo (Gen maxNum skip (maxNum+offset)) offset

filterG :: Gen -> (Int -> Bool) -> [Int]
filterG (Gen current skip maxNumber) cond = go current
    where go !n
            | n >= maxNumber = []
            | cond n         = 
                let rest = go (n + skip)
                in  rest `seq`
                        n : rest
            | otherwise      = go (n + skip)

workers :: Int -> BatchPool -> IO ()
workers n (BatchPool resultsM binfoM) = mapM_ (forkIO . loop) [1..n]
    where loop thID = do
            tmp <- takeMVar binfoM
            let (xid,binfo) = tmp
            putMVar binfoM (xid+1 , nextBatch binfo)
            let values = filterG (batch_gen binfo) prime 
            evaluate values
            results <- takeMVar resultsM
            let nmap = M.insert xid (Result values) results
            nmap `seq` putMVar resultsM nmap
            loop thID
     
    
def =   BatchPool
    <$> newMVar M.empty
    <*> newMVar (0, config) 
            
takeBatchN :: Int -> BatchPool -> IO [Int]
takeBatchN n (BatchPool resultsM _) = loop
    where loop = do
            mp <- takeMVar resultsM
            case M.lookup n mp of
                 Just (Result v) -> do
                     let nmap = M.delete n mp 
                     nmap `seq` putMVar resultsM nmap
                     return v
                 Nothing -> do
                     putMVar resultsM mp
                     threadDelay 50000
                     loop         


takeBiggestBatch :: BatchPool -> IO [Int]
takeBiggestBatch (BatchPool resultsM _) = loop
    where loop = do
            mp <- takeMVar resultsM
            putMVar resultsM M.empty
            case M.maxView mp of
                 Nothing -> do
                     threadDelay 50000
                     loop
                 Just (Result v, _) -> do
                     return v
             

loop :: MVar () -> MVar () -> BatchPool -> IO ()
loop die over bp = go 
    where go = do
            x <- takeBiggestBatch bp
            m <- tryTakeMVar die
            case m of
                 Nothing    -> go
                 Just () -> do
                    putMVar over ()
                    print (last x)

incrementalloop :: BatchPool -> IO ()
incrementalloop bp = go 0
    where go !n = do
            x <- takeBatchN n bp
            mapM_ print x
            go (n+1)

benchmark :: BatchPool -> IO ()
benchmark bpool = do
    die <- newEmptyMVar
    over <- newEmptyMVar 
    th <- forkIO $ loop die over bpool
    mapM_ (const (threadDelay (1000 * 1000))) [1..10]
    putMVar die ()
    takeMVar over 

main = do
    opt <- getArgs 
    bpool <- def
    n <- getNumProcessors
    setNumCapabilities n
    workers (max 1 (n-1)) bpool
    case opt of
         ["help"] -> mapM_ putStrLn
            ["available commands:"
            ,"  benchmark"
            ,"  help"
            ,"  primes"]
         ["benchmark"] -> benchmark bpool
         ["primes"]    -> incrementalloop bpool
         _ -> putStrLn "invalid parameter"