关注并星标🌟 一起学安全❤️
作者:coleak
首发于公号:渗透测试安全攻防
字数:98321
声明:仅供学习参考,请勿用作违法用途
目录
-
Deepwall
-
FFI
-
WFP
-
mitmproxy
-
Pytorch
-
mysql
-
用户态代码汇总
-
后记
-
reference
Deepwall
简介:
一个简易的AI防火墙,项目共花费一周时间仅用于复习知识点,该框架仅供参考 不具备实战价值
该项目所用技术栈如下:
-
FFI:python调用c相关api完成内核态和用户态的通信 -
WFP:过滤来自80和443端口恶意ip的连接,驱动内维持ip黑名单表 -
pytorch:FNN机器学习,并保存训练好的模型 -
mitmproxy:在用户态实时抓取访问的url并获取对应的ip -
mysql:记录恶意的ip到IP table -
named_pipe:负责用户态通信,将url送入pytorch判定
FFI
import ctypes
import time
from ctypes import wintypes
# 手动定义常量
GENERIC_READ = 0x80000000
GENERIC_WRITE = 0x40000000
OPEN_EXISTING = 3
FILE_ATTRIBUTE_SYSTEM = 0x00000004
INVALID_HANDLE_VALUE = -1
METHOD_BUFFERED = 0
FILE_ANY_ACCESS = 0
FILE_DEVICE_UNKNOWN = 0x00000022 # 设备类型未知
# 定义 IOCTL 控制代码
def CTL_CODE(DeviceType, Function, Method, Access):
return ((DeviceType << 16) | (Access << 14) | (Function << 2) | Method)
IOCTL_SEND_MESSAGE = CTL_CODE(FILE_DEVICE_UNKNOWN, 0x800, METHOD_BUFFERED, FILE_ANY_ACCESS)
# 加载 Windows API 函数
kernel32 = ctypes.WinDLL('kernel32')
# 定义 CreateFile 函数的参数和返回值
kernel32.CreateFileW.argtypes = [
wintypes.LPCWSTR, # lpFileName
wintypes.DWORD, # dwDesiredAccess
wintypes.DWORD, # dwShareMode
wintypes.LPVOID, # lpSecurityAttributes
wintypes.DWORD, # dwCreationDisposition
wintypes.DWORD, # dwFlagsAndAttributes
wintypes.HANDLE # hTemplateFile
]
kernel32.CreateFileW.restype = wintypes.HANDLE
# 定义 DeviceIoControl 函数的参数和返回值
kernel32.DeviceIoControl.argtypes = [
wintypes.HANDLE, # hDevice
wintypes.DWORD, # dwIoControlCode
wintypes.LPVOID, # lpInBuffer
wintypes.DWORD, # nInBufferSize
wintypes.LPVOID, # lpOutBuffer
wintypes.DWORD, # nOutBufferSize
ctypes.POINTER(wintypes.DWORD), # lpBytesReturned
wintypes.LPVOID # lpOverlapped
]
kernel32.DeviceIoControl.restype = wintypes.BOOL
# 定义 CloseHandle 函数的参数和返回值
kernel32.CloseHandle.argtypes = [wintypes.HANDLE]
kernel32.CloseHandle.restype = wintypes.BOOL
# 定义 CTL_CODE 函数
def CTL_CODE(DeviceType, Function, Method, Access):
return (DeviceType << 16) | (Access << 14) | (Function << 2) | Method
def main():
# 打开设备
device_name = r'\.cc'
hDevice = kernel32.CreateFileW(
device_name,
GENERIC_READ | GENERIC_WRITE,
0,
0,
OPEN_EXISTING,
FILE_ATTRIBUTE_SYSTEM,
0
)
if hDevice == INVALID_HANDLE_VALUE:
print(f"Failed to open device. Error: {ctypes.GetLastError()}")
for i in range(20):
mess=b'192.168.10.1'
send(mess,hDevice)
time.sleep(2)
kernel32.CloseHandle(hDevice)
def send(message,hDevice):
# 准备发送的消息
buffer_size = len(message) + 1
response = ctypes.create_string_buffer(100)
bytes_returned = wintypes.DWORD()
# 发送 IOCTL 请求
result = kernel32.DeviceIoControl(
hDevice,
IOCTL_SEND_MESSAGE,
message,
buffer_size,
response,
len(response),
ctypes.byref(bytes_returned),
None
)
if result:
print(f"Response from driver: {response.value.decode()}")
else:
print(f"DeviceIoControl failed. Error: {ctypes.GetLastError()}")
if __name__ == "__main__":
main()
#include "ntddk.h"
#include "ntstrsafe.h"
#define SYMBOLLINK L"\??\cc"
//生成一个自己设备控制请求功能号 0-7ff 被微软保留,只能用比这大的
#define SENDSTR CTL_CODE(FILE_DEVICE_UNKNOWN, 0x800, METHOD_BUFFERED, FILE_ANY_ACCESS)
#define MAX_IP_COUNT 100
#define IP_STRING_LENGTH 100
char ipBlacklist[MAX_IP_COUNT][IP_STRING_LENGTH] = {0};
int ipCount = 0;
PDEVICE_OBJECT dev = NULL; //控制设备
NTSTATUS AddIpToBlacklist(const char* ipAddress) {
if (ipCount >= MAX_IP_COUNT) {
return STATUS_INSUFFICIENT_RESOURCES; // 黑名单已满
}
// 确保 IP 地址不会超出缓冲区长度
RtlStringCchCopyA(ipBlacklist[ipCount], IP_STRING_LENGTH, ipAddress);
ipCount++;
return STATUS_SUCCESS;
}
BOOLEAN IsIpInBlacklist(const char* ipAddress) {
for (int i = 0; i < ipCount; i++) {
if (strcmp(ipBlacklist[i], ipAddress) == 0) {
return TRUE; // IP 地址在黑名单中
}
}
return FALSE; // IP 地址不在黑名单中
}
NTSTATUS RemoveIpFromBlacklist(const char* ipAddress) {
for (int i = 0; i < ipCount; i++) {
if (strcmp(ipBlacklist[i], ipAddress) == 0) {
// 将数组中的最后一个 IP 移动到当前位置,覆盖被删除的 IP
RtlStringCchCopyA(ipBlacklist[i], IP_STRING_LENGTH, ipBlacklist[ipCount - 1]);
ipCount--;
return STATUS_SUCCESS;
}
}
return STATUS_NOT_FOUND; // IP 地址不在黑名单中
}
VOID DriverUnload(PDRIVER_OBJECT DriverObject)
{
if (DriverObject != NULL)
{
UNICODE_STRING SymbolName;//符号链接
RtlInitUnicodeString(&SymbolName, SYMBOLLINK);
IoDeleteSymbolicLink(&SymbolName);//删除符号链接
if (dev != NULL)
{
IoDeleteDevice(dev);
}
DbgPrint("删除设备和符号链接成功");
}
}
NTSTATUS CreateDevice(PDRIVER_OBJECT DriverObject) {
NTSTATUS Status; //返回状态
UNICODE_STRING DeviceName; //设备名称
UNICODE_STRING SymbolName;//符号链接
RtlInitUnicodeString(&DeviceName, L"\Device\cc");
Status = IoCreateDevice(
DriverObject,
0,
&DeviceName,
FILE_DEVICE_UNKNOWN,
0,
TRUE, //是否是独占设备,安全软件一般都是独占,由某个进程打开着永不关闭
&dev
);
do
{
if (!NT_SUCCESS(Status)) {
if (Status == STATUS_OBJECT_NAME_COLLISION)
{
DbgPrint("设备名称冲突");
}
DbgPrint("创建失败");
break;
}
//初始化符号链接 设备名称应用程序是不可见的,因此驱动要暴露一个符号链接给应用层
RtlInitUnicodeString(&SymbolName, SYMBOLLINK);
Status = IoCreateSymbolicLink(&SymbolName, &DeviceName);
if (!NT_SUCCESS(Status)) { //不等于0
IoDeleteDevice(dev); //删除设备
DbgPrint("删除设备成功");
break;
}
else {
DbgPrint("创建符号链接成功");
}
} while (FALSE);//仅执行一次的经典写法,为内核态的跳出格式
return Status;
}
NTSTATUS fDispatch(PDEVICE_OBJECT pdev, PIRP irp) {
UNREFERENCED_PARAMETER(pdev);
NTSTATUS Status = STATUS_SUCCESS; //返回状态
ULONG len = 0;
PIO_STACK_LOCATION stack = IoGetCurrentIrpStackLocation(irp);
ULONG inBufferLength = stack->Parameters.DeviceIoControl.InputBufferLength;
ULONG outBufferLength = stack->Parameters.DeviceIoControl.OutputBufferLength;
PVOID inBuffer = (PCHAR)irp->AssociatedIrp.SystemBuffer;
if (stack->MajorFunction == IRP_MJ_DEVICE_CONTROL)
{
//处理DeviceIoControl
switch (stack->Parameters.DeviceIoControl.IoControlCode)
{
case SENDSTR:
if (inBufferLength > 0 && inBuffer != NULL) {
DbgPrint("Received message from user: %sn", (char*)inBuffer);
AddIpToBlacklist(inBuffer);
char response[10] = "coleak";
ULONG responseLength = (ULONG)strlen(response) + 1;
if (outBufferLength >= responseLength) {
RtlZeroMemory(inBuffer, outBufferLength);
RtlCopyMemory(inBuffer, response, responseLength);
len = responseLength;
}
else {
Status = STATUS_BUFFER_TOO_SMALL;
irp->IoStatus.Information = 0;
}
break;
}
default:
//到这里的请求都是不接受的请求,返回参数错误
Status = STATUS_INVALID_PARAMETER;
break;
}
}
irp->IoStatus.Information = len;
irp->IoStatus.Status = Status;
IoCompleteRequest(irp, IO_NO_INCREMENT);
return Status;
}
NTSTATUS DriverEntry(PDRIVER_OBJECT DriverObject, PUNICODE_STRING RegistryPath)
{
//KdBreakPoint();
if (RegistryPath != NULL)
{
DbgPrint("[%ws]所在注册表位置:%wZn", __FUNCTIONW__, RegistryPath);
}
if (DriverObject != NULL)
{
DbgPrint("[%ws]驱动对象地址:%pn", __FUNCTIONW__, DriverObject);
//创建控制设备
CreateDevice(DriverObject);
//设置分发函数
for (ULONG i = 0; i < IRP_MJ_MAXIMUM_FUNCTION; i++)
{
DriverObject->MajorFunction[i] = fDispatch;
}
DriverObject->DriverUnload = DriverUnload;
DbgPrint("驱动加载成功");
}
return STATUS_SUCCESS;
}
WFP
驱动关键代码如下,创建一个ipBlacklist字符串数组维持黑名单ip
#define NDIS_SUPPORT_NDIS6 1
#define DEV_NAME L"\Device\MY_WFP_DEV_NAME"
#define SYM_NAME L"\??\cc"
#include <ntifs.h>
#include <fwpsk.h>
#include <fwpmk.h>
#include <stdio.h>
#include "ntddk.h"
#include "ntstrsafe.h"
#define SYMBOLLINK L"\??\cc"
#define SENDSTR CTL_CODE(FILE_DEVICE_UNKNOWN, 0x800, METHOD_BUFFERED, FILE_ANY_ACCESS)
#define MAX_IP_COUNT 100
#define IP_STRING_LENGTH 100
char ipBlacklist[MAX_IP_COUNT][IP_STRING_LENGTH] = { 0 };
int ipCount = 0;
PDEVICE_OBJECT dev = NULL; //控制设备
NTSTATUS AddIpToBlacklist(const char* ipAddress) {
if (ipCount >= MAX_IP_COUNT) {
return STATUS_INSUFFICIENT_RESOURCES; // 黑名单已满
}
// 确保 IP 地址不会超出缓冲区长度
RtlStringCchCopyA(ipBlacklist[ipCount], IP_STRING_LENGTH, ipAddress);
ipCount++;
return STATUS_SUCCESS;
}
BOOLEAN IsIpInBlacklist(const char* ipAddress) {
for (int i = 0; i < ipCount; i++) {
if (strcmp(ipBlacklist[i], ipAddress) == 0) {
return TRUE; // IP 地址在黑名单中
}
}
return FALSE; // IP 地址不在黑名单中
}
NTSTATUS RemoveIpFromBlacklist(const char* ipAddress) {
for (int i = 0; i < ipCount; i++) {
if (strcmp(ipBlacklist[i], ipAddress) == 0) {
// 将数组中的最后一个 IP 移动到当前位置,覆盖被删除的 IP
RtlStringCchCopyA(ipBlacklist[i], IP_STRING_LENGTH, ipBlacklist[ipCount - 1]);
ipCount--;
return STATUS_SUCCESS;
}
}
return STATUS_NOT_FOUND; // IP 地址不在黑名单中
}
NTSTATUS fDispatch(PDEVICE_OBJECT pdev, PIRP irp) {
UNREFERENCED_PARAMETER(pdev);
NTSTATUS Status = STATUS_SUCCESS; //返回状态
ULONG len = 0;
PIO_STACK_LOCATION stack = IoGetCurrentIrpStackLocation(irp);
ULONG inBufferLength = stack->Parameters.DeviceIoControl.InputBufferLength;
ULONG outBufferLength = stack->Parameters.DeviceIoControl.OutputBufferLength;
PVOID inBuffer = (PCHAR)irp->AssociatedIrp.SystemBuffer;
if (stack->MajorFunction == IRP_MJ_DEVICE_CONTROL)
{
//处理DeviceIoControl
switch (stack->Parameters.DeviceIoControl.IoControlCode)
{
case SENDSTR:
if (inBufferLength > 0 && inBuffer != NULL) {
DbgPrint("Received message from user: %sn", (char*)inBuffer);
AddIpToBlacklist(inBuffer);
char response[10] = "coleak";
ULONG responseLength = (ULONG)strlen(response) + 1;
if (outBufferLength >= responseLength) {
RtlZeroMemory(inBuffer, outBufferLength);
RtlCopyMemory(inBuffer, response, responseLength);
len = responseLength;
}
else {
Status = STATUS_BUFFER_TOO_SMALL;
irp->IoStatus.Information = 0;
}
break;
}
default:
//到这里的请求都是不接受的请求,返回参数错误
Status = STATUS_INVALID_PARAMETER;
break;
}
}
irp->IoStatus.Information = len;
irp->IoStatus.Status = Status;
IoCompleteRequest(irp, IO_NO_INCREMENT);
return Status;
}
// Callout函数 classifyFn 事前回调函数
VOID NTAPI classifyFn(_In_ const FWPS_INCOMING_VALUES0* inFixedValues, _In_ const FWPS_INCOMING_METADATA_VALUES0* inMetaValues, _Inout_opt_ void* layerData, _In_opt_ const void* classifyContext, _In_ const FWPS_FILTER2* filter, _In_ UINT64 flowContext, _Inout_ FWPS_CLASSIFY_OUT0* classifyOut)
{
// 数据包的方向,取值 FWP_DIRECTION_INBOUND = 1 或 FWP_DIRECTION_OUTBOUND = 0
WORD wDirection = inFixedValues->incomingValue[FWPS_FIELD_ALE_FLOW_ESTABLISHED_V4_DIRECTION].value.int8;
// 定义本机地址与本机端口
ULONG ulLocalIp = inFixedValues->incomingValue[FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_LOCAL_ADDRESS].value.uint32;
UINT16 uLocalPort = inFixedValues->incomingValue[FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_LOCAL_PORT].value.uint16;
// 定义对端地址与对端端口
ULONG ulRemoteIp = inFixedValues->incomingValue[FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_REMOTE_ADDRESS].value.uint32;
UINT16 uRemotePort = inFixedValues->incomingValue[FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_REMOTE_PORT].value.uint16;
// 获取当前进程IRQ
KIRQL kCurrentIrql = KeGetCurrentIrql();
// 获取进程ID
ULONG64 processId = inMetaValues->processId;
UCHAR szProcessPath[256] = { 0 };
CHAR szProtocalName[256] = { 0 };
RtlZeroMemory(szProcessPath, 256);
// 获取进程路径
for (ULONG i = 0; i < inMetaValues->processPath->size; i++)
{
// 里面是宽字符存储的
szProcessPath[i] = inMetaValues->processPath->data[i];
}
// 获取当前协议类型
ProtocalIdToName(inFixedValues->incomingValue[FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_PROTOCOL].value.uint16, szProtocalName);
// 设置默认规则 允许连接
classifyOut->actionType = FWP_ACTION_PERMIT;
// 输出对端地址字符串 并阻断链接
char szRemoteAddress[256] = { 0 };
char szRemotePort[128] = { 0 };
sprintf(szRemoteAddress, "%u.%u.%u.%u", (ulRemoteIp >> 24) & 0xFF, (ulRemoteIp >> 16) & 0xFF, (ulRemoteIp >> 8) & 0xFF, (ulRemoteIp) & 0xFF);
sprintf(szRemotePort, "%d", uRemotePort);
// DbgPrint("本端: %s : %s --> 对端: %s : %s n", szLocalAddress, szLocalPort, szRemoteAddress, szRemotePort);
if (strcmp(szRemotePort, "80") == 0 || strcmp(szRemotePort, "443") == 0)
{
if (IsIpInBlacklist(szRemoteAddress))
DbgPrint("拦截网站访问请求 --> %s : %s n", szRemoteAddress, szRemotePort);
// 设置拒绝规则 拒绝连接
classifyOut->actionType = FWP_ACTION_BLOCK;
classifyOut->rights = classifyOut->rights & (~FWPS_RIGHT_ACTION_WRITE);
classifyOut->flags = classifyOut->flags | FWPS_CLASSIFY_OUT_FLAG_ABSORB;
}
//if (strcmp(szRemoteAddress, "8.141.58.64") == 0 && strcmp(szRemotePort, "443") == 0)
//{
// DbgPrint("[LyShark.com] 拦截网站访问请求 --> %s : %s n", szRemoteAddress, szRemotePort);
// // 设置拒绝规则 拒绝连接
// classifyOut->actionType = FWP_ACTION_BLOCK;
// classifyOut->rights = classifyOut->rights & (~FWPS_RIGHT_ACTION_WRITE);
// classifyOut->flags = classifyOut->flags | FWPS_CLASSIFY_OUT_FLAG_ABSORB;
//}
// 显示
DbgPrint("[LyShark.com] 方向: %d -> 协议类型: %s -> 本端地址: %u.%u.%u.%u:%d -> 对端地址: %u.%u.%u.%u:%d -> IRQL: %d -> 进程ID: %I64d -> 路径: %S n",
wDirection,
szProtocalName,
(ulLocalIp >> 24) & 0xFF,
(ulLocalIp >> 16) & 0xFF,
(ulLocalIp >> 8) & 0xFF,
(ulLocalIp) & 0xFF,
uLocalPort,
(ulRemoteIp >> 24) & 0xFF,
(ulRemoteIp >> 16) & 0xFF,
(ulRemoteIp >> 8) & 0xFF,
(ulRemoteIp) & 0xFF,
uRemotePort,
kCurrentIrql,
processId,
(PWCHAR)szProcessPath);
}
// 默认派遣函数
NTSTATUS DriverDefaultHandle(PDEVICE_OBJECT pdev, PIRP irp)
{
UNREFERENCED_PARAMETER(pdev);
NTSTATUS Status = STATUS_SUCCESS; //返回状态
ULONG len = 0;
PIO_STACK_LOCATION stack = IoGetCurrentIrpStackLocation(irp);
ULONG inBufferLength = stack->Parameters.DeviceIoControl.InputBufferLength;
ULONG outBufferLength = stack->Parameters.DeviceIoControl.OutputBufferLength;
PVOID inBuffer = (PCHAR)irp->AssociatedIrp.SystemBuffer;
if (stack->MajorFunction == IRP_MJ_DEVICE_CONTROL)
{
//处理DeviceIoControl
switch (stack->Parameters.DeviceIoControl.IoControlCode)
{
case SENDSTR:
if (inBufferLength > 0 && inBuffer != NULL) {
DbgPrint("Received message from user: %sn", (char*)inBuffer);
AddIpToBlacklist(inBuffer);
char response[10] = "coleak";
ULONG responseLength = (ULONG)strlen(response) + 1;
if (outBufferLength >= responseLength) {
RtlZeroMemory(inBuffer, outBufferLength);
RtlCopyMemory(inBuffer, response, responseLength);
len = responseLength;
}
else {
Status = STATUS_BUFFER_TOO_SMALL;
irp->IoStatus.Information = 0;
}
break;
}
default:
//到这里的请求都是不接受的请求,返回参数错误
Status = STATUS_INVALID_PARAMETER;
break;
}
}
irp->IoStatus.Information = len;
irp->IoStatus.Status = Status;
IoCompleteRequest(irp, IO_NO_INCREMENT);
return Status;
}
NTSTATUS CreateDevice(PDRIVER_OBJECT DriverObject) {
NTSTATUS Status; //返回状态
UNICODE_STRING DeviceName; //设备名称
UNICODE_STRING SymbolName;//符号链接
RtlInitUnicodeString(&DeviceName, L"\Device\cc");
Status = IoCreateDevice(
DriverObject,
0,
&DeviceName,
FILE_DEVICE_UNKNOWN,
0,
TRUE, //是否是独占设备,安全软件一般都是独占,由某个进程打开着永不关闭
&dev
);
do
{
if (!NT_SUCCESS(Status)) {
if (Status == STATUS_OBJECT_NAME_COLLISION)
{
DbgPrint("设备名称冲突");
}
DbgPrint("创建失败");
break;
}
//初始化符号链接 设备名称应用程序是不可见的,因此驱动要暴露一个符号链接给应用层
RtlInitUnicodeString(&SymbolName, SYMBOLLINK);
Status = IoCreateSymbolicLink(&SymbolName, &DeviceName);
if (!NT_SUCCESS(Status)) { //不等于0
IoDeleteDevice(dev); //删除设备
DbgPrint("删除设备成功");
break;
}
else {
DbgPrint("创建符号链接成功");
}
} while (FALSE);//仅执行一次的经典写法,为内核态的跳出格式
return Status;
}
// 卸载驱动
VOID UnDriver(PDRIVER_OBJECT driver)
{
// 删除回调函数和过滤器,关闭引擎
WfpUnload();
UNICODE_STRING ustrSymName;
RtlInitUnicodeString(&ustrSymName, SYM_NAME);
IoDeleteSymbolicLink(&ustrSymName);
if (driver->DeviceObject)
{
IoDeleteDevice(driver->DeviceObject);
}
}
// 驱动入口
NTSTATUS DriverEntry(IN PDRIVER_OBJECT Driver, PUNICODE_STRING RegistryPath)
{
NTSTATUS status = STATUS_SUCCESS;
Driver->DriverUnload = UnDriver;
for (ULONG i = 0; i < IRP_MJ_MAXIMUM_FUNCTION; i++)
{
Driver->MajorFunction[i] = DriverDefaultHandle;
}
// 创建设备
CreateDevice(Driver);
// 启动WFP
WfpLoad(Driver->DeviceObject);
Driver->DriverUnload = UnDriver;
return STATUS_SUCCESS;
}
效果展示
mitmproxy
import socket
import win32file
from mitmproxy import ctx
PIPE_NAME = r'\.pipeMyNamedPipe'
pipe = win32file.CreateFile(
PIPE_NAME,
win32file.GENERIC_READ | win32file.GENERIC_WRITE,
0,
None,
win32file.OPEN_EXISTING,
0,
None
)
class CaptureRequests:
def request(self, flow):
# 检查请求是否为 HTTP
if flow.request.scheme == "http" or flow.request.scheme == "https":
# 获取远程 IP 地址
try:
domain = flow.request.host
remote_ip = socket.gethostbyname(domain)
except socket.gaierror:
return
# 获取 URL
url = flow.request.url
# 将远程 IP 地址和 URL 添加到列表
try:
# 向管道写入数据
message = url.encode()+b","+remote_ip.encode()
win32file.WriteFile(pipe, message)
except Exception as e:
print(f"Client exception: {e}")
addons = [
CaptureRequests()
]
import win32pipe
import win32file
import win32api
import time
PIPE_NAME = r'\.pipeMyNamedPipe'
def server():
# 创建一个命名管道
pipe = win32pipe.CreateNamedPipe(
PIPE_NAME,
win32pipe.PIPE_ACCESS_DUPLEX,
win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_READMODE_MESSAGE | win32pipe.PIPE_WAIT,
1, 65536, 65536, 0, None
)
print("Server: Waiting for a client to connect...")
# 等待客户端连接
win32pipe.ConnectNamedPipe(pipe, None)
print("Server: Client connected.")
while(1):
try:
# 从管道读取数据
hr, data = win32file.ReadFile(pipe, 4096)
time.sleep(5)
if hr == 0:
print(f"Server received: {data.decode()}")
except Exception as e:
print(f"Server exception: {e}")
break
win32pipe.DisconnectNamedPipe(pipe)
print("Server: Client disconnected.")
win32api.CloseHandle(pipe)
if __name__ == '__main__':
server()
Pytorch
URL特征提取
import urllib
import pandas as pd
# 判断字符x是否为数字
def is_digit(x):
return '0' <= x <= '9'
# 判断字符x是否为字母
def is_letter(x):
return ('a' <= x <= 'z') or ('A' <= x <= 'Z')
# 判断是否含有恶意词
def count_badwords(url, badwords):
return sum(url.count(word) for word in badwords)
def longest_path_segment_length(url):
parsed_url = urllib.parse.urlparse(url)
path_segments = parsed_url.path.split('/')[1:]
if path_segments:
return max(len(segment) for segment in path_segments)
return 0
def extract_features(url, badwords):
url = url.lower()
url_len = len(url)
dig_ratio = sum(1 for c in url if is_digit(c)) / url_len
special_char_count = sum(1 for c in url if not (is_digit(c) or is_letter(c)))
url=url.replace("://",'')
url_depth = url.count('/')
dot_count = url.count('.')
at_symbol = 1 if '@' in url else 0
badword_count = count_badwords(url, badwords)
exe_or_php_count = url.count('.exe') + url.count('.php')
http_www_count = sum(url[1:].count(word) for word in ['http', 'www'])
params_count = url.count('&')
search_length = len(url.split('?')[1]) if '?' in url else 0
longest_path_len = longest_path_segment_length(url)
return [
url_len, dig_ratio, special_char_count, url_depth, dot_count,
at_symbol, badword_count, exe_or_php_count, http_www_count,
params_count, search_length,longest_path_len
]
def main(input_csv, output_csv, badwords_file):
# 读取badwords文件
with open(badwords_file, 'r') as f:
badwords = f.read().splitlines()
# 读取输入CSV
df = pd.read_csv(input_csv)
# 提取特征
feature_names = [
'URL长度', '数字比例', '特殊字符个数', 'URL深度(/)', '出现点的次数(.)',
'是否存在@符号', '出现恶意词的次数', '出现.php或者.exe的次数',
'在除了开头位置出现http,www的次数', '参数个数', 'search长度','最长path长度'
]
features = df['URL'].apply(lambda url: extract_features(url, badwords))
features_df = pd.DataFrame(features.tolist(), columns=feature_names)
# 合并原始数据和特征
result_df = pd.concat([df, features_df], axis=1)
# 保存到输出CSV
result_df.to_csv(output_csv, index=False)
print(f"Features saved to {output_csv}")
# 调用main函数,输入文件名,输出文件名和badwords文件
main('data.csv', 'output.csv', 'badwords.txt')
badwords.txt
000webhostapp
1tempurl
1033
13inboxlight
16mb
2017
2018
acc
account
action
access
action
access
admin
afurlonges
alfacomercial
alibaba
alert
ameli
america
amp
angelklchen
app
appmanager
apple
appleid
apps
apostile
asb
asd
assist
assure
aspx
altervista
automatic
automaticsfdsent
auth
bank
bankofamerica
bay
beget
bid
bin
biz
bonus
bookmark
brobabil
business
cbcxt
cgi
chase
chasebank
chaseonline
center
centralserver
card
classes
check
cgi
chase
chasebank
chaseonline
center
centralserver
cloud
cmd
com1
components
compte
confirm
connexion
contact
content
country
crm
customer
creatory
daten
date
ddns
deu
dev
dns
dnset
document
down
download
eadu
eby
ebay
ebayisapi
electric
email
esy
excel
facebook
file
for
form
free
frost
gdn
genaro
global
goodsteel
google
group
haga
halkbank
havven
help
hit
hol
home
hosting
hotis
htm
icloud
identity
images
index
includes
inet
information
informations
intl
inc
iso
itunes
known
kolp
konto
kostumernaya
kunden
less
link
lionelbrown
live
locale
log
login
loginlink
logon
lucky
luk
lmportant
mail
microsoft
manage
masterweb
message
mkt
mobile
mobi
modules
moncompte
mypdfadobedocs
myutilitydomain
myaccount
myfw
nab
name
n0tice
neibottling
netflix
new
notice
notlce
our
orange
otis
page
pages
paypal
pdf
payment
pay
peraltek
personal
php
pki
plugins
portal
portailas
rand
recovery
read
recover
request
respond
report
review
saqibsiddiqui
secure
securelogin
secured
sicherheit
sicherheitssystem
signin
signin
sig
sign
site
sites
supply
site
sites
snsc
somtc
space
sanginbenz
scottwoolbright
secure
securelogin
secured
sicherheit
sicherheitssystem
signin
sig
sign
site
snsc
somtc
space
styles
store
standard
stpaul
stpaulsmathura
submit
support
system
task
team
tech
terms
the
thonyes
tmp
tourstogo
trade
tracelog
true
truememberent
update
upgrade
usa
usaa
user
validate
validation
validierung
valid
verify
verifications
verification
vip
warnlng
web
webapps
webcindario
webscr
websrc
website
well
win
writer
xyz
your
ziraat
ziraatbank
完整框架
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
class MaliciousURLClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
super(MaliciousURLClassifier, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim1)
self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
self.fc3 = nn.Linear(hidden_dim2, output_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.5)
self.batch_norm1 = nn.BatchNorm1d(hidden_dim1)
self.batch_norm2 = nn.BatchNorm1d(hidden_dim2)
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.batch_norm1(x)
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.batch_norm2(x)
x = self.dropout(x)
x = self.fc3(x)
return x
class URLDataset(Dataset):
def __init__(self, csv_file):
data = pd.read_csv(csv_file,names=[0, 1, 2, 3, 4,5,6,7,8,9,10,11,12,13])
d = {'good': 0, 'bad': 1}
data[1] = data[1].map(d)
features = data.drop(columns=[1])
features = features.drop(columns=[0])
labels = pd.DataFrame(data[1])
features = features.iloc[:, :]
labels = labels.iloc[:, :]
self.features = torch.from_numpy(np.array(features, dtype='float32'))
self.label = torch.from_numpy(np.array(labels, dtype='int64')).squeeze(-1)
self.data_num = len(labels)
def __len__(self):
return self.data_num
def __getitem__(self, idx):
return self.features[idx], self.label[idx]
# CSV 文件路径
csv_file = "output.csv"
# 创建数据集实例
url_dataset = URLDataset(csv_file)
# 划分数据集
train_size = int(0.6 * len(url_dataset))
val_size = int(0.2 * len(url_dataset))
test_size = len(url_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(url_dataset, [train_size, val_size, test_size])
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
def train(model, dataloader, criterion, optimizer, device):
model.train()
acc_num = torch.zeros(1).to(device)
total = 0
for datas in dataloader:
inputs, labels=datas
#labels = labels.squeeze(-1)
total += inputs.shape[0]
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
acc_num += torch.eq(preds, labels.to(device)).sum().item()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_acc = acc_num / total
return epoch_acc.item()
def evaluate(model, dataloader, criterion, device):
model.eval()
corrects = 0
total = 0
with torch.no_grad():
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
#labels = labels.squeeze(-1)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
corrects += torch.sum(preds == labels.data)
total += labels.size(0)
epoch_acc = corrects.double() / total
return epoch_acc.item()
# 设备配置
device = torch.device("cuda:0")
# 模型参数
input_dim = 12 # 特征的数量
hidden_dim1 = 16
hidden_dim2 = 12
output_dim = 2 # 输出类别数
# 创建模型实例
model = MaliciousURLClassifier(input_dim, hidden_dim1, hidden_dim2, output_dim)
model.to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
# 训练模型
num_epochs = 2
for epoch in range(num_epochs):
train_acc = train(model, train_loader, criterion, optimizer, device)
val_acc = evaluate(model, val_loader, criterion, device)
print(f"Epoch {epoch+1}/{num_epochs} | , Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
# 测试模型
test_acc = evaluate(model, test_loader, criterion, device)
print(f"Test Acc: {test_acc:.4f}")
# 保存模型
torch.save(model, "malicious_url_classifier.pth")
def test():
path = "malicious_url_classifier.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(path, map_location=device)
model.eval()
list = [53,0.0,8,3,4,9,15,8,5,0,0,15]
# Prepare a single data point
# Assuming the input data format is similar to the one used during training
data_point = np.array(list, dtype='float32') # Replace with your data
data_point = (data_point - np.mean(data_point)) / np.std(data_point)
data_point = torch.from_numpy(data_point).float().unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(data_point.to(device))
print(output)
predicted_class = torch.max(output, dim=1)[1].item()
print("Predicted class:", predicted_class)
time.sleep(2)
test()
框架简介
激活函数:ReLU:f(x)=max(0,x) 非线性 缓解梯度消失问题
神经网络:全连接的前馈神经网络(FNN),两个隐藏层,每个隐藏层都经过BN归一化和drop防过拟合
损失函数:交叉熵损失函数 (CrossEntropyLoss),适用于分类任务。
优化器:Adam
训练循环:模型训练了 20 个周期,每个周期包括以下步骤:1、模型在训练数据集上进行前向传播和反向传播,并更新参数。2、在验证数据集上进行评估,记录损失和准确率。
mysql
use ip;
CREATE TABLE ip_table (
ip VARCHAR(60) NOT NULL,
PRIMARY KEY (ip)
);
import mysql.connector
from mysql.connector import errorcode
config = {
'user': 'root',
'password': '123456',
'host': 'localhost',
'database': 'ip',
'raise_on_warnings': True
}
def insert_data(cursor, ip):
try:
insert_query = "INSERT INTO ip_table (ip) VALUES (%s)"
cursor.execute(insert_query, (ip,))
except mysql.connector.Error as err:
if err.errno == errorcode.ER_DUP_ENTRY:
print(f"Error: Duplicate entry for IP '{ip}'")
else:
print(err)
def sqlinit(cursor):
query = "SELECT ip FROM ip_table"
cursor.execute(query)
for (ip,) in cursor:
print(f"IP: {ip}")
try:
cnx = mysql.connector.connect(**config)
cursor = cnx.cursor()
except mysql.connector.Error as err:
print(err)
sqlinit(cursor)
cursor.close()
cnx.close()
用户态代码汇总
import urllib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import win32pipe
import win32file
import win32api
import mysql.connector
from mysql.connector import errorcode
import ctypes
from ctypes import wintypes
config = {
'user': 'root',
'password': '123456',
'host': 'localhost',
'database': 'ip',
'raise_on_warnings': True
}
PIPE_NAME = r'\.pipeMyNamedPipe'
# 手动定义常量
GENERIC_READ = 0x80000000
GENERIC_WRITE = 0x40000000
OPEN_EXISTING = 3
FILE_ATTRIBUTE_SYSTEM = 0x00000004
INVALID_HANDLE_VALUE = -1
METHOD_BUFFERED = 0
FILE_ANY_ACCESS = 0
FILE_DEVICE_UNKNOWN = 0x00000022 # 设备类型未知
hDevice=0
# 定义 IOCTL 控制代码
def CTL_CODE(DeviceType, Function, Method, Access):
return ((DeviceType << 16) | (Access << 14) | (Function << 2) | Method)
IOCTL_SEND_MESSAGE = CTL_CODE(FILE_DEVICE_UNKNOWN, 0x800, METHOD_BUFFERED, FILE_ANY_ACCESS)
# 加载 Windows API 函数
kernel32 = ctypes.WinDLL('kernel32')
# 定义 CreateFile 函数的参数和返回值
kernel32.CreateFileW.argtypes = [
wintypes.LPCWSTR, # lpFileName
wintypes.DWORD, # dwDesiredAccess
wintypes.DWORD, # dwShareMode
wintypes.LPVOID, # lpSecurityAttributes
wintypes.DWORD, # dwCreationDisposition
wintypes.DWORD, # dwFlagsAndAttributes
wintypes.HANDLE # hTemplateFile
]
kernel32.CreateFileW.restype = wintypes.HANDLE
# 定义 DeviceIoControl 函数的参数和返回值
kernel32.DeviceIoControl.argtypes = [
wintypes.HANDLE, # hDevice
wintypes.DWORD, # dwIoControlCode
wintypes.LPVOID, # lpInBuffer
wintypes.DWORD, # nInBufferSize
wintypes.LPVOID, # lpOutBuffer
wintypes.DWORD, # nOutBufferSize
ctypes.POINTER(wintypes.DWORD), # lpBytesReturned
wintypes.LPVOID # lpOverlapped
]
kernel32.DeviceIoControl.restype = wintypes.BOOL
# 定义 CloseHandle 函数的参数和返回值
kernel32.CloseHandle.argtypes = [wintypes.HANDLE]
kernel32.CloseHandle.restype = wintypes.BOOL
# 定义 CTL_CODE 函数
def CTL_CODE(DeviceType, Function, Method, Access):
return (DeviceType << 16) | (Access << 14) | (Function << 2) | Method
def indriver():
# 打开设备
device_name = r'\.cc'
hDevice = kernel32.CreateFileW(
device_name,
GENERIC_READ | GENERIC_WRITE,
0,
0,
OPEN_EXISTING,
FILE_ATTRIBUTE_SYSTEM,
0
)
if hDevice == INVALID_HANDLE_VALUE:
print(f"Failed to open device. Error: {ctypes.GetLastError()}")
def send(message):
# 准备发送的消息
buffer_size = len(message) + 1
response = ctypes.create_string_buffer(100)
bytes_returned = wintypes.DWORD()
# 发送 IOCTL 请求
result = kernel32.DeviceIoControl(
hDevice,
IOCTL_SEND_MESSAGE,
message,
buffer_size,
response,
len(response),
ctypes.byref(bytes_returned),
None
)
if result:
print(f"Response from driver: {response.value.decode()}")
else:
print(f"DeviceIoControl failed. Error: {ctypes.GetLastError()}")
def insert_data(cursor, ip):
try:
insert_query = "INSERT INTO ip_table (ip) VALUES (%s)"
cursor.execute(insert_query, (ip,))
except mysql.connector.Error as err:
if err.errno == errorcode.ER_DUP_ENTRY:
print(f"Error: Duplicate entry for IP '{ip}'")
else:
print(err)
def sqlinit(cursor):
query = "SELECT ip FROM ip_table"
cursor.execute(query)
for (ip,) in cursor:
print(f"IP: {ip}")
send(ip)
def server():
try:
# 连接到数据库
cnx = mysql.connector.connect(**config)
cursor = cnx.cursor()
except mysql.connector.Error as err:
print(err)
indriver()
sqlinit(cursor)
# 创建一个命名管道
pipe = win32pipe.CreateNamedPipe(
PIPE_NAME,
win32pipe.PIPE_ACCESS_DUPLEX,
win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_READMODE_MESSAGE | win32pipe.PIPE_WAIT,
1, 65536, 65536, 0, None
)
print("Server: Waiting for a client to connect...")
# 等待客户端连接
win32pipe.ConnectNamedPipe(pipe, None)
print("Server: Client connected.")
while(1):
try:
# 从管道读取数据
hr, data = win32file.ReadFile(pipe, 4096)
data=data.decode()
if hr == 0:
print(f"Server received: {data}")
if(test(data.split(',')[0])):
insert_data(cursor,data.split(',')[1])
send(data.split(',')[1])
cnx.commit()
except Exception as e:
print(f"Server exception: {e}")
break
win32pipe.DisconnectNamedPipe(pipe)
cursor.close()
cnx.close()
print("Server: Client disconnected.")
win32api.CloseHandle(pipe)
kernel32.CloseHandle(hDevice)
class MaliciousURLClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
super(MaliciousURLClassifier, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim1)
self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
self.fc3 = nn.Linear(hidden_dim2, output_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.5)
self.batch_norm1 = nn.BatchNorm1d(hidden_dim1)
self.batch_norm2 = nn.BatchNorm1d(hidden_dim2)
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.batch_norm1(x)
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.batch_norm2(x)
x = self.dropout(x)
x = self.fc3(x)
return x
class URLDataset(Dataset):
def __init__(self, csv_file):
data = pd.read_csv(csv_file,names=[0, 1, 2, 3, 4,5,6,7,8,9,10,11,12,13])
d = {'good': 0, 'bad': 1}
data[1] = data[1].map(d)
features = data.drop(columns=[1])
features = features.drop(columns=[0])
labels = pd.DataFrame(data[1])
features = features.iloc[:, :]
labels = labels.iloc[:, :]
self.features = torch.from_numpy(np.array(features, dtype='float32'))
self.label = torch.from_numpy(np.array(labels, dtype='int64'))
self.data_num = len(labels) # 存储训练集的所有图片路径
def __len__(self):
return self.data_num
def __getitem__(self, idx):
return self.features[idx], self.label[idx]
badwords=['000webhostapp', '1tempurl', '1033', '13inboxlight', '16mb', '2017', '2018', 'acc', 'account', 'action', 'access', 'action', 'access', 'admin', 'afurlonges', 'alfacomercial', 'alibaba', 'alert', 'ameli', 'america', 'amp', 'angelklchen', 'app', 'appmanager', 'apple', 'appleid', 'apps', 'apostile', 'asb', 'asd', 'assist', 'assure', 'aspx', 'altervista', 'automatic', 'automaticsfdsent', 'auth', 'bank', 'bankofamerica', 'bay', 'beget', 'bid', 'bin', 'biz', 'bonus', 'bookmark', 'brobabil', 'business', 'cbcxt', 'cgi', 'chase', 'chasebank', 'chaseonline', 'center', 'centralserver', 'card', 'classes', 'check', 'cgi', 'chase', 'chasebank', 'chaseonline', 'center', 'centralserver', 'cloud', 'cmd', 'com1', 'components', 'compte', 'confirm', 'connexion', 'contact', 'content', 'country', 'crm', 'customer', 'creatory', 'daten', 'date', 'ddns', 'deu', 'dev', 'dns', 'dnset', 'document', 'down', 'download', 'eadu', 'eby', 'ebay', 'ebayisapi', 'electric', 'email', 'esy', 'excel', 'facebook', 'file', 'for', 'form', 'free', 'frost', 'gdn', 'genaro', 'global', 'goodsteel', 'google', 'group', 'haga', 'halkbank', 'havven', 'help', 'hit', 'hol', 'home', 'hosting', 'hotis', 'htm', 'icloud', 'identity', 'images', 'index', 'includes', 'inet', 'information', 'informations', 'intl', 'inc', 'iso', 'itunes', 'known', 'kolp', 'konto', 'kostumernaya', 'kunden', 'less', 'link', 'lionelbrown', 'live', 'locale', 'log', 'login', 'loginlink', 'logon', 'lucky', 'luk', 'lmportant', 'mail', 'microsoft', 'manage', 'masterweb', 'message', 'mkt', 'mobile', 'mobi', 'modules', 'moncompte', 'mypdfadobedocs', 'myutilitydomain', 'myaccount', 'myfw', 'nab', 'name', 'n0tice', 'neibottling', 'netflix', 'new', 'notice', 'notlce', 'our', 'orange', 'otis', 'page', 'pages', 'paypal', 'pdf', 'payment', 'pay', 'peraltek', 'personal', 'php', 'pki', 'plugins', 'portal', 'portailas', 'rand', 'recovery', 'read', 'recover', 'request', 'respond', 'report', 'review', 'saqibsiddiqui', 'secure', 'securelogin', 'secured', 'sicherheit', 'sicherheitssystem', 'signin', 'signin', 'sig', 'sign', 'site', 'sites', 'supply', 'site', 'sites', 'snsc', 'somtc', 'space', 'sanginbenz', 'scottwoolbright', 'secure', 'securelogin', 'secured', 'sicherheit', 'sicherheitssystem', 'signin', 'sig', 'sign', 'site', 'snsc', 'somtc', 'space', 'styles', 'store', 'standard', 'stpaul', 'stpaulsmathura', 'submit', 'support', 'system', 'task', 'team', 'tech', 'terms', 'the', 'thonyes', 'tmp', 'tourstogo', 'trade', 'tracelog', 'true', 'truememberent', 'update', 'upgrade', 'usa', 'usaa', 'user', 'validate', 'validation', 'validierung', 'valid', 'verify', 'verifications', 'verification', 'vip', 'warnlng', 'web', 'webapps', 'webcindario', 'webscr', 'websrc', 'website', 'well', 'win', 'writer', 'xyz', 'your', 'ziraat', 'ziraatbank']
# 判断字符x是否为数字
def is_digit(x):
return '0' <= x <= '9'
# 判断字符x是否为字母
def is_letter(x):
return ('a' <= x <= 'z') or ('A' <= x <= 'Z')
# 判断是否含有恶意词
def count_badwords(url, badwords):
return sum(url.count(word) for word in badwords)
def longest_path_segment_length(url):
parsed_url = urllib.parse.urlparse(url)
path_segments = parsed_url.path.split('/')[1:]
if path_segments:
return max(len(segment) for segment in path_segments)
return 0
def extract_features(url):
url = url.lower()
url_len = len(url)
dig_ratio = sum(1 for c in url if is_digit(c)) / url_len
special_char_count = sum(1 for c in url if not (is_digit(c) or is_letter(c)))
url = url.replace("://", '')
url_depth = url.count('/')
dot_count = url.count('.')
at_symbol = 1 if '@' in url else 0
badword_count = count_badwords(url, badwords)
exe_or_php_count = url.count('.exe') + url.count('.php')
http_www_count = sum(url[1:].count(word) for word in ['http', 'www'])
params_count = url.count('&')
search_length = len(url.split('?')[1]) if '?' in url else 0
longest_path_len = longest_path_segment_length(url)
return [
url_len, dig_ratio, special_char_count, url_depth, dot_count,
at_symbol, badword_count, exe_or_php_count, http_www_count,
params_count, search_length, longest_path_len
]
def test(url):
device = torch.device("cuda")
path = "malicious_url_classifier.pth"
model = torch.load(path, map_location=device)
model.to(device)
model.eval()
list = extract_features(url)
data_point = np.array(list, dtype='float32') # Replace with your data
data_point = torch.from_numpy(data_point).float().unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(data_point.to(device))
print(output)
predicted_class = torch.max(output, dim=1)[1].item()
print("Predicted class:", predicted_class)
return int(predicted_class)
server()
import socket
import win32file
from mitmproxy import ctx
PIPE_NAME = r'\.pipeMyNamedPipe'
pipe = win32file.CreateFile(
PIPE_NAME,
win32file.GENERIC_READ | win32file.GENERIC_WRITE,
0,
None,
win32file.OPEN_EXISTING,
0,
None
)
class CaptureRequests:
def request(self, flow):
# 检查请求是否为 HTTP
if flow.request.scheme == "http" or flow.request.scheme == "https":
# 获取远程 IP 地址
try:
domain = flow.request.host
remote_ip = socket.gethostbyname(domain)
except socket.gaierror:
return
# 获取 URL
url = flow.request.url
# 将远程 IP 地址和 URL 添加到列表
try:
# 向管道写入数据
message = url.encode()+b","+remote_ip.encode()
win32file.WriteFile(pipe, message)
except Exception as e:
print(f"Client exception: {e}")
addons = [
CaptureRequests()
]
后记
环境安装
Anaconda Prompt
conda create -n pytorch_gpu
conda info -e
conda activate pytorch_gpu
conda install pandas
import torch
torch.cuda.is_available()
#返回true,gpu版本已可用
conda deactivate
然后将conda环境放入pycharm,并选择虚拟环境pytorch_gpu即可
过程记录
epoch和batch影响
较大的 batch size 可以更快地完成一个 epoch,但可能需要更多的内存;较小的 batch size 则相反。
较小的 batch size 通常会导致梯度估计的方差较大,可能会使训练过程更加不稳定
预处理
data = (data - np.mean(data)) / np.std(data)
#将数据的每个值减去均值使数据中心化,然后将中心化后的数据除以标准差完成标准化。此时均值为0,标准差为1
常用的处理链:pd.read_csv->np.array->torch.from_numpy,此时为tensor数据结构(PyTorch张量)
读取数据
dataload继承自Dataset,需要重写三个函数:__getitem__、__len__、__len__
初始化神经网络
继承自nn.Module
初始化层数和定义forward前向传播函数
class Neuralnetwork(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Neuralnetwork, self).__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
def forward(self, x):
x = torch.relu(self.layer1(x))
x = torch.relu(self.layer2(x))
x = self.layer3(x)
return x
训练环境
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
划分加载数据集
torch.utils.data.random_split
DataLoader
参数初始化
model = Neuralnetwork(4, 12, 6, 3).to(device) # 实例化模型
loss_function = nn.CrossEntropyLoss() # 定义损失函数 交叉熵损失
pg = [p for p in model.parameters() if p.requires_grad] # 定义模型参数
optimizer = optim.Adam(pg, lr=0.005) # 定义优化器
由于神经网络自定义三层,因此设置自定义n_hidden_1、n_hidden_2和不可更改的in_dim、out_dim
训练过程
model.train()
for datas in train_loader:
data, label = datas
optimizer.zero_grad() # 清零梯度
outputs = model(data.to(device))
loss = loss_function(outputs, label.to(device)) # 求损失
loss.backward() # 自动求导
optimizer.step() # 梯度下降
torch.save(model.state_dict(),path)#保存模型的静态参数
torch.save(model, path)#保存训练后的模型
结果验证
def infer(model, dataset, device):
model.eval()
acc_num = 0.0
with torch.no_grad():
for data in dataset:
datas, labels = data
outputs = model(datas.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc_num += torch.eq(predict_y, labels.to(device)).sum().item()
accuratcy = acc_num / len(dataset)
return accuratcy
模型加载
import torch
import torch.nn as nn
import numpy as np
class Neuralnetwork(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Neuralnetwork, self).__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
def forward(self, x):
x = torch.relu(self.layer1(x))
x = torch.relu(self.layer2(x))
x = self.layer3(x)
return x
# Load the model
path = "a.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(path, map_location=device)
model.eval()
while(1):
list=input().split(',')
# Prepare a single data point
# Assuming the input data format is similar to the one used during training
data_point = np.array(list, dtype='float32') # Replace with your data
data_point = (data_point - np.mean(data_point)) / np.std(data_point)
data_point = torch.from_numpy(data_point).float().unsqueeze(0) # Add batch dimension and move to device
# Make prediction
with torch.no_grad():
output = model(data_point.to(device))
predicted_class = torch.max(output, dim=1)[1].item()
print("Predicted class:", predicted_class)
查看概率
with torch.no_grad():
output = loaded_model(single_data.to(device))
print(output)
前向/反向传播
前向传播 (Forward Propagation)
前向传播是将输入数据通过神经网络进行传递,生成输出的过程。具体步骤如下:
-
输入层:输入数据(特征向量)传递到网络的输入层。
-
隐藏层:输入数据经过各个隐藏层的线性变换(加权求和)和非线性激活函数,逐层传递。
-
输出层:隐藏层的输出经过最后一层的线性变换,生成最终的预测结果。
在前向传播过程中,网络的每一层都会生成激活值,这些值将作为下一层的输入,最终得到网络的输出。对于分类任务,输出层通常会使用 Softmax 函数将结果转换为概率分布。
反向传播 (Backward Propagation)
反向传播是通过计算损失函数相对于各个参数(权重和偏置)的梯度,调整参数以最小化损失函数的过程。具体步骤如下:
-
计算损失:使用前向传播的输出和真实标签计算损失。常用的损失函数有均方误差(MSE)和交叉熵损失(CrossEntropyLoss)。
-
反向传播误差:从输出层开始,逐层向后计算每个参数对损失函数的梯度。
-
计算输出层的梯度:使用损失函数对输出层的梯度进行计算。 -
逐层计算隐藏层的梯度:通过链式法则(链式求导)将梯度逐层传递到隐藏层,计算每层的参数(权重和偏置)的梯度。 -
更新参数:使用优化器(如梯度下降、Adam等)根据计算得到的梯度调整每个参数,以减小损失函数的值。
前向传播和反向传播的关系
-
前向传播:输入数据通过网络计算输出,并用输出计算损失。 -
反向传播:通过计算损失函数的梯度,调整网络参数,以使得损失最小化。
这种循环会持续多个训练周期(epochs),直到模型收敛,即损失函数值不再显著下降。通过这两个过程,神经网络能够逐渐学习到输入数据与输出标签之间的映射关系,提高在新数据上的预测准确率。
reference
https://www.cnblogs.com/zhaopengpeng/p/13668727.html
https://blog.csdn.net/weixin_43486940/article/details/123229290
https://blog.csdn.net/weixin_46470894/article/details/107145207
https://www.cnblogs.com/LyShark/p/17134954.html
https://github.com/adysec/top_1m_domains
https://github.com/bfilar/URLTran
https://blog.csdn.net/weixin_42475060/article/details/128862411
https://blog.51cto.com/u_15127601/2758687
https://github.com/dufq/malicious-URL-detection
https://blog.csdn.net/Fluentwater/article/details/130119720
https://www.jianshu.com/p/036e5057f0b9
文章首发于:渗透测试安全攻防
原文始发于微信公众号(渗透测试安全攻防):Deepwall:从零开始构建一个AI防火墙
- 左青龙
- 微信扫一扫
-
- 右白虎
- 微信扫一扫
-
评论