#-*- coding:utf-8 -*-
import Image as im
import numpy 
from sys import argv
from pylab import *
import string
import cv
import os
from random import *
from math import *

infinity = 1000000000


#paths for cover and stego
path_stego = '/home/couturie/ajeter/stego/'
path_cover = '/home/couturie/ajeter/cover/'


def dec(ch,n):    
    l = len(ch)
    acc = 0
    for i in xrange(l):
        if ch[i]==1:
            acc = acc + 2**(n-i-1)        
    return acc


def bin(elem,n):
    q = -1
    res = [0 for i in xrange(n)]
    i = 1
    while q != 0:
        q = elem // 2
        r = elem % 2
        res[n-i] =  r
        elem = q
        i+=1
    return res



def xorb(a,b):
    return 1 if a != b else 0

def xor(e1,e2,h):
    e1b,e2b  = bin(e1,h),bin(e2,h)
    d = dec([xorb(e1b[j],e2b[j]) for j in xrange(h)],h)
    return d

def lit(d,(indx,indy)):
    if (indx,indy) in d :
        return d[(indx,indy)]
    else :
        return 0



# for STC algorithm
def forward(H_hat,x,message,lnm,rho):
    (h,w) = int(log(max(H_hat),2))+1, len(H_hat)
    path = dict() 
    nbblock = lnm
    wght = [infinity for _ in xrange(int(2**h))] 
    wght[0]=0    
    newwght = [0 for _ in xrange(int(2**h))]
#    rho = 1
#    rho= [1 for _ in xrange(len(x))]
    indx,indm = 0,0
    i=0
    while i < nbblock: # for each bit in the message 
        for j in xrange(w):   # for each column in H_hat
            k = 0
            while k < int(2**h): # for each line  of H 
                w0 = wght[k] + x[indx]*rho[indx]
                w1 = wght[xor(k,H_hat[j],h)] + (1-x[indx])*rho[indx]
                if w1 < w0 :
                    path[(indx,k)] = 1 
                else : 
                    if (indx,k) in path:
                        del path[(indx,k)]
                newwght[k] = min(w0,w1)
                k +=1 
            indx +=1
            wght = [t for t in newwght]


        for j in xrange(int(2**(h-1))):   # for each column in H
            wght[j] = wght[2*j + message[indm]]
        wght = wght[:int(pow(2,h-1))] + [infinity for _ in xrange(int(pow(2,h)-pow(2,h-1)))]
        indm +=1
        i +=1
    start = np.argmin(wght)
    return (start,path)

# for STC algorithm
def backward(start,H_hat,x,message,lnm,path):
    (h,w) = int(log(max(H_hat),2))+1, len(H_hat)
    indx,indm = len(x)-1,lnm-1
    state = 2*start + message[indm]
    indm -=1
    nbblock = lnm
    y=np.zeros(len(x))
    i=0
    while i < nbblock:
        l = range(w)
        l.reverse()
        for j in l:   # for each column of H_hat
            y[indx] = lit(path,(indx,state))
            state = xor(state,y[indx]*H_hat[j],h)
            indx -=1
        state = 2*state + message[indm]
        indm -=1 
        i +=1
    return [int(t) for t in y]

    
 

# for STC algorithm
def trouve_H_hat(n,m,h):
    assert h ==7 
    alpha = float(n)/m
    assert alpha >= 1 
    index = min(int(alpha),9)
    mat = {
        1 : [127],
        2 : [71,109],
        3 : [95, 101, 121],
        4 : [81, 95, 107, 121],
        5 : [75, 95, 97, 105, 117],
        6 : [73, 83, 95, 103, 109, 123],
        7 : [69, 77, 93, 107, 111, 115, 121],
        8 : [69, 79, 81, 89, 93, 99, 107, 119],
        9 : [69, 79, 81, 89, 93, 99, 107, 119, 125]
        }[index]
    return(mat, index*m)

# STC Algorithm
def stc(x,rho,message):
    lnm = len(message)
    (mat,taille_suff) = trouve_H_hat(len(x),len(message),7)
    x_b = x[:taille_suff]
    (start,path) = forward(mat,x_b,message,lnm,rho)
    return (x_b,backward(start,mat,x_b,message,lnm,path),mat)





def nbdif(x,y):
    r,it = 0,0
    l = len(y)
    while it < l :
        if x[it] != y[it] :
            r +=1
        it += 1
    return float(r)/l 
        




# for decoding purpose
def prod(H_hat,lnm,y):
    (h,w) = int(log(max(H_hat),2))+1, len(H_hat)
    i=0
    H =[]
    V=[0 for _ in range(len(y))]
    sol=[]
    while i < lnm: # pour chaque ligne 
        V=[0 for _ in range(len(y))]    
        k = max([(i-h+1)*w,0])
        dec = max([i-h+1,0])
        for j in xrange(min([i+1,h])): #nbre de blocks presents sur la ligne i
            for l in xrange(w): # pour chaque collone de H_hat
                V[k] = bin(H_hat[l],h)[h-i-1+j+dec]
                k+=1
                
        sol.append(np.dot(np.array(V),np.array(y)))
        i+=1
        #H += [V]
    #H = np.array(H)    
    #y = np.array(y)
    #print "dot",np.dot(H,y),H.shape
    #print "sol",sol
    return sol#list(np.dot(H,y))
    
def equiv(x,y): 
    lx = len(x)
    assert lx == len(y)
    i=0
    while i < lx :
        if x[i] % 2 != y[i]%2 : 
            return False
        i += 1
    return True
        





def conversion(nombre, base, epsilon = 0.00001 ):
    ''' Soit nombre écrit en base 10, le retourne en base base'''
    if not 2 <= base <= 36:
        raise ValueError, "La base doit être entre 2 et 36"
    if not base == 2 and '.' in str(nombre):
        raise ValueError, "La partie décimale n'est pas gérée pour les bases\
                           différentes de 2."
    # IMPROVE : Convertir aussi la partie décimale, quand la base n'est pas égale
    # à 2.
    abc = string.digits + string.letters
    result = ''
    if nombre < 0:
        nombre = -nombre
        sign = '-'
    else:
        sign = ''
    if '.' in str(nombre):
        entier,decimal = int(str(nombre).split('.')[0]),\
                         float('.'+str(nombre).split('.')[1])
    else:
        entier,decimal = int(str(nombre)),0
    while entier !=0 :
        entier, rdigit = divmod( entier, base )
        result = abc[rdigit] + result
    flotante, decimalBin = 1./float(base),''
    while flotante > epsilon :
        if decimal >= flotante:
            decimalBin+='1'
            decimal-=flotante
        else :
            decimalBin+='0'    
        flotante = flotante/float(base)
    if '1' in decimalBin :
        reste = '.'+decimalBin
        while reste[-1]=='0':
            reste = reste[:-1]
    else :
        reste = ''
    return sign + result + reste

    
def getBit(X,pos):
    ''' retrieve the bit indexed by pos in X
    For instance getBit(8,1) = 0
    '''
    assert pos != 0
    entier = conversion(X,2)
    if '.' in entier:
        entier, decimal = entier.split('.')  
        if decimal == '0':
            decimal = ''  
    else:
        decimal = ''
    if '-' in entier:
        entier = entier.replace('-','')
    entier  = entier.zfill(abs(pos))
    decimal = (decimal+'0'*abs(pos))[:max(len(decimal),abs(pos))]

    return int(entier[len(entier)-pos]) if pos >0 else int(decimal[-pos-1])


def setBit(X,pos,y):
    '''set the bit pos of  X with value y.
    '''
    assert pos != 0
    entier = conversion(X,2)
    if '.' in entier:
        entier, decimal = entier.split('.')    
    else:
        decimal = ''
    entier  = list(entier.zfill(abs(pos)))
    decimal = list((decimal+'0'*abs(pos))[:max(len(decimal),abs(pos))])
    if pos>0:
        entier[len(entier)-pos]=str(int(y))
    else:
        decimal[-pos-1] = str(int(y))
    if decimal == []:
        return int(''.join(entier),2)
    else:
        S=0
        for k in range(len(decimal)):
            S += 1./2**(k+1)*int(decimal[k])
        return float(str(int(''.join(entier),2))+'.'+str(S).split('.')[1])


def a2b(a): 
    ai = ord(a) 
    return ''.join('01'[(ai >> x) & 1] for x in xrange(7, -1, -1)) 



def a2b_list(L):
    LL=[]
    for i in L:
        for j in list(a2b(i)):
            LL.append(j)
    return LL



def toDecimal(x):
    return sum(map(lambda z: int(x[z]) and 2**(len(x) - z - 1),
                   range(len(x)-1, -1, -1)))            

def conv_list_bit(L):
    L2=[]
    for j in range(len(L)/8):
        L2.append(chr(toDecimal("".join(L[j*8:(j+1)*8]))))
    return ''.join(L2)

def Denary2Binary(n):
    '''convert denary integer n to binary string bStr'''
    bStr = ''
    if n < 0:  raise ValueError, "must be a positive integer"
    if n == 0: return '0'
    while n > 0:
        bStr = str(n % 2) + bStr
        n = n >> 1
    return bStr


def compute_filter_sobel(level,image):
    level2=level.copy()
    level2= array(level2.getdata()).flatten()
    l=0
    for x in level2:
        level2[l]=(x/2)*2
        l+=1
    level2_im=im.new('L',image.size)
    level2_im.putdata(level2)

    cv_im = cv.CreateImageHeader(image.size, cv.IPL_DEPTH_8U, 1)
    cv.SetData(cv_im, level2_im.tostring())
    dst16 = cv.CreateImage(cv.GetSize(cv_im), cv.IPL_DEPTH_16S, 1)

    laplace = cv.Sobel(cv_im, dst16,1, 1,7)
    
    dst8 = cv.CreateImage (cv.GetSize(cv_im), cv.IPL_DEPTH_8U, 1)
    cv.ConvertScale(dst16,dst8)
    processed=im.fromstring("L", cv.GetSize(dst8), dst8.tostring())
#    cv.ShowImage ('canny', dst8)
#    cv.WaitKey()

    return processed



def compute_list_bit_to_change(threshold,processed):
    List=[]
    nb=0
    l=0
    for i in processed:
        if (processed[l]>=threshold):
            #if nb%2==0:
                List.append(l)
            #nb+=1
        l+=1
    return List


def compute_filter_canny(level,image):
    level2=level.copy()
    level2= array(level2.getdata()).flatten()
    l=0
    for x in level2:
        level2[l]=(x/2)*2
        l+=1
    level2_im=im.new('L',image.size)
    level2_im.putdata(level2)
    level2_im=im.merge("RGB",(level2_im,level2_im,level2_im))

    mean=numpy.mean(level2)
    std=numpy.std(level2)

    cv_im = cv.CreateImageHeader(image.size, cv.IPL_DEPTH_8U, 3)
    cv.SetData(cv_im, level2_im.tostring())

    yuv = cv.CreateImage(cv.GetSize(cv_im), 8, 3)
    gray = cv.CreateImage(cv.GetSize(cv_im), 8, 1)
    cv.CvtColor(cv_im, yuv, cv.CV_BGR2YCrCb)
    cv.Split(yuv, gray, None, None, None)

    canny = cv.CreateImage(cv.GetSize(cv_im), 8, 1)

    List_bit_to_change=set([])
    Weight=[]


    
    cv.Canny(gray, canny, mean-1*std, mean+1*std,3)  #avant 10 255
    processed=im.fromstring("L", cv.GetSize(canny), canny.tostring())
    processed= array(processed.getdata()).flatten()
    List3=set(compute_list_bit_to_change(100,processed))

    cv.Canny(gray, canny, mean-1*std, mean+1*std,5)  #avant 10 255
    processed=im.fromstring("L", cv.GetSize(canny), canny.tostring())
    processed= array(processed.getdata()).flatten()
    List5=set(compute_list_bit_to_change(100,processed))

    cv.Canny(gray, canny, mean-1*std, mean+1*std,7)  #avant 10 255
    processed=im.fromstring("L", cv.GetSize(canny), canny.tostring())
    processed= array(processed.getdata()).flatten()
    List7=set(compute_list_bit_to_change(100,processed))

    nb_bit_embedded=(512*512/10)+40
    AvailablePixel3=List3
    AvailablePixel5=AvailablePixel3.union(List5)
    AvailablePixel7=AvailablePixel5.union(List7)
    if len(AvailablePixel3)>nb_bit_embedded:
        step=1
        WorkingPixel=AvailablePixel3
    elif len(AvailablePixel5)>nb_bit_embedded:
        step=2
        WorkingPixel=AvailablePixel5
    elif len(AvailablePixel7)>nb_bit_embedded:
        step=3
        WorkingPixel=AvailablePixel7
    else:
        step=4
        WorkingPixel=range(len(level2))

    print "avail P3",len(AvailablePixel3)
    print "avail P5",len(AvailablePixel5)
    print "avail P7",len(AvailablePixel7)

    print "size WorkingPixel",len(WorkingPixel)
    Weight=[0 for _ in WorkingPixel]

    l=0
    for i in WorkingPixel:
        if step>=1 and i in List3:
            Weight[l]=1
        if step>=2 and i in List5 and Weight[l]==0:
            Weight[l]=10
        if step>=3 and i in List7 and Weight[l]==0:
            Weight[l]=100
        if step>=4 and Weight[l]==0:
            Weight[l]=1000
        l+=1

            
        

    List_bit_to_change=WorkingPixel
        
        



    return [List_bit_to_change,Weight]











def mystego(filein, fileout):
    dd = im.open(filein)
    dd = dd.convert('RGB') 
    red, green, blue = dd.split()
    level=red.copy()

    [List_bit_to_change,Weight]=compute_filter_canny(level,dd)
    level= array(level.getdata()).flatten()


    bit_to_read=1  






#parameters for BBS
    M=18532395500947174450709383384936679868383424444311405679463280782405796233163977*39688644836832882526173831577536117815818454437810437210221644553381995813014959
    X=18532395500947174450709383384936679868383424444311






    l=0
    message="Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete   Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete  Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete   Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete  Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete   Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete  Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete   Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete   Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete  Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete   Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete Ce que j'écris est très original... Bref, je suis un poete   Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete Ce que j'écris est très original... Bref, je suis un poete   Salut christophe, arrives tu à lire ce message? normalement tu dois lire cela. Bon voici un test avec un message un peu plus long. Bon j'en rajoute pour voir. Ce que j'écris est très original... Bref, je suis un poete voila c'est la fin blablabla:-)"
    message=message[0:len(message)/1]
    leng_msg=len(message)
    message=message+((leng_msg+7)/8*8-leng_msg)*" "
    leng_msg=len(message)

    leng='%08d'%len(message)




    len_leng=len(leng)
    leng_error=int(len_leng)
    leng_cor=leng
    List_pack=a2b_list(leng_cor)

    List_random=[]
    while len(List_random)<len(List_bit_to_change):
        X=(X*X)%M
        List_random.extend(Denary2Binary(X))

    size=0
    for i in range(leng_msg/8):
        m=message[i*8:(i+1)*8]
        m_cor=m
        m_bin=a2b_list(m_cor)
        size=size+len(m_bin)
        List_pack.extend(m_bin) 






   

    Support=[getBit(level[l],bit_to_read) for l in List_bit_to_change]


    Message=[(int(List_pack[l])^int(List_random[l])) for l in xrange(len(List_pack))]

    print "support",len(List_bit_to_change)
    print "message",len(Message)
    print "weight",len(Weight)

    (x_b,Stc_message,H_hat) = stc(Support,Weight,Message)

    print "modification in %",nbdif(x_b,Stc_message)
    print "size of the STC message",len(Stc_message)



    l=0
    size_mesg=0
    val_mod=0

    

    for l in List_bit_to_change:
        if(size_mesg<len(Stc_message)):
            b=getBit(level[l],bit_to_read)
            if b!=Stc_message[size_mesg]:
                val_mod+=1
            level[l]=float64(setBit(level[l],bit_to_read,Stc_message[size_mesg]))
            size_mesg+=1

    print 'size mesg',size_mesg
    print 'val mod',val_mod
    print 'len message',len(Message),len(List_pack)



    zz3=im.new('L',dd.size)
    zz3.putdata(level)

    zz3.save(fileout)




listing = os.listdir(path_cover)

print listing

list_nb_bit=[]
l=0



for infile in listing:
    print "current file is: " + infile, path_stego+infile
    mystego(path_cover+infile,path_stego+infile)
    l+=1

