I don't usually write up my technical work here, mostly because I spend enough hours as is doing technical writing. But a co-author, Jon Barker, recently wrote a post on the NVIDIA Parallel For All blog about one of our papers on neural networks for detecting malware, so I thought I'd link to it here. (You can read the paper itself, "Malware Detection by Eating a Whole EXE" here.) Plus it was on the front page of Hacker News earlier this week, which is not something I thought would ever happen to my work.
Rather than rehashing everything in Jon's Parallel for All post about our work, I want to highlight some of the lessons we learned from doing this about ML/neural nets/deep learning.
As way of background, I'll lift a few paragraphs from Jon's introduction:
The paper introduces an artificial neural network trained to differentiate between benign and malicious Windows executable files with only the raw byte sequence of the executable as input. This approach has several practical advantages:
- No hand-crafted features or knowledge of the compiler used are required. This means the trained model is generalizable and robust to natural variations in malware.
- The computational complexity is linearly dependent on the sequence length (binary size), which means inference is fast and scalable to very large files.
- Important sub-regions of the binary can be identified for forensic analysis.
- This approach is also adaptable to new file formats, compilers and instruction set architectures—all we need is training data.
We also hope this paper demonstrates that malware detection from raw byte sequences has unique and challenging properties that make it a fruitful research area for the larger machine learning community.
One of the big issues we were confronting with our approach, MalConv, is that executables are often millions of bytes in length. That's orders of magnitude more time steps than most sequence processing networks deal with. Big data usually refers to lots and lots of small data points, but for us each individual sample was big. Saying this was a non-trivial problem is a serious understatement.
Here are three lessons we learned, not about malware or cybersecurity, but about the process of building neural networks on such unusual data.
1. Deep learning != image processing
The large majority of the work in deep learning has been done in the image domain. Of the remainder, the large majority has been in either text or speech. Many of the lessons, best practices, rules of thumb, etc., that we think apply to deep learning may actually be specific to these domains.
For instance, the community has settled around narrow convolutional filters, stacked with a lot of depth as being generally the best way to go. And for images, narrow-and-deep absolutely seems to be the correct choice. But in order to get a network that processes two million time steps to fit in memory at all (on beefy 16GB cards no less) we were forced to go wide-and-shallow.
With images, a pixel values is always a pixel value. 0x20
in a grayscale image is always darkish gray, no matter what. In an executable, a byte values are ridiculously polysemous: 0x20
may be part of an instruction, a string, a bit array, a compressed or encrypted values, an address, etc. You can't interpolate between values at all, so you can't resize or crop the way you would with images to make your data set smaller or introduce data augmentation. Binaries also play havoc with locality, since you can re-arrange functions in any order, among other things. You can't rely on any Tobbler's Law ((Everything is related, but near things are more related than far things.)) relationship the way you can in images, text, or speech.
2. BatchNorm isn't pixie dust
Batch Normalization has this bippity-boppity-boo magic quality. Just sprinkle it on top of your network architecture, and things that didn't converge before now do, and things that did converge now converge faster. It's worked like that every time I've tried it — on images. When we tried it on binaries it actually had the opposite effect: networks that converged slowly now didn't at all, no matter what variety of architecture we tried. It's also had no effect at all on some other esoteric data sets that I've worked on.
We discuss this at more length in the paper (§5.3), but here's the relevant figure:
This is showing the pre-BN activations from MalConv (blue) and from ResNet (red & orange) and Inception-v4 (green). The purpose of BatchNorm is to output values in a standard normal, and it implicitly expects inputs that are relatively close to that. What we suspect is happening is that the input values from other networks aren't gaussian, but they're close-ish. ((I'd love to be able to quantify that closeness, but every test for normality I'm aware of doesn't apply when you have this many samples. If anyone knows of a more robust test please let me know.)) The input values for MalConv display huge asperity, and aren't even unimodal. If BatchNorm is being wonky for you, I'd suggest plotting the pre-BN activations and checking to see that they're relatively smooth and unimodal.
3. The Lump of Regularization Fallacy
If you're overfitting, you probably need more regularization. Simple advice, and easily executed. Everytime I see this brought up though, people treat regularization as if it's this monolithic thing. Implicitly, people are talking as if you have some pile of regularization, and if you need to fight overfitting then you just shovel more regularization on top. It doesn't matter what kind, just add more.
We ran in to overfitting problems and tried every method we could think of: weight decay, dropout, regional dropout, gradient noise, activation noise, and on and on. The only one that had any impact was DeCov, which penalized activities in the penultimate layer that are highly correlated with each other. I have no idea what will work on your data — especially if it's not images/speech/text — so try different types. Don't just treat regularization as a single knob that you crank up or down.
I hope some of these lessons are helpful to you if you're into cybersecurity, or pushing machine learning into new domains in general. We'll be presenting the paper this is all based on at the Artificial Intelligence for Cyber Security (AICS) workshop at AAAI in February, so if you're at AAAI then stop by and talk.