Wie Sie mit diesem C-code zum multiplizieren von zwei Matrizen mit Strassen ' s Algorithmus?

Ich war auf der Suche für eine Implementierung von Strassen ' s Algorithmus in C, und ich habe festgestellt, das der code am Ende.

Verwenden multiply Funktion:

void multiply(int n, matrix a, matrix b, matrix c, matrix d);

welche multipliziert zwei Matrizen a, b und legt das Ergebnis in c (d ist ein Vermittler-matrix). Matrizen a und b sollte die folgenden Typ:

typedef union _matrix 
{
    double **d;
    union _matrix **p;
} *matrix;

Habe ich dynamisch zugewiesen werden vier Matrizen a, b, c, d (zwei-dimensionale arrays von doubles) und zugeordnet haben Ihre Adressen aus dem Feld _matrix.d:

#include "strassen.h"

#define SIZE 50 

int main(int argc, char *argv[])
{
    double ** matA, ** matB, ** matC, ** matD;
    union _matrix ma, mb, mc, md; 
    int i = 0, j = 0, n;

    matA = (double **) malloc(sizeof(double *) * SIZE);
    for (i = 0; i < SIZE; i++)
        matA[i] = (double *) malloc(sizeof(double) * SIZE); 
    //Do the same for matB, matC, matD.

    ma.d = matA;
    mb.d = matB;
    mc.d = matC;
    md.d = matD;

    //Initialize matC and matD to 0.

    //Read n.

    //Read matA and matB.

    multiply(n, &ma, &mb, &mc, &md);
    return 0;
}

Dieser code erfolgreich kompiliert, aber stürzt mit n > BREAK.

strassen.c :

#include "strassen.h"

/* c = a * b */
void multiply(int n, matrix a, matrix b, matrix c, matrix d)
{
    if (n <= BREAK) {
      double sum, **p = a->d, **q = b->d, **r = c->d;
      int i, j, k;

      for (i = 0; i < n; i++)
         for (j = 0; j < n; j++) {
            for (sum = 0., k = 0; k < n; k++)
               sum += p[i][k] * q[k][j];
            r[i][j] = sum;
         }
    } else {
        n /= 2;
        sub(n, a12, a22, d11);
        add(n, b21, b22, d12);
        multiply(n, d11, d12, c11, d21);
        sub(n, a21, a11, d11);
        add(n, b11, b12, d12);
        multiply(n, d11, d12, c22, d21);
        add(n, a11, a12, d11);
        multiply(n, d11, b22, c12, d12);
        sub(n, c11, c12, c11);
        sub(n, b21, b11, d11);
        multiply(n, a22, d11, c21, d12);
        add(n, c21, c11, c11);
        sub(n, b12, b22, d11);
        multiply(n, a11, d11, d12, d21);
        add(n, d12, c12, c12);
        add(n, d12, c22, c22);
        add(n, a21, a22, d11);
        multiply(n, d11, b11, d12, d21);
        add(n, d12, c21, c21);
        sub(n, c22, d12, c22);
        add(n, a11, a22, d11);
        add(n, b11, b22, d12);
        multiply(n, d11, d12, d21, d22);
        add(n, d21, c11, c11);
        add(n, d21, c22, c22);
    }
}

/* c = a + b */
void add(int n, matrix a, matrix b, matrix c)
{
    if (n <= BREAK) {
        double **p = a->d, **q = b->d, **r = c->d;
        int i, j;

        for (i = 0; i < n; i++)
           for (j = 0; j < n; j++)
              r[i][j] = p[i][j] + q[i][j];
    } else {
        n /= 2;
        add(n, a11, b11, c11);
        add(n, a12, b12, c12);
        add(n, a21, b21, c21);
        add(n, a22, b22, c22);
    }
}

/* c = a - b */
void sub(int n, matrix a, matrix b, matrix c)
{
    if (n <= BREAK) {
        double **p = a->d, **q = b->d, **r = c->d;
        int i, j;

        for (i = 0; i < n; i++)
           for (j = 0; j < n; j++)
              r[i][j] = p[i][j] - q[i][j];
    } else {
        n /= 2;
        sub(n, a11, b11, c11);
        sub(n, a12, b12, c12);
        sub(n, a21, b21, c21);
        sub(n, a22, b22, c22);
    }
}

strassen.h:

#define BREAK 8   

typedef union _matrix {
    double **d;
    union _matrix **p;
} *matrix;

/* Notational shorthand to access submatrices for matrices named a, b, c, d */

#define a11 a->p[0]
#define a12 a->p[1]
#define a21 a->p[2]
#define a22 a->p[3]
#define b11 b->p[0]
#define b12 b->p[1]
#define b21 b->p[2]
#define b22 b->p[3]
#define c11 c->p[0]
#define c12 c->p[1]
#define c21 c->p[2]
#define c22 c->p[3]
#define d11 d->p[0]
#define d12 d->p[1]
#define d21 d->p[2]
#define d22 d->p[3]

Meine Frage ist, wie man die Funktion multiply (wie die Umsetzung der matrix).

strassen.h

strassen.c

  • Nicht gegossen, die den Rückgabewert von malloc() in C.
  • Anstelle von dumping solche großen code-Stück, Ecke bitte das problem und klar zu erklären, was es ist! Und auch sagen, was Sie versucht haben und was Sie Verdacht? Die aktuelle version der Frage könnte machen Menschen juckende
  • n ist uninitialiazed in Ihrem main
  • Überprüfen Sie dieses großartige Dokument auf die Umsetzung der strassen-Algorithmus software.intel.com/file/24473implementation und auch den code gepostet von @Tudor : software.intel.com/file/24473
InformationsquelleAutor obo | 2012-03-02
Schreibe einen Kommentar