SSH server library, fork of Hackage one but hoping to get patches upstream
Clone
HTTPS:
darcs clone https://vervis.peers.community/repos/6r4Ao
SSH:
darcs clone USERNAME@vervis.peers.community:6r4Ao
Tags
TODO
test
/
test.hs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | {-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Main where
import Test.Tasty
(TestTree, defaultMain, testGroup, withResource
)
import Test.Tasty.HUnit (testCase)
import Test.Tasty.QuickCheck (testProperty)
import Test.HUnit (assertBool)
import Test.QuickCheck
(Arbitrary(..), elements, forAll, choose, vectorOf
)
import Control.Applicative ((<$>))
import Control.Concurrent (forkIO, killThread)
import Control.Concurrent.MVar (newEmptyMVar, takeMVar, putMVar)
import Control.Exception (bracket, try, catchJust, ErrorCall(..), evaluate)
import Control.Monad (when)
import Data.ByteString.Char8 (pack)
import qualified Data.ByteString.Lazy as LBS
import Data.List (isSuffixOf)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Word (Word8)
import System.Directory (createDirectoryIfMissing, removeFile)
import System.FilePath ((<.>))
import System.IO (hPutStr, openTempFile, hClose)
import System.IO.Unsafe (unsafePerformIO)
import Network.SSH.Client.LibSSH2
import Network.SSH.Client.LibSSH2.Errors
import Network.SSH.Client.LibSSH2.Foreign
import qualified Network.SSH as SSH
import Network.SSH.Channel
import qualified Network.SSH.Internal.Crypto as Crypto
import Network.SSH.Internal.Crypto hiding (sign, verify)
import Network.SSH.Session
import EmbedTree
keysDirectory :: Map String Entry
keysDirectory = getDirectory $(embedTree "keys")
sshPort :: Num a => a -- used as an Int or a PortNumber
sshPort = 5032
withOneUserServer :: KeyPair -> PublicKey -> TestTree -> TestTree
withOneUserServer hostKp acceptedKey test = do
withResource
(do startedSignal <- newEmptyMVar
tid <- forkIO $ SSH.startConfig (config startedSignal)
takeMVar startedSignal
return tid
)
killThread
(\_ -> test)
where
config startedSignal =
SSH.Config
{ SSH.cSession = session
, SSH.cChannel = channel
, SSH.cPort = sshPort
, SSH.cReadyAction = putMVar startedSignal ()
}
session =
SessionConfig
{ scAuthMethods = ["publickey", "password"]
, scAuthorize = sshAuthorize
, scKeyPair = hostKp
, scRunBaseMonad = id
}
channel =
ChannelConfig
{ ccRequestHandler = channelRequest
, ccRunBaseMonad = id
}
sshAuthorize (PublicKey "testuser" k) =
if k == acceptedKey
then return $ AuthSuccess ()
else return $ AuthFail
sshAuthorize _ = return AuthFail
channelRequest wr (Execute "check") = do
channelMessage "checked"
when wr channelSuccess
channelDone
channelRequest wr cmd = do
channelError $ "<" ++ show cmd ++ "> not supported"
when wr channelFail
withTextInTempFile :: String -> String -> (FilePath -> IO a) -> IO a
withTextInTempFile nameTemplate contents action = do
let tempFolder = "temp"
createDirectoryIfMissing False tempFolder
bracket
(do
(f, h) <- openTempFile tempFolder nameTemplate
hPutStr h contents
hClose h
return f
)
removeFile
action
data AuthResult = OK | Error ErrorCode
deriving (Show, Eq)
authWith :: String -> KeyPair -> IO (Main.AuthResult)
authWith publicKeyText privateKeyPair = do
withTextInTempFile "private" (printKeyPair privateKeyPair) $ \privateKeyFile ->
withTextInTempFile "public" publicKeyText $ \publicKeyFile ->
withSession "localhost" sshPort $ \session -> do
authResult <- try $ publicKeyAuthFile session "testuser" publicKeyFile privateKeyFile ""
case authResult of
Left e -> return $ Error e
Right () -> do
channel <- openChannelSession session
channelExecute channel "check"
checked <- readChannel channel 20
when (checked /= pack "checked\r\n") $ fail "incorrect check result"
return OK
breakPrivateKey :: KeyPair -> KeyPair
-- This leaves enough information intact to reconstruct the private key
-- (e.g the primes), but in practice it seems to be enough to cause an
-- authentication failure.
-- Changing the numbers too much can cause segfaults or out of range signatures
breakPrivateKey kp@RSAKeyPair {} =
kp
{ rprivD = rprivD kp - 2
, rprivPrime1 = rprivPrime1 kp - 2
, rprivPrime2 = rprivPrime2 kp - 2
, rprivExponent1 = rprivExponent1 kp - 2
, rprivExponent2 = rprivExponent2 kp - 2
, rprivCoefficient = rprivCoefficient kp - 2
}
breakPrivateKey kp@DSAKeyPair {} = kp { dprivX = 1 }
publicKey :: KeyPair -> PublicKey
publicKey (RSAKeyPair { rprivPub = k }) = k
publicKey (DSAKeyPair { dprivPub = k }) = k
hostKeyPair :: KeyPair
hostKeyPair = parseKeyPair . getFile $ getEntry "host" keysDirectory
clientKeysDirectory :: Map String Entry
clientKeysDirectory = getDirectory $ getEntry "client" keysDirectory
getClientPublicKeyFileText :: String -> String
getClientPublicKeyFileText keyName = getFile $ getEntry (keyName <.> "pub") clientKeysDirectory
getClientPrivateKeyPair :: String -> KeyPair
getClientPrivateKeyPair keyName = parseKeyPair . getFile $ getEntry keyName clientKeysDirectory
privateKeyPairFiles :: [String]
privateKeyPairFiles = filter (not . isSuffixOf "pub") $ Map.keys clientKeysDirectory
singleKeyAuthTests :: TestTree
singleKeyAuthTests =
testGroup "Single key auth tests"
[
let publicKeyFileText = getClientPublicKeyFileText privateKeyPairFile
privateKeyPair = getClientPrivateKeyPair privateKeyPairFile
in
withOneUserServer hostKeyPair (publicKey privateKeyPair) $
testGroup ("Check auth with " ++ privateKeyPairFile)
[
testCase ("Works") $ do
authWith publicKeyFileText privateKeyPair
>>= assertBool "should auth with correct private key" . (==OK)
, testCase ("Fails with broken private key") $ do
authWith publicKeyFileText (breakPrivateKey privateKeyPair)
>>= assertBool "shouldn't auth with broken private key" . (==Error PUBLICKEY_UNVERIFIED)
]
| privateKeyPairFile <- privateKeyPairFiles
]
wrongKeyAuthTest :: TestTree
wrongKeyAuthTest =
withOneUserServer hostKeyPair (publicKey rightPrivateKeyPair) $
testCase "Check auth failure with wrong key" $ do
authWith wrongPublicKeyFileText wrongPrivateKeyPair
>>= assertBool "shouldn't auth with wrong private key" . (==Error AUTHENTICATION_FAILED)
where
rightPrivateKeyPair = getClientPrivateKeyPair "id_rsa_test"
wrongPrivateKeyPair = getClientPrivateKeyPair "id_rsa_test2"
wrongPublicKeyFileText = getClientPublicKeyFileText "id_rsa_test2"
instance Arbitrary LBS.ByteString where
arbitrary = LBS.pack <$> arbitrary
instance Arbitrary KeyPair where
arbitrary = elements $ map getClientPrivateKeyPair privateKeyPairFiles
instance Arbitrary PublicKey where
arbitrary = publicKey <$> arbitrary
-- QuickCheck tests end up using unsafePerformIO because sign and verify
-- are in IO, which in turn is because the DSA operations are in IO,
-- but hopefully they only have benign side-effects if any
sign :: KeyPair -> LBS.ByteString -> LBS.ByteString
sign kp message = unsafePerformIO $ Crypto.sign kp message
verify :: PublicKey -> LBS.ByteString -> LBS.ByteString -> Bool
verify key message sig =
unsafePerformIO $
catchJust
sigErrors
(Crypto.verify key message sig >>= evaluate)
(\() -> return False)
where
sigErrors (ErrorCall msg)
| msg == "signature representative out of range" = Just ()
sigErrors _ = Nothing
signThenVerifyTest :: TestTree
signThenVerifyTest = testProperty "signatures from sign work with verify" $
\kp message -> verify (publicKey kp) message $ sign kp message
signThenMutatedVerifyTest :: TestTree
signThenMutatedVerifyTest = testProperty "mutated signatures from sign fail with verify" $
\kp message ->
let sig = sign kp message
actualSignatureLen = fromIntegral $ actualSignatureLength (publicKey kp)
in forAll (choose (LBS.length sig - actualSignatureLen, LBS.length sig - 1)) $ \offset ->
forAll (choose (1, 255 :: Word8)) $ \mutation ->
let mutatedSig =
LBS.take offset sig `LBS.append`
LBS.pack [LBS.index sig offset + mutation] `LBS.append`
LBS.drop (offset+1) sig
in not $ verify (publicKey kp) message mutatedSig
randomVerifyTest :: TestTree
randomVerifyTest = testProperty "random signatures fail with verify" $
-- might be sensible to test some other lengths, but the actual code
-- just takes the last n bytes anyway, and it's not totally obvious
-- what would be a good range of values to test with.
\key message -> forAll (vectorOf (actualSignatureLength key) arbitrary) $ \sigBytes ->
not $ verify key message (LBS.pack sigBytes)
allTests :: TestTree
allTests =
testGroup "Tests"
[ testGroup "With server"
[ singleKeyAuthTests
, wrongKeyAuthTest
]
, testGroup "Signatures"
[ signThenVerifyTest
, signThenMutatedVerifyTest
, randomVerifyTest
]
]
main :: IO ()
main = defaultMain allTests
|