欢迎访问 生活随笔!

ag凯发k8国际

当前位置: ag凯发k8国际 > 人工智能 > pytorch >内容正文

pytorch

脑电信号特征提取算法c语言-ag凯发k8国际

发布时间:2024/10/8 pytorch 0 豆豆
ag凯发k8国际 收集整理的这篇文章主要介绍了 脑电信号特征提取算法c语言_应用深度学习eegnet来处理脑电信号 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

文章来源于"脑机接口社区"

应用深度学习eegnet来处理脑电信号​mp.weixin.qq.com

本篇文章内容主要包括:

  • eegnet论文;
  • eegnet的实现。
  • eegnet论文简介

    脑机接口(bci)使用神经活动作为控制信号,实现与计算机的直接通信。这种神经信号通常是从各种研究透彻的脑电图(eeg)信号中挑选出来的。卷积神经网络(cnn)主要用来自动特征提取和分类,其在计算机视觉和语音识别领域中的使用已经很广泛。cnn已成功应用于基于eeg的bci;但是,cnn主要应用于单个bci范式,在其他范式中的使用比较少,论文作者提出是否可以设计一个cnn架构来准确分类来自不同bci范式的eeg信号,同时尽可能地紧凑(定义为模型中的参数数量)。该论文介绍了eegnet,这是一种用于基于eeg的bci的紧凑型卷积神经网络。论文介绍了使用深度和可分离卷积来构建特定于eeg的模型,该模型封装了脑机接口中常见的eeg特征提取概念。论文通过四种bci范式(p300视觉诱发电位、错误相关负性反应(ern)、运动相关皮层电位(mrcp)和感觉运动节律(smr)),将eegnet在主体内和跨主体分类方面与目前最先进的方法进行了比较。结果显示,在训练数据有限的情况下,eegnet比参考算法具有更强的泛化能力和更高的性能。同时论文也证明了eegnet可以有效地推广到erp和基于振荡的bci。

    网络结构图如下:

    实验结果如下图,p300数据集的所有cnn模型之间的差异非常小,但是mrcp数据集却存在显著的差异,两个eegnet模型的性能都优于所有其他模型。对于ern数据集来说,两个eegnet模型的性能都优于其他所有模型(p < 0.05)。

    如下图每个模型的p300,ern和mrcp数据集的分类性能平均为30倍。对于p300和mrcp数据集,deepconvnet和eegnet模型之间的差异很小,两个模型的性能均优于shallowconvnet。对于ern数据集,参考算法(xdawn rg)明显优于所有其他模型。

    下图是对eegnet-4,1模型配置获得的特征进行可视化。

    (a)每个空间过滤器的空间拓扑。

    (b)每个滤波器的目标试验和非目标试验之间的平均小波时频差。

    下图中第一排是使用deeplift针对mrcp数据集的三个不同测试试验,对使用cross-subject训练的eegnet-8,2模型进行的单次试验脑电特征相关性:

    (a)高可信度,正确预测左手运动;

    (b)高可信度,正确预测右手运动;

    (c)低可信度,错误预测左手运动。

    标题包括真实的类别标签和该标签的预测概率。

    第二排是在两个时间点的相关性空间分布图:按钮按下后大约50毫秒和150毫秒。与预期的一样,高可信度试验显示出分别对应左(a)和右(b)按钮对应的对侧运动皮层的正确相关性。对于低置信度的试验,可以看到相关性更加混杂且分布广泛,而运动皮质没有明确的空间定位。

    eegnet网络实现

    作者提供的代码用的是旧版本的pytorch,所以有一些错误。rose小哥基于作者提供的代码在pytorch 1.3.1(only cpu)版本下修改,经测试,在rose小哥环境下可以运行[不排除在其他环境可能会存在不兼容的问题]

    # 导入工具包 import numpy as np from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_score import torch import torch.nn as nn import torch.optim as optim from torch.autograd import variable import torch.nn.functional as f import torch.optim as optim

    eegnet网络模型参数如下:

    定义网络模型:

    class eegnet(nn.module):def __init__(self):super(eegnet, self).__init__()self.t = 120# layer 1self.conv1 = nn.conv2d(1, 16, (1, 64), padding = 0)self.batchnorm1 = nn.batchnorm2d(16, false)# layer 2self.padding1 = nn.zeropad2d((16, 17, 0, 1))self.conv2 = nn.conv2d(1, 4, (2, 32))self.batchnorm2 = nn.batchnorm2d(4, false)self.pooling2 = nn.maxpool2d(2, 4)# layer 3self.padding2 = nn.zeropad2d((2, 1, 4, 3))self.conv3 = nn.conv2d(4, 4, (8, 4))self.batchnorm3 = nn.batchnorm2d(4, false)self.pooling3 = nn.maxpool2d((2, 4))# 全连接层# 此维度将取决于数据中每个样本的时间戳数。# i have 120 timepoints. self.fc1 = nn.linear(4*2*7, 1)def forward(self, x):# layer 1x = f.elu(self.conv1(x))x = self.batchnorm1(x)x = f.dropout(x, 0.25)x = x.permute(0, 3, 1, 2)# layer 2x = self.padding1(x)x = f.elu(self.conv2(x))x = self.batchnorm2(x)x = f.dropout(x, 0.25)x = self.pooling2(x)# layer 3x = self.padding2(x)x = f.elu(self.conv3(x))x = self.batchnorm3(x)x = f.dropout(x, 0.25)x = self.pooling3(x)# 全连接层x = x.view(-1, 4*2*7)x = f.sigmoid(self.fc1(x))return x

    定义评估指标:

    acc:准确率

    auc:auc 即 roc 曲线对应的面积

    recall:召回率

    precision:精确率

    fmeasure:f值

    def evaluate(model, x, y, params = ["acc"]):results = []batch_size = 100predicted = []for i in range(len(x)//batch_size):s = i*batch_sizee = i*batch_size batch_sizeinputs = variable(torch.from_numpy(x[s:e]))pred = model(inputs)predicted.append(pred.data.cpu().numpy())inputs = variable(torch.from_numpy(x))predicted = model(inputs)predicted = predicted.data.cpu().numpy()"""设置评估指标:acc:准确率auc:auc 即 roc 曲线对应的面积recall:召回率precision:精确率fmeasure:f值"""for param in params:if param == 'acc':results.append(accuracy_score(y, np.round(predicted)))if param == "auc":results.append(roc_auc_score(y, predicted))if param == "recall":results.append(recall_score(y, np.round(predicted)))if param == "precision":results.append(precision_score(y, np.round(predicted)))if param == "fmeasure":precision = precision_score(y, np.round(predicted))recall = recall_score(y, np.round(predicted))results.append(2*precision*recall/ (precision recall))return results

    构建网络eegnet,并设置二分类交叉熵和adam优化器

    # 定义网络 net = eegnet() # 定义二分类交叉熵 (binary cross entropy) criterion = nn.bceloss() # 定义adam优化器 optimizer = optim.adam(net.parameters())

    创建数据集

    """ 生成训练数据集,数据集有100个样本 训练数据x_train:为[0,1)之间的随机数; 标签数据y_train:为0或1 """ x_train = np.random.rand(100, 1, 120, 64).astype('float32') y_train = np.round(np.random.rand(100).astype('float32')) """ 生成验证数据集,数据集有100个样本 验证数据x_val:为[0,1)之间的随机数; 标签数据y_val:为0或1 """ x_val = np.random.rand(100, 1, 120, 64).astype('float32') y_val = np.round(np.random.rand(100).astype('float32')) """ 生成测试数据集,数据集有100个样本 测试数据x_test:为[0,1)之间的随机数; 标签数据y_test:为0或1 """ x_test = np.random.rand(100, 1, 120, 64).astype('float32') y_test = np.round(np.random.rand(100).astype('float32'))

    训练并验证

    batch_size = 32 # 训练 循环 for epoch in range(10): print("nepoch ", epoch)running_loss = 0.0for i in range(len(x_train)//batch_size-1):s = i*batch_sizee = i*batch_size batch_sizeinputs = torch.from_numpy(x_train[s:e])labels = torch.floattensor(np.array([y_train[s:e]]).t*1.0)# wrap them in variableinputs, labels = variable(inputs), variable(labels)# zero the parameter gradientsoptimizer.zero_grad()# forward backward optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss = loss.item()# 验证params = ["acc", "auc", "fmeasure"]print(params)print("training loss ", running_loss)print("train - ", evaluate(net, x_train, y_train, params))print("validation - ", evaluate(net, x_val, y_val, params))print("test - ", evaluate(net, x_test, y_test, params))

    epoch 0
    ['acc', 'auc', 'fmeasure']
    training loss 1.6107637286186218
    train - [0.52, 0.5280448717948718, 0.6470588235294118]
    validation - [0.55, 0.450328407224959, 0.693877551020408]
    test - [0.54, 0.578926282051282, 0.6617647058823529]
    epoch 1
    ['acc', 'auc', 'fmeasure']
    training loss 1.5536684393882751
    train - [0.45, 0.41145833333333337, 0.5454545454545454]
    validation - [0.55, 0.4823481116584565, 0.6564885496183207]
    test - [0.65, 0.6530448717948717, 0.7107438016528926]
    epoch 2
    ['acc', 'auc', 'fmeasure']
    training loss 1.5197088718414307
    train - [0.49, 0.5524839743589743, 0.5565217391304348]
    validation - [0.53, 0.5870279146141215, 0.5436893203883495]
    test - [0.57, 0.5428685897435898, 0.5567010309278351]
    epoch 3
    ['acc', 'auc', 'fmeasure']
    training loss 1.4534167051315308
    train - [0.53, 0.5228365384615385, 0.4597701149425287]
    validation - [0.5, 0.48152709359605916, 0.46808510638297873]
    test - [0.61, 0.6502403846153847, 0.5517241379310345]
    epoch 4
    ['acc', 'auc', 'fmeasure']
    training loss 1.3821702003479004
    train - [0.46, 0.4651442307692308, 0.3076923076923077]
    validation - [0.47, 0.5977011494252874, 0.29333333333333333]
    test - [0.52, 0.5268429487179488, 0.35135135135135137]
    epoch 5
    ['acc', 'auc', 'fmeasure']
    training loss 1.440490186214447
    train - [0.56, 0.516025641025641, 0.35294117647058826]
    validation - [0.36, 0.3801313628899836, 0.2]
    test - [0.53, 0.6113782051282052, 0.27692307692307694]
    epoch 6
    ['acc', 'auc', 'fmeasure']
    training loss 1.4722238183021545
    train - [0.47, 0.4194711538461539, 0.13114754098360656]
    validation - [0.46, 0.5648604269293925, 0.2285714285714286]
    test - [0.5, 0.5348557692307693, 0.10714285714285714]
    epoch 7
    ['acc', 'auc', 'fmeasure']
    training loss 1.3460421562194824
    train - [0.51, 0.44871794871794873, 0.1694915254237288]
    validation - [0.44, 0.4490968801313629, 0.2]
    test - [0.53, 0.4803685897435898, 0.14545454545454545]
    epoch 8
    ['acc', 'auc', 'fmeasure']
    training loss 1.3336675763130188
    train - [0.54, 0.4130608974358974, 0.20689655172413793]
    validation - [0.39, 0.40394088669950734, 0.14084507042253522]
    test - [0.51, 0.5400641025641025, 0.19672131147540983]
    epoch 9
    ['acc', 'auc', 'fmeasure']
    training loss 1.438510239124298
    train - [0.53, 0.5392628205128205, 0.22950819672131148]
    validation - [0.42, 0.4848111658456486, 0.09375]
    test - [0.56, 0.5420673076923076, 0.2413793103448276]

    参考

    应用深度学习eegnet来处理脑电信号

    总结

    以上是ag凯发k8国际为你收集整理的脑电信号特征提取算法c语言_应用深度学习eegnet来处理脑电信号的全部内容,希望文章能够帮你解决所遇到的问题。

    如果觉得ag凯发k8国际网站内容还不错,欢迎将ag凯发k8国际推荐给好友。

    • 上一篇:
    • 下一篇:
    网站地图